244 lines
7.7 KiB
Rust
244 lines
7.7 KiB
Rust
use actors::*;
|
|
use clap::Clap;
|
|
use cli::*;
|
|
use game::{Action, Controllable, Game, Tickable};
|
|
use graphics::standard_renderer;
|
|
use graphics::COLOR_BACKGROUND;
|
|
use indicatif::ProgressIterator;
|
|
use log::{debug, info, trace};
|
|
use rand::SeedableRng;
|
|
use sdl2::event::Event;
|
|
use sdl2::keyboard::Keycode;
|
|
use std::time::Duration;
|
|
use tokio::time::interval;
|
|
|
|
mod actors;
|
|
mod cli;
|
|
mod game;
|
|
mod graphics;
|
|
mod playfield;
|
|
mod random;
|
|
mod srs;
|
|
mod tetromino;
|
|
|
|
const TICKS_PER_SECOND: usize = 60;
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
let opts = crate::cli::Opts::parse();
|
|
|
|
init_verbosity(&opts)?;
|
|
let mut actor = None;
|
|
|
|
match opts.subcmd {
|
|
SubCommand::Play(sub_opts) => {}
|
|
SubCommand::Train(sub_opts) => {
|
|
let mut to_train = get_actor(sub_opts.agent);
|
|
to_train.set_learning_rate(sub_opts.learning_rate);
|
|
to_train.set_discount_rate(sub_opts.discount_rate);
|
|
to_train.set_exploration_prob(sub_opts.exploration_prob);
|
|
|
|
info!(
|
|
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
|
|
sub_opts.learning_rate,
|
|
sub_opts.discount_rate,
|
|
sub_opts.exploration_prob
|
|
);
|
|
let mut trained_actor = train_actor(sub_opts.episodes, to_train);
|
|
if sub_opts.no_explore_during_evaluation {
|
|
trained_actor.set_exploration_prob(0.0);
|
|
}
|
|
|
|
if sub_opts.no_learn_during_evaluation {
|
|
trained_actor.set_learning_rate(0.0);
|
|
}
|
|
|
|
actor = Some(trained_actor);
|
|
}
|
|
}
|
|
|
|
play_game(actor).await?;
|
|
Ok(())
|
|
}
|
|
|
|
fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
|
|
let mut rng = rand::rngs::SmallRng::from_entropy();
|
|
let mut avg = 0.0;
|
|
|
|
for i in (0..episodes).progress() {
|
|
if i != 0 && i % (episodes / 10) == 0 {
|
|
info!("Last {} scores avg: {}", (episodes / 10), avg);
|
|
println!();
|
|
avg = 0.0;
|
|
}
|
|
let mut game = Game::default();
|
|
while (&game).is_game_over().is_none() {
|
|
let cur_state = (&game).into();
|
|
let cur_score = game.score();
|
|
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
|
|
|
|
match action {
|
|
Action::Nothing => (),
|
|
Action::MoveLeft => game.move_left(),
|
|
Action::MoveRight => game.move_right(),
|
|
Action::SoftDrop => game.move_down(),
|
|
Action::HardDrop => game.hard_drop(),
|
|
Action::Hold => game.hold(),
|
|
Action::RotateLeft => game.rotate_left(),
|
|
Action::RotateRight => game.rotate_right(),
|
|
}
|
|
|
|
let new_state = (&game).into();
|
|
let mut reward = game.score() as f64 - cur_score as f64;
|
|
if action != Action::Nothing {
|
|
reward -= 10.0;
|
|
}
|
|
|
|
if game.is_game_over().is_some() {
|
|
reward = -100.0;
|
|
}
|
|
|
|
actor.update(cur_state, action, new_state, reward);
|
|
|
|
game.tick();
|
|
}
|
|
|
|
avg += game.score() as f64 / (episodes / 10) as f64;
|
|
}
|
|
|
|
actor
|
|
}
|
|
|
|
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
|
|
let mut rng = rand::rngs::SmallRng::from_entropy();
|
|
let sdl_context = sdl2::init()?;
|
|
let video_subsystem = sdl_context.video()?;
|
|
let window = video_subsystem
|
|
.window("retris", 800, 800)
|
|
.position_centered()
|
|
.build()?;
|
|
let mut canvas = window.into_canvas().build()?;
|
|
let mut event_pump = sdl_context.event_pump()?;
|
|
let mut game = Game::default();
|
|
let mut interval = interval(Duration::from_millis(1000 / TICKS_PER_SECOND as u64));
|
|
|
|
'running: loop {
|
|
match game.is_game_over() {
|
|
Some(e) => {
|
|
println!("Lost due to: {:?}", e);
|
|
break;
|
|
}
|
|
None => (),
|
|
}
|
|
|
|
let cur_state = (&game).into();
|
|
|
|
// If there's an actor, the player action will get overridden. If not,
|
|
// then then the player action falls through, if there is one. This is
|
|
// to allow for restarting and quitting the game from the GUI.
|
|
let mut action = None;
|
|
for event in event_pump.poll_iter() {
|
|
match event {
|
|
Event::Quit { .. }
|
|
| Event::KeyDown {
|
|
keycode: Some(Keycode::Escape),
|
|
..
|
|
} => {
|
|
debug!("Escape registered");
|
|
break 'running;
|
|
}
|
|
Event::KeyDown {
|
|
keycode: Some(Keycode::Left),
|
|
..
|
|
} => {
|
|
debug!("Move left registered");
|
|
action = Some(Action::MoveLeft);
|
|
}
|
|
Event::KeyDown {
|
|
keycode: Some(Keycode::Right),
|
|
..
|
|
} => {
|
|
debug!("Move right registered");
|
|
action = Some(Action::MoveRight);
|
|
}
|
|
Event::KeyDown {
|
|
keycode: Some(Keycode::Down),
|
|
..
|
|
} => {
|
|
debug!("Soft drop registered");
|
|
action = Some(Action::SoftDrop);
|
|
}
|
|
Event::KeyDown {
|
|
keycode: Some(Keycode::Z),
|
|
..
|
|
} => {
|
|
debug!("Rotate left registered");
|
|
action = Some(Action::RotateLeft);
|
|
}
|
|
Event::KeyDown {
|
|
keycode: Some(Keycode::X),
|
|
..
|
|
} => {
|
|
debug!("Rotate right registered");
|
|
action = Some(Action::RotateRight);
|
|
}
|
|
Event::KeyDown {
|
|
keycode: Some(Keycode::Space),
|
|
..
|
|
}
|
|
| Event::KeyDown {
|
|
keycode: Some(Keycode::Up),
|
|
..
|
|
} => {
|
|
debug!("Hard drop registered");
|
|
action = Some(Action::HardDrop);
|
|
}
|
|
Event::KeyDown {
|
|
keycode: Some(Keycode::LShift),
|
|
..
|
|
} => {
|
|
debug!("Hold registered");
|
|
action = Some(Action::Hold);
|
|
}
|
|
Event::KeyDown {
|
|
keycode: Some(Keycode::R),
|
|
..
|
|
} => {
|
|
info!("Restarting game");
|
|
game = Game::default();
|
|
}
|
|
Event::KeyDown {
|
|
keycode: Some(e), ..
|
|
} => trace!("Ignoring keycode {}", e),
|
|
_ => (),
|
|
}
|
|
}
|
|
|
|
actor.as_mut().map(|actor| {
|
|
action = Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())));
|
|
});
|
|
|
|
action.map(|action| match action {
|
|
Action::Nothing => (),
|
|
Action::MoveLeft => game.move_left(),
|
|
Action::MoveRight => game.move_right(),
|
|
Action::SoftDrop => game.move_down(),
|
|
Action::HardDrop => game.hard_drop(),
|
|
Action::Hold => game.hold(),
|
|
Action::RotateLeft => game.rotate_left(),
|
|
Action::RotateRight => game.rotate_right(),
|
|
});
|
|
|
|
game.tick();
|
|
canvas.set_draw_color(COLOR_BACKGROUND);
|
|
canvas.clear();
|
|
standard_renderer::render(&mut canvas, &game);
|
|
canvas.present();
|
|
interval.tick().await;
|
|
}
|
|
|
|
info!("Final score: {}", game.score());
|
|
actor.map(|a| a.dbg());
|
|
Ok(())
|
|
}
|