diff --git a/Cargo.lock b/Cargo.lock index bb89b2b..90df455 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -90,6 +90,17 @@ dependencies = [ "syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "clicolors-control" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", + "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "colored" version = "1.9.3" @@ -100,6 +111,26 @@ dependencies = [ "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "console" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "clicolors-control 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", + "encode_unicode 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", + "regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)", + "termios 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", + "unicode-width 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", + "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "fnv" version = "1.0.6" @@ -158,6 +189,17 @@ dependencies = [ "autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "indicatif" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "console 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "iovec" version = "0.1.4" @@ -293,6 +335,11 @@ dependencies = [ "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "number_prefix" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "pin-project-lite" version = "0.1.4" @@ -385,6 +432,19 @@ name = "redox_syscall" version = "0.1.56" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "regex" +version = "1.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "regex-syntax" +version = "0.6.17" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "sdl2" version = "0.33.0" @@ -467,11 +527,20 @@ dependencies = [ "syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "termios" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "tetris" version = "0.1.0" dependencies = [ "clap 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)", + "indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", "sdl2 0.33.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -610,7 +679,10 @@ dependencies = [ "checksum chrono 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)" = "80094f509cf8b5ae86a4966a39b3ff66cd7e2a3e594accec3743ff3fabeab5b2" "checksum clap 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)" = "" "checksum clap_derive 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)" = "" +"checksum clicolors-control 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90082ee5dcdd64dc4e9e0d37fbf3ee325419e39c0092191e0393df65518f741e" "checksum colored 1.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "f4ffc801dacf156c5854b9df4f425a626539c3a6ef7893cc0c5084a23f0b6c59" +"checksum console 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "6728a28023f207181b193262711102bfbaf47cc9d13bc71d0736607ef8efe88c" +"checksum encode_unicode 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" "checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3" "checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82" "checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7" @@ -619,6 +691,7 @@ dependencies = [ "checksum heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205" "checksum hermit-abi 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "1010591b26bbfe835e9faeabeb11866061cc7dcebffd56ad7d0942d0e61aefd8" "checksum indexmap 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "076f042c5b7b98f31d205f1249267e12a6518c1481e9dae9764af19b707d2292" +"checksum indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)" = "49a68371cf417889c9d7f98235b7102ea7c54fc59bcbd22f3dea785be9d27e40" "checksum iovec 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b2b3ea6ff95e175473f8ffe6a7eb7c00d054240321b84c57051175fe3c1e075e" "checksum kernel32-sys 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7507624b29483431c0ba2d82aece8ca6cdba9382bff4ddd0f7490560c056098d" "checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" @@ -634,6 +707,7 @@ dependencies = [ "checksum num-integer 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "3f6ea62e9d81a77cd3ee9a2a5b9b609447857f3d358704331e4ef39eb247fcba" "checksum num-traits 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "c62be47e61d1842b9170f0fdeec8eba98e60e90e5446449a0545e5152acd7096" "checksum num_cpus 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "46203554f085ff89c235cd12f7075f3233af9b11ed7c9e16dfe2560d03313ce6" +"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a" "checksum pin-project-lite 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "237844750cfbb86f67afe27eee600dfbbcb6188d734139b534cbfbf4f96792ae" "checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" "checksum proc-macro-error 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7" @@ -645,6 +719,8 @@ dependencies = [ "checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" "checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" "checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84" +"checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3" +"checksum regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)" = "7fe5bd57d1d7414c6b5ed48563a2c855d995ff777729dcd91c369ec7fea395ae" "checksum sdl2 0.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1f74124048ea86b5cd50236b2443f6f57cf4625a8e8818009b4e50dbb8729a43" "checksum sdl2-sys 0.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c2e1deb61ff274d29fb985017d4611d4004b113676eaa9c06754194caf82094e" "checksum signal-hook-registry 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "94f478ede9f64724c5d173d7bb56099ec3e2d9fc2774aac65d34b8b890405f41" @@ -654,6 +730,7 @@ dependencies = [ "checksum strsim 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c" "checksum syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)" = "123bd9499cfb380418d509322d7a6d52e5315f064fe4b3ad18a53d6b92c07859" "checksum syn-mid 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7be3539f6c128a931cf19dcee741c1af532c7fd387baa739c03dd2e96479338a" +"checksum termios 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "72b620c5ea021d75a735c943269bb07d30c9b77d6ac6b236bc8b5c496ef05625" "checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" "checksum time 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "db8dcfca086c1143c9270ac42a2bbd8a7ee477b78ac8e45b19abfb0cbede4b6f" "checksum tokio 0.2.13 (registry+https://github.com/rust-lang/crates.io-index)" = "0fa5e81d6bc4e67fe889d5783bd2a128ab2e0cfa487e0be16b6a8d177b101616" diff --git a/Cargo.toml b/Cargo.toml index e18140c..b91cf54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,4 +12,5 @@ tokio = { version = "0.2", features = ["full"] } log = "0.4" simple_logger = "1.6" sdl2 = { version = "0.33.0", features = ["ttf"] } -clap = { git = "https://github.com/clap-rs/clap/" } \ No newline at end of file +clap = { git = "https://github.com/clap-rs/clap/", features = ["color"] } +indicatif = "0.14" \ No newline at end of file diff --git a/src/actors/mod.rs b/src/actors/mod.rs index 81ff3a2..6c14aaf 100644 --- a/src/actors/mod.rs +++ b/src/actors/mod.rs @@ -5,7 +5,7 @@ use rand::RngCore; pub mod qlearning; -#[derive(Hash, PartialEq, Eq, Clone)] +#[derive(Hash, PartialEq, Eq, Clone, Debug)] pub struct State { matrix: Matrix, active_piece: Option, @@ -41,4 +41,12 @@ pub trait Actor { state: &State, legal_actions: &[Action], ) -> Action; + + fn update(&mut self, state: State, action: Action, next_state: State, reward: f64); + + fn set_learning_rate(&mut self, learning_rate: f64); + fn set_exploration_prob(&mut self, exploration_prob: f64); + fn set_discount_rate(&mut self, discount_rate: f64); + + fn dbg(&self); } diff --git a/src/actors/qlearning.rs b/src/actors/qlearning.rs index bdcfebb..ead215b 100644 --- a/src/actors/qlearning.rs +++ b/src/actors/qlearning.rs @@ -1,13 +1,15 @@ use crate::actors::{Actor, State}; use crate::game::Action; +use log::debug; use rand::seq::SliceRandom; use rand::Rng; use rand::RngCore; use std::collections::HashMap; + pub struct QLearningAgent { pub learning_rate: f64, pub exploration_prob: f64, - discount_rate: f64, + pub discount_rate: f64, q_values: HashMap>, } @@ -51,21 +53,6 @@ impl QLearningAgent { }) .unwrap_or_else(|| &0.0) } - - pub fn update(&mut self, state: State, action: Action, next_state: State, reward: f64) { - let cur_q_val = self.get_q_value(&state, action); - let new_q_val = cur_q_val - + self.learning_rate - * (reward + self.discount_rate * self.get_value_from_q_values(&next_state) - - cur_q_val); - if !self.q_values.contains_key(&state) { - self.q_values.insert(state.clone(), HashMap::default()); - } - self.q_values - .get_mut(&state) - .unwrap() - .insert(action, new_q_val); - } } impl Actor for QLearningAgent { @@ -83,4 +70,33 @@ impl Actor for QLearningAgent { self.get_action_from_q_values(state, legal_actions) } } + + fn update(&mut self, state: State, action: Action, next_state: State, reward: f64) { + let cur_q_val = self.get_q_value(&state, action); + let new_q_val = cur_q_val + + self.learning_rate + * (reward + self.discount_rate * self.get_value_from_q_values(&next_state) + - cur_q_val); + if !self.q_values.contains_key(&state) { + self.q_values.insert(state.clone(), HashMap::default()); + } + self.q_values + .get_mut(&state) + .unwrap() + .insert(action, new_q_val); + } + + fn dbg(&self) { + debug!("Total states: {}", self.q_values.len()); + } + + fn set_learning_rate(&mut self, learning_rate: f64) { + self.learning_rate = learning_rate; + } + fn set_exploration_prob(&mut self, exploration_prob: f64) { + self.exploration_prob = exploration_prob; + } + fn set_discount_rate(&mut self, discount_rate: f64) { + self.discount_rate = discount_rate; + } } diff --git a/src/cli.rs b/src/cli.rs index fff787e..f1bd2e8 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,4 +1,92 @@ -use clap::Clap; +use crate::actors::*; +use clap::{arg_enum, Clap}; +use log::Level; +use simple_logger::init_with_level; #[derive(Clap)] -pub struct Ops {} +pub struct Opts { + /// Add more flags to increase verbosity to log at the debug or trace level. + #[clap( + short = "v", + long = "verbose", + parse(from_occurrences), + conflicts_with("quiet") + )] + pub verbose: u8, + /// Add more flags to decrease verbosity to the warn, error, or silent level. + #[clap(short = "q", long = "quiet", parse(from_occurrences))] + pub quiet: u8, + #[clap(subcommand)] + pub subcmd: SubCommand, +} + +#[derive(Clap)] +pub enum SubCommand { + /// Play the game, without training an actor + Play(Play), + /// Train an actor to play the game. + Train(Train), +} + +#[derive(Clap)] +pub struct Play { + /// Disable gravity. Useful for debugging. + #[clap(short = "G", long = "no-gravity")] + pub no_gravity: bool, +} + +#[derive(Clap)] +pub struct Train { + /// Which agent to use for training. + pub agent: Agent, + + /// The rate at which temporal agents learn at. + #[clap(short = "a", long = "alpha", default_value = "0.3")] + pub learning_rate: f64, + + /// The rate at which agents explore new actions. Range is [0, 1]. + #[clap(short = "e", long = "epsilon", default_value = "0.1")] + pub exploration_prob: f64, + + /// The discount rate for future states. + #[clap(short = "g", long = "gamma", default_value = "0.7")] + pub discount_rate: f64, + + /// Stop learning during the evaluation of the agent. This sets the learning + /// rate to 0 when displaying the results. + #[clap(short = "L", long = "no-learn")] + pub no_learn_during_evaluation: bool, + + /// Stop exploring during the evaluation of the agent. This sets the + /// exploration rate to 0 when displaying the results. + #[clap(short = "E", long = "no-explore")] + pub no_explore_during_evaluation: bool, + + /// Number of episodes to train the agent + #[clap(short = "n", long = "num", default_value = "10")] + pub episodes: usize, +} + +arg_enum! { + #[derive(Debug)] + pub enum Agent { + QLearning + } +} + +pub fn init_verbosity(opts: &Opts) -> Result<(), Box> { + match (opts.quiet, opts.verbose) { + (0, 0) => init_with_level(Level::Info)?, + (0, 1) => init_with_level(Level::Debug)?, + (0, _) => init_with_level(Level::Trace)?, + (1, 0) => init_with_level(Level::Warn)?, + (2, 0) => init_with_level(Level::Error)?, + _ => (), + }; + + Ok(()) +} + +pub fn get_actor() -> impl Actor { + qlearning::QLearningAgent::default() +} diff --git a/src/game.rs b/src/game.rs index f3d1970..e8b4681 100644 --- a/src/game.rs +++ b/src/game.rs @@ -116,6 +116,7 @@ impl Game { pub fn score(&self) -> u32 { self.score } + pub fn is_game_over(&self) -> Option { self.is_game_over.or_else(|| { self.playfield @@ -125,7 +126,6 @@ impl Game { } fn update_gravity_tick(&mut self) { - // self.next_gravity_tick = (-1 as i64) as u64; self.next_gravity_tick = self.tick + TICKS_PER_SECOND as u64; } @@ -178,6 +178,7 @@ impl Game { if cleared_lines > 0 { trace!("Lines were cleared."); self.line_clears += cleared_lines as u32; + self.score += (cleared_lines * self.level as usize) as u32; self.level = (self.line_clears / 10) as u8; self.playfield.active_piece = None; self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY; @@ -230,6 +231,7 @@ impl Controllable for Game { self.update_lock_tick(); } } + fn move_right(&mut self) { if self.playfield.active_piece.is_none() { return; @@ -240,6 +242,7 @@ impl Controllable for Game { self.update_lock_tick(); } } + fn move_down(&mut self) { if self.playfield.active_piece.is_none() { return; @@ -251,6 +254,7 @@ impl Controllable for Game { self.update_lock_tick(); } } + fn rotate_left(&mut self) { if self.playfield.active_piece.is_none() { return; @@ -267,6 +271,7 @@ impl Controllable for Game { Err(_) => (), } } + fn rotate_right(&mut self) { if self.playfield.active_piece.is_none() { return; @@ -283,6 +288,7 @@ impl Controllable for Game { Err(_) => (), } } + fn hard_drop(&mut self) { if self.playfield.active_piece.is_none() { return; @@ -323,6 +329,7 @@ impl Controllable for Game { fn get_legal_actions(&self) -> Vec { let mut legal_actions = vec![ + Action::Nothing, Action::MoveLeft, Action::MoveRight, Action::SoftDrop, diff --git a/src/graphics/standard_renderer.rs b/src/graphics/standard_renderer.rs index fa8122d..fa0ef18 100644 --- a/src/graphics/standard_renderer.rs +++ b/src/graphics/standard_renderer.rs @@ -4,6 +4,7 @@ use crate::{ playfield::{PlayField, PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, tetromino::{Position, RotationState, Tetromino, TetrominoType}, }; +use log::error; use sdl2::{ pixels::Color, rect::Rect, @@ -21,7 +22,10 @@ pub fn render(canvas: &mut Canvas, game: &Game) { game_width as u32, game_height as u32, ))); - game.render(canvas); + match game.render(canvas) { + Ok(_) => (), + Err(e) => error!("{}", e), + }; } impl Renderable for Game { @@ -74,7 +78,7 @@ impl Renderable for PlayField { match self.hold_piece() { Some(p) => { canvas.set_draw_color(p); - canvas.fill_rect(Rect::new(-32 - UI_PADDING as i32, 0, CELL_SIZE, CELL_SIZE)); + canvas.fill_rect(Rect::new(-32 - UI_PADDING as i32, 0, CELL_SIZE, CELL_SIZE))?; } None => (), } diff --git a/src/main.rs b/src/main.rs index 51cbb1c..4d3fd27 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,19 @@ use actors::*; +use clap::Clap; +use cli::*; use game::{Action, Controllable, Game, Tickable}; use graphics::standard_renderer; use graphics::COLOR_BACKGROUND; +use indicatif::ProgressIterator; use log::{debug, info, trace}; use rand::SeedableRng; use sdl2::event::Event; use sdl2::keyboard::Keycode; -use simple_logger; use std::time::Duration; use tokio::time::interval; mod actors; +mod cli; mod game; mod graphics; mod playfield; @@ -22,17 +25,50 @@ const TICKS_PER_SECOND: usize = 60; #[tokio::main] async fn main() -> Result<(), Box> { - // simple_logger::init()?; - simple_logger::init_with_level(log::Level::Info)?; + let opts = crate::cli::Opts::parse(); - let mut actor = qlearning::QLearningAgent::default(); - let mut rng = rand::rngs::StdRng::seed_from_u64(1337); + init_verbosity(&opts)?; + let mut actor = None; + match opts.subcmd { + SubCommand::Play(sub_opts) => {} + SubCommand::Train(sub_opts) => { + let mut to_train = get_actor(); + to_train.set_learning_rate(sub_opts.learning_rate); + to_train.set_discount_rate(sub_opts.discount_rate); + to_train.set_exploration_prob(sub_opts.exploration_prob); + + info!( + "Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}", + sub_opts.learning_rate, + sub_opts.discount_rate, + sub_opts.exploration_prob + ); + let mut trained_actor = train_actor(sub_opts.episodes, to_train); + if sub_opts.no_explore_during_evaluation { + trained_actor.set_exploration_prob(0.0); + } + + if sub_opts.no_learn_during_evaluation { + trained_actor.set_learning_rate(0.0); + } + + actor = Some(trained_actor); + } + } + + play_game(actor).await?; + Ok(()) +} + +fn train_actor(episodes: usize, mut actor: impl Actor) -> impl Actor { + let mut rng = rand::rngs::StdRng::from_entropy(); let mut avg = 0.0; - for i in 0..100000 { - if i % 100 == 0 { - info!("Last 100 scores avg: {}", avg); + for i in (0..episodes).progress() { + if i != 0 && i % (episodes / 10) == 0 { + info!("Last {} scores avg: {}", (episodes / 10), avg); + println!(); avg = 0.0; } let mut game = Game::default(); @@ -53,18 +89,28 @@ async fn main() -> Result<(), Box> { } let new_state = (&game).into(); - let reward = game.score() - cur_score; - actor.update(cur_state, action, new_state, reward as f64); + let mut reward = game.score() as f64 - cur_score as f64; + if action != Action::Nothing { + reward -= 10.0; + } + + if game.is_game_over().is_some() { + reward = -100.0; + } + + actor.update(cur_state, action, new_state, reward); game.tick(); } - avg += game.score() as f64 / 100.0; - // info!("Game over with score of {}", game.score()); + avg += game.score() as f64 / (episodes / 10) as f64; } - actor.exploration_prob = 0.0; + actor +} +async fn play_game(mut actor: Option) -> Result<(), Box> { + let mut rng = rand::rngs::StdRng::from_entropy(); let sdl_context = sdl2::init()?; let video_subsystem = sdl_context.video()?; let window = video_subsystem @@ -86,18 +132,11 @@ async fn main() -> Result<(), Box> { } let cur_state = (&game).into(); - let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())); - match action { - Action::Nothing => (), - Action::MoveLeft => game.move_left(), - Action::MoveRight => game.move_right(), - Action::SoftDrop => game.move_down(), - Action::HardDrop => game.hard_drop(), - Action::Hold => game.hold(), - Action::RotateLeft => game.rotate_left(), - Action::RotateRight => game.rotate_right(), - } + // If there's an actor, the player action will get overridden. If not, + // then then the player action falls through, if there is one. This is + // to allow for restarting and quitting the game from the GUI. + let mut action = None; for event in event_pump.poll_iter() { match event { Event::Quit { .. } @@ -113,35 +152,35 @@ async fn main() -> Result<(), Box> { .. } => { debug!("Move left registered"); - game.move_left(); + action = Some(Action::MoveLeft); } Event::KeyDown { keycode: Some(Keycode::Right), .. } => { debug!("Move right registered"); - game.move_right(); + action = Some(Action::MoveRight); } Event::KeyDown { keycode: Some(Keycode::Down), .. } => { debug!("Soft drop registered"); - game.move_down(); + action = Some(Action::SoftDrop); } Event::KeyDown { keycode: Some(Keycode::Z), .. } => { debug!("Rotate left registered"); - game.rotate_left(); + action = Some(Action::RotateLeft); } Event::KeyDown { keycode: Some(Keycode::X), .. } => { debug!("Rotate right registered"); - game.rotate_right(); + action = Some(Action::RotateRight); } Event::KeyDown { keycode: Some(Keycode::Space), @@ -152,14 +191,14 @@ async fn main() -> Result<(), Box> { .. } => { debug!("Hard drop registered"); - game.hard_drop(); + action = Some(Action::HardDrop); } Event::KeyDown { keycode: Some(Keycode::LShift), .. } => { debug!("Hold registered"); - game.hold(); + action = Some(Action::Hold); } Event::KeyDown { keycode: Some(Keycode::R), @@ -174,6 +213,22 @@ async fn main() -> Result<(), Box> { _ => (), } } + + actor.as_mut().map(|actor| { + action = Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()))); + }); + + action.map(|action| match action { + Action::Nothing => (), + Action::MoveLeft => game.move_left(), + Action::MoveRight => game.move_right(), + Action::SoftDrop => game.move_down(), + Action::HardDrop => game.hard_drop(), + Action::Hold => game.hold(), + Action::RotateLeft => game.rotate_left(), + Action::RotateRight => game.rotate_right(), + }); + game.tick(); canvas.set_draw_color(COLOR_BACKGROUND); canvas.clear(); @@ -182,7 +237,7 @@ async fn main() -> Result<(), Box> { interval.tick().await; } - dbg!(game); - + info!("Final score: {}", game.score()); + actor.map(|a| a.dbg()); Ok(()) }