better output for logging

This commit is contained in:
Edward Shen 2020-04-21 13:41:10 -04:00
parent 6e8099c57f
commit c95af24390
Signed by: edward
GPG key ID: 19182661E818369F
3 changed files with 40 additions and 26 deletions

View file

@ -290,8 +290,10 @@ pub async fn train_actor(opts: &Train) -> Box<dyn Actor> {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
random_selection.sort_unstable_by(|e1, e2| e1.1.cmp(&e2.1)); random_selection.sort_unstable_by(|e1, e2| e1.1.cmp(&e2.1));
let best_two = random_selection.iter().rev().take(2).collect::<Vec<_>>(); let best_two = random_selection.iter().rev().take(2).collect::<Vec<_>>();
let parent1 = dbg!(best_two[0]); let parent1 = &best_two[0];
let parent2 = dbg!(best_two[1]); println!("{:?}", &best_two[0]);
let parent2 = &best_two[1];
println!("{:?}", &best_two[1]);
for _ in 0..new_pop_size / 3 { for _ in 0..new_pop_size / 3 {
let breeded = parent1.0.breed(parent1.1, &parent2.0, parent2.1); let breeded = parent1.0.breed(parent1.1, &parent2.0, parent2.1);
let mut cloned = breeded.clone(); let mut cloned = breeded.clone();

View file

@ -5,8 +5,7 @@ use crate::{
game::{Action, Controllable, Game, Tickable}, game::{Action, Controllable, Game, Tickable},
playfield::{Matrix, PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, playfield::{Matrix, PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
}; };
use indicatif::ProgressIterator; use log::{debug, error, info, trace};
use log::{debug, info, trace};
use rand::rngs::SmallRng; use rand::rngs::SmallRng;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::Rng; use rand::Rng;
@ -28,6 +27,7 @@ pub trait QLearningActor: Actor {
fn set_discount_rate(&mut self, discount_rate: f64); fn set_discount_rate(&mut self, discount_rate: f64);
} }
#[derive(Debug)]
pub struct QLearningAgent { pub struct QLearningAgent {
pub learning_rate: f64, pub learning_rate: f64,
pub exploration_prob: f64, pub exploration_prob: f64,
@ -150,6 +150,7 @@ impl QLearningActor for QLearningAgent {
} }
} }
#[derive(Debug)]
pub struct ApproximateQLearning { pub struct ApproximateQLearning {
pub learning_rate: f64, pub learning_rate: f64,
pub exploration_prob: f64, pub exploration_prob: f64,
@ -183,7 +184,7 @@ enum Feature {
impl ApproximateQLearning { impl ApproximateQLearning {
fn get_features(&self, game: &Game, action: &Action) -> HashMap<Feature, f64> { fn get_features(&self, game: &Game, action: &Action) -> HashMap<Feature, f64> {
// let game = game.get_next_state(*action); let game = game.get_next_state(*action);
let mut features = HashMap::default(); let mut features = HashMap::default();
let field = game.playfield().field(); let field = game.playfield().field();
@ -255,6 +256,8 @@ impl ApproximateQLearning {
.map(|action| (action, self.get_q_value(game, action))) .map(|action| (action, self.get_q_value(game, action)))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// dbg!(&legal_actions);
let max_val = legal_actions let max_val = legal_actions
.iter() .iter()
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize)) .max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
@ -273,10 +276,16 @@ impl ApproximateQLearning {
); );
} }
*actions_to_choose let action = actions_to_choose.choose(&mut SmallRng::from_entropy());
.choose(&mut SmallRng::from_entropy())
.unwrap() match action {
.0 Some(a) => *a.0,
None => {
dbg!(&legal_actions);
dbg!(&actions_to_choose);
panic!("wtf???");
}
}
} }
fn get_value(&self, game: &Game) -> f64 { fn get_value(&self, game: &Game) -> f64 {
@ -325,7 +334,7 @@ impl QLearningActor for ApproximateQLearning {
game_state: Game, game_state: Game,
action: Action, action: Action,
next_game_state: Game, next_game_state: Game,
next_legal_actions: &[Action], _: &[Action],
reward: f64, reward: f64,
) { ) {
let difference = reward + self.discount_rate * self.get_value(&next_game_state) let difference = reward + self.discount_rate * self.get_value(&next_game_state)
@ -352,7 +361,7 @@ impl QLearningActor for ApproximateQLearning {
} }
} }
pub fn train_actor<T: 'static + QLearningActor + Actor>( pub fn train_actor<T: std::fmt::Debug + 'static + QLearningActor + Actor>(
mut actor: T, mut actor: T,
opts: &Train, opts: &Train,
) -> Box<dyn Actor> { ) -> Box<dyn Actor> {
@ -368,10 +377,11 @@ pub fn train_actor<T: 'static + QLearningActor + Actor>(
opts.learning_rate, opts.discount_rate, opts.exploration_prob opts.learning_rate, opts.discount_rate, opts.exploration_prob
); );
for i in (0..episodes).progress() { for i in 0..episodes {
if i != 0 && i % (episodes / 10) == 0 { if i != 0 && i % (episodes / 10) == 0 {
info!("Last {} scores avg: {}", (episodes / 10), avg); println!("{}", avg);
println!(); eprintln!("iteration {}", i);
// println!("{:?}", &actor);
avg = 0.0; avg = 0.0;
} }
let mut game = Game::default(); let mut game = Game::default();

View file

@ -54,16 +54,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> { async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
let mut rng = rand::rngs::SmallRng::from_entropy(); let mut rng = rand::rngs::SmallRng::from_entropy();
let sdl_context = sdl2::init()?; let sdl_context = sdl2::init()?;
let video_subsystem = sdl_context.video()?; // let video_subsystem = sdl_context.video()?;
let window = video_subsystem // let window = video_subsystem
.window("retris", 800, 800) // .window("retris", 800, 800)
.position_centered() // .position_centered()
.build()?; // .build()?;
let mut canvas = window.into_canvas().build()?; // let mut canvas = window.into_canvas().build()?;
let mut event_pump = sdl_context.event_pump()?; let mut event_pump = sdl_context.event_pump()?;
let mut interval = interval(Duration::from_millis(1000 / TICKS_PER_SECOND as u64)); let mut interval = interval(Duration::from_millis(1000 / TICKS_PER_SECOND as u64));
'escape: loop { 'escape: for _ in 0..10 {
let mut game = Game::default(); let mut game = Game::default();
loop { loop {
@ -89,7 +89,7 @@ async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std:
.. ..
} => { } => {
debug!("Escape registered"); debug!("Escape registered");
break 'escape Ok(()); break 'escape;
} }
Event::KeyDown { Event::KeyDown {
keycode: Some(Keycode::Left), keycode: Some(Keycode::Left),
@ -175,13 +175,15 @@ async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std:
}); });
game.tick(); game.tick();
canvas.set_draw_color(COLOR_BACKGROUND); // canvas.set_draw_color(COLOR_BACKGROUND);
canvas.clear(); // canvas.clear();
standard_renderer::render(&mut canvas, &game); // standard_renderer::render(&mut canvas, &game);
canvas.present(); // canvas.present();
interval.tick().await; interval.tick().await;
} }
info!("Final score: {}", game.score()); info!("Final score: {}", game.score());
} }
Ok(())
} }