qlearning

This commit is contained in:
Edward Shen 2020-03-30 17:23:51 -04:00
parent 4444af9d07
commit b4660d9f45
Signed by: edward
GPG key ID: 19182661E818369F
8 changed files with 312 additions and 56 deletions

77
Cargo.lock generated
View file

@ -90,6 +90,17 @@ dependencies = [
"syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "clicolors-control"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "colored"
version = "1.9.3"
@ -100,6 +111,26 @@ dependencies = [
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "console"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"clicolors-control 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
"encode_unicode 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
"termios 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-width 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "encode_unicode"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "fnv"
version = "1.0.6"
@ -158,6 +189,17 @@ dependencies = [
"autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "indicatif"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"console 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "iovec"
version = "0.1.4"
@ -293,6 +335,11 @@ dependencies = [
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "number_prefix"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "pin-project-lite"
version = "0.1.4"
@ -385,6 +432,19 @@ name = "redox_syscall"
version = "0.1.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "regex"
version = "1.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "regex-syntax"
version = "0.6.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "sdl2"
version = "0.33.0"
@ -467,11 +527,20 @@ dependencies = [
"syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "termios"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "tetris"
version = "0.1.0"
dependencies = [
"clap 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)",
"indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
"sdl2 0.33.0 (registry+https://github.com/rust-lang/crates.io-index)",
@ -610,7 +679,10 @@ dependencies = [
"checksum chrono 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)" = "80094f509cf8b5ae86a4966a39b3ff66cd7e2a3e594accec3743ff3fabeab5b2"
"checksum clap 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)" = "<none>"
"checksum clap_derive 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)" = "<none>"
"checksum clicolors-control 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90082ee5dcdd64dc4e9e0d37fbf3ee325419e39c0092191e0393df65518f741e"
"checksum colored 1.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "f4ffc801dacf156c5854b9df4f425a626539c3a6ef7893cc0c5084a23f0b6c59"
"checksum console 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "6728a28023f207181b193262711102bfbaf47cc9d13bc71d0736607ef8efe88c"
"checksum encode_unicode 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
"checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3"
"checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82"
"checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
@ -619,6 +691,7 @@ dependencies = [
"checksum heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205"
"checksum hermit-abi 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "1010591b26bbfe835e9faeabeb11866061cc7dcebffd56ad7d0942d0e61aefd8"
"checksum indexmap 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "076f042c5b7b98f31d205f1249267e12a6518c1481e9dae9764af19b707d2292"
"checksum indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)" = "49a68371cf417889c9d7f98235b7102ea7c54fc59bcbd22f3dea785be9d27e40"
"checksum iovec 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b2b3ea6ff95e175473f8ffe6a7eb7c00d054240321b84c57051175fe3c1e075e"
"checksum kernel32-sys 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7507624b29483431c0ba2d82aece8ca6cdba9382bff4ddd0f7490560c056098d"
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
@ -634,6 +707,7 @@ dependencies = [
"checksum num-integer 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "3f6ea62e9d81a77cd3ee9a2a5b9b609447857f3d358704331e4ef39eb247fcba"
"checksum num-traits 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "c62be47e61d1842b9170f0fdeec8eba98e60e90e5446449a0545e5152acd7096"
"checksum num_cpus 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "46203554f085ff89c235cd12f7075f3233af9b11ed7c9e16dfe2560d03313ce6"
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
"checksum pin-project-lite 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "237844750cfbb86f67afe27eee600dfbbcb6188d734139b534cbfbf4f96792ae"
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
"checksum proc-macro-error 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7"
@ -645,6 +719,8 @@ dependencies = [
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84"
"checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3"
"checksum regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)" = "7fe5bd57d1d7414c6b5ed48563a2c855d995ff777729dcd91c369ec7fea395ae"
"checksum sdl2 0.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1f74124048ea86b5cd50236b2443f6f57cf4625a8e8818009b4e50dbb8729a43"
"checksum sdl2-sys 0.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c2e1deb61ff274d29fb985017d4611d4004b113676eaa9c06754194caf82094e"
"checksum signal-hook-registry 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "94f478ede9f64724c5d173d7bb56099ec3e2d9fc2774aac65d34b8b890405f41"
@ -654,6 +730,7 @@ dependencies = [
"checksum strsim 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
"checksum syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)" = "123bd9499cfb380418d509322d7a6d52e5315f064fe4b3ad18a53d6b92c07859"
"checksum syn-mid 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7be3539f6c128a931cf19dcee741c1af532c7fd387baa739c03dd2e96479338a"
"checksum termios 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "72b620c5ea021d75a735c943269bb07d30c9b77d6ac6b236bc8b5c496ef05625"
"checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
"checksum time 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "db8dcfca086c1143c9270ac42a2bbd8a7ee477b78ac8e45b19abfb0cbede4b6f"
"checksum tokio 0.2.13 (registry+https://github.com/rust-lang/crates.io-index)" = "0fa5e81d6bc4e67fe889d5783bd2a128ab2e0cfa487e0be16b6a8d177b101616"

View file

@ -12,4 +12,5 @@ tokio = { version = "0.2", features = ["full"] }
log = "0.4"
simple_logger = "1.6"
sdl2 = { version = "0.33.0", features = ["ttf"] }
clap = { git = "https://github.com/clap-rs/clap/" }
clap = { git = "https://github.com/clap-rs/clap/", features = ["color"] }
indicatif = "0.14"

View file

@ -5,7 +5,7 @@ use rand::RngCore;
pub mod qlearning;
#[derive(Hash, PartialEq, Eq, Clone)]
#[derive(Hash, PartialEq, Eq, Clone, Debug)]
pub struct State {
matrix: Matrix,
active_piece: Option<Tetromino>,
@ -41,4 +41,12 @@ pub trait Actor {
state: &State,
legal_actions: &[Action],
) -> Action;
fn update(&mut self, state: State, action: Action, next_state: State, reward: f64);
fn set_learning_rate(&mut self, learning_rate: f64);
fn set_exploration_prob(&mut self, exploration_prob: f64);
fn set_discount_rate(&mut self, discount_rate: f64);
fn dbg(&self);
}

View file

@ -1,13 +1,15 @@
use crate::actors::{Actor, State};
use crate::game::Action;
use log::debug;
use rand::seq::SliceRandom;
use rand::Rng;
use rand::RngCore;
use std::collections::HashMap;
pub struct QLearningAgent {
pub learning_rate: f64,
pub exploration_prob: f64,
discount_rate: f64,
pub discount_rate: f64,
q_values: HashMap<State, HashMap<Action, f64>>,
}
@ -51,21 +53,6 @@ impl QLearningAgent {
})
.unwrap_or_else(|| &0.0)
}
pub fn update(&mut self, state: State, action: Action, next_state: State, reward: f64) {
let cur_q_val = self.get_q_value(&state, action);
let new_q_val = cur_q_val
+ self.learning_rate
* (reward + self.discount_rate * self.get_value_from_q_values(&next_state)
- cur_q_val);
if !self.q_values.contains_key(&state) {
self.q_values.insert(state.clone(), HashMap::default());
}
self.q_values
.get_mut(&state)
.unwrap()
.insert(action, new_q_val);
}
}
impl Actor for QLearningAgent {
@ -83,4 +70,33 @@ impl Actor for QLearningAgent {
self.get_action_from_q_values(state, legal_actions)
}
}
fn update(&mut self, state: State, action: Action, next_state: State, reward: f64) {
let cur_q_val = self.get_q_value(&state, action);
let new_q_val = cur_q_val
+ self.learning_rate
* (reward + self.discount_rate * self.get_value_from_q_values(&next_state)
- cur_q_val);
if !self.q_values.contains_key(&state) {
self.q_values.insert(state.clone(), HashMap::default());
}
self.q_values
.get_mut(&state)
.unwrap()
.insert(action, new_q_val);
}
fn dbg(&self) {
debug!("Total states: {}", self.q_values.len());
}
fn set_learning_rate(&mut self, learning_rate: f64) {
self.learning_rate = learning_rate;
}
fn set_exploration_prob(&mut self, exploration_prob: f64) {
self.exploration_prob = exploration_prob;
}
fn set_discount_rate(&mut self, discount_rate: f64) {
self.discount_rate = discount_rate;
}
}

View file

@ -1,4 +1,92 @@
use clap::Clap;
use crate::actors::*;
use clap::{arg_enum, Clap};
use log::Level;
use simple_logger::init_with_level;
#[derive(Clap)]
pub struct Ops {}
pub struct Opts {
/// Add more flags to increase verbosity to log at the debug or trace level.
#[clap(
short = "v",
long = "verbose",
parse(from_occurrences),
conflicts_with("quiet")
)]
pub verbose: u8,
/// Add more flags to decrease verbosity to the warn, error, or silent level.
#[clap(short = "q", long = "quiet", parse(from_occurrences))]
pub quiet: u8,
#[clap(subcommand)]
pub subcmd: SubCommand,
}
#[derive(Clap)]
pub enum SubCommand {
/// Play the game, without training an actor
Play(Play),
/// Train an actor to play the game.
Train(Train),
}
#[derive(Clap)]
pub struct Play {
/// Disable gravity. Useful for debugging.
#[clap(short = "G", long = "no-gravity")]
pub no_gravity: bool,
}
#[derive(Clap)]
pub struct Train {
/// Which agent to use for training.
pub agent: Agent,
/// The rate at which temporal agents learn at.
#[clap(short = "a", long = "alpha", default_value = "0.3")]
pub learning_rate: f64,
/// The rate at which agents explore new actions. Range is [0, 1].
#[clap(short = "e", long = "epsilon", default_value = "0.1")]
pub exploration_prob: f64,
/// The discount rate for future states.
#[clap(short = "g", long = "gamma", default_value = "0.7")]
pub discount_rate: f64,
/// Stop learning during the evaluation of the agent. This sets the learning
/// rate to 0 when displaying the results.
#[clap(short = "L", long = "no-learn")]
pub no_learn_during_evaluation: bool,
/// Stop exploring during the evaluation of the agent. This sets the
/// exploration rate to 0 when displaying the results.
#[clap(short = "E", long = "no-explore")]
pub no_explore_during_evaluation: bool,
/// Number of episodes to train the agent
#[clap(short = "n", long = "num", default_value = "10")]
pub episodes: usize,
}
arg_enum! {
#[derive(Debug)]
pub enum Agent {
QLearning
}
}
pub fn init_verbosity(opts: &Opts) -> Result<(), Box<dyn std::error::Error>> {
match (opts.quiet, opts.verbose) {
(0, 0) => init_with_level(Level::Info)?,
(0, 1) => init_with_level(Level::Debug)?,
(0, _) => init_with_level(Level::Trace)?,
(1, 0) => init_with_level(Level::Warn)?,
(2, 0) => init_with_level(Level::Error)?,
_ => (),
};
Ok(())
}
pub fn get_actor() -> impl Actor {
qlearning::QLearningAgent::default()
}

View file

@ -116,6 +116,7 @@ impl Game {
pub fn score(&self) -> u32 {
self.score
}
pub fn is_game_over(&self) -> Option<LossReason> {
self.is_game_over.or_else(|| {
self.playfield
@ -125,7 +126,6 @@ impl Game {
}
fn update_gravity_tick(&mut self) {
// self.next_gravity_tick = (-1 as i64) as u64;
self.next_gravity_tick = self.tick + TICKS_PER_SECOND as u64;
}
@ -178,6 +178,7 @@ impl Game {
if cleared_lines > 0 {
trace!("Lines were cleared.");
self.line_clears += cleared_lines as u32;
self.score += (cleared_lines * self.level as usize) as u32;
self.level = (self.line_clears / 10) as u8;
self.playfield.active_piece = None;
self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY;
@ -230,6 +231,7 @@ impl Controllable for Game {
self.update_lock_tick();
}
}
fn move_right(&mut self) {
if self.playfield.active_piece.is_none() {
return;
@ -240,6 +242,7 @@ impl Controllable for Game {
self.update_lock_tick();
}
}
fn move_down(&mut self) {
if self.playfield.active_piece.is_none() {
return;
@ -251,6 +254,7 @@ impl Controllable for Game {
self.update_lock_tick();
}
}
fn rotate_left(&mut self) {
if self.playfield.active_piece.is_none() {
return;
@ -267,6 +271,7 @@ impl Controllable for Game {
Err(_) => (),
}
}
fn rotate_right(&mut self) {
if self.playfield.active_piece.is_none() {
return;
@ -283,6 +288,7 @@ impl Controllable for Game {
Err(_) => (),
}
}
fn hard_drop(&mut self) {
if self.playfield.active_piece.is_none() {
return;
@ -323,6 +329,7 @@ impl Controllable for Game {
fn get_legal_actions(&self) -> Vec<Action> {
let mut legal_actions = vec![
Action::Nothing,
Action::MoveLeft,
Action::MoveRight,
Action::SoftDrop,

View file

@ -4,6 +4,7 @@ use crate::{
playfield::{PlayField, PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
tetromino::{Position, RotationState, Tetromino, TetrominoType},
};
use log::error;
use sdl2::{
pixels::Color,
rect::Rect,
@ -21,7 +22,10 @@ pub fn render(canvas: &mut Canvas<Window>, game: &Game) {
game_width as u32,
game_height as u32,
)));
game.render(canvas);
match game.render(canvas) {
Ok(_) => (),
Err(e) => error!("{}", e),
};
}
impl Renderable for Game {
@ -74,7 +78,7 @@ impl Renderable for PlayField {
match self.hold_piece() {
Some(p) => {
canvas.set_draw_color(p);
canvas.fill_rect(Rect::new(-32 - UI_PADDING as i32, 0, CELL_SIZE, CELL_SIZE));
canvas.fill_rect(Rect::new(-32 - UI_PADDING as i32, 0, CELL_SIZE, CELL_SIZE))?;
}
None => (),
}

View file

@ -1,16 +1,19 @@
use actors::*;
use clap::Clap;
use cli::*;
use game::{Action, Controllable, Game, Tickable};
use graphics::standard_renderer;
use graphics::COLOR_BACKGROUND;
use indicatif::ProgressIterator;
use log::{debug, info, trace};
use rand::SeedableRng;
use sdl2::event::Event;
use sdl2::keyboard::Keycode;
use simple_logger;
use std::time::Duration;
use tokio::time::interval;
mod actors;
mod cli;
mod game;
mod graphics;
mod playfield;
@ -22,17 +25,50 @@ const TICKS_PER_SECOND: usize = 60;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// simple_logger::init()?;
simple_logger::init_with_level(log::Level::Info)?;
let opts = crate::cli::Opts::parse();
let mut actor = qlearning::QLearningAgent::default();
let mut rng = rand::rngs::StdRng::seed_from_u64(1337);
init_verbosity(&opts)?;
let mut actor = None;
match opts.subcmd {
SubCommand::Play(sub_opts) => {}
SubCommand::Train(sub_opts) => {
let mut to_train = get_actor();
to_train.set_learning_rate(sub_opts.learning_rate);
to_train.set_discount_rate(sub_opts.discount_rate);
to_train.set_exploration_prob(sub_opts.exploration_prob);
info!(
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
sub_opts.learning_rate,
sub_opts.discount_rate,
sub_opts.exploration_prob
);
let mut trained_actor = train_actor(sub_opts.episodes, to_train);
if sub_opts.no_explore_during_evaluation {
trained_actor.set_exploration_prob(0.0);
}
if sub_opts.no_learn_during_evaluation {
trained_actor.set_learning_rate(0.0);
}
actor = Some(trained_actor);
}
}
play_game(actor).await?;
Ok(())
}
fn train_actor(episodes: usize, mut actor: impl Actor) -> impl Actor {
let mut rng = rand::rngs::StdRng::from_entropy();
let mut avg = 0.0;
for i in 0..100000 {
if i % 100 == 0 {
info!("Last 100 scores avg: {}", avg);
for i in (0..episodes).progress() {
if i != 0 && i % (episodes / 10) == 0 {
info!("Last {} scores avg: {}", (episodes / 10), avg);
println!();
avg = 0.0;
}
let mut game = Game::default();
@ -53,18 +89,28 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
let new_state = (&game).into();
let reward = game.score() - cur_score;
actor.update(cur_state, action, new_state, reward as f64);
let mut reward = game.score() as f64 - cur_score as f64;
if action != Action::Nothing {
reward -= 10.0;
}
if game.is_game_over().is_some() {
reward = -100.0;
}
actor.update(cur_state, action, new_state, reward);
game.tick();
}
avg += game.score() as f64 / 100.0;
// info!("Game over with score of {}", game.score());
avg += game.score() as f64 / (episodes / 10) as f64;
}
actor.exploration_prob = 0.0;
actor
}
async fn play_game(mut actor: Option<impl Actor>) -> Result<(), Box<dyn std::error::Error>> {
let mut rng = rand::rngs::StdRng::from_entropy();
let sdl_context = sdl2::init()?;
let video_subsystem = sdl_context.video()?;
let window = video_subsystem
@ -86,18 +132,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
let cur_state = (&game).into();
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
}
// If there's an actor, the player action will get overridden. If not,
// then then the player action falls through, if there is one. This is
// to allow for restarting and quitting the game from the GUI.
let mut action = None;
for event in event_pump.poll_iter() {
match event {
Event::Quit { .. }
@ -113,35 +152,35 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
..
} => {
debug!("Move left registered");
game.move_left();
action = Some(Action::MoveLeft);
}
Event::KeyDown {
keycode: Some(Keycode::Right),
..
} => {
debug!("Move right registered");
game.move_right();
action = Some(Action::MoveRight);
}
Event::KeyDown {
keycode: Some(Keycode::Down),
..
} => {
debug!("Soft drop registered");
game.move_down();
action = Some(Action::SoftDrop);
}
Event::KeyDown {
keycode: Some(Keycode::Z),
..
} => {
debug!("Rotate left registered");
game.rotate_left();
action = Some(Action::RotateLeft);
}
Event::KeyDown {
keycode: Some(Keycode::X),
..
} => {
debug!("Rotate right registered");
game.rotate_right();
action = Some(Action::RotateRight);
}
Event::KeyDown {
keycode: Some(Keycode::Space),
@ -152,14 +191,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
..
} => {
debug!("Hard drop registered");
game.hard_drop();
action = Some(Action::HardDrop);
}
Event::KeyDown {
keycode: Some(Keycode::LShift),
..
} => {
debug!("Hold registered");
game.hold();
action = Some(Action::Hold);
}
Event::KeyDown {
keycode: Some(Keycode::R),
@ -174,6 +213,22 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
_ => (),
}
}
actor.as_mut().map(|actor| {
action = Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())));
});
action.map(|action| match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
});
game.tick();
canvas.set_draw_color(COLOR_BACKGROUND);
canvas.clear();
@ -182,7 +237,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
interval.tick().await;
}
dbg!(game);
info!("Final score: {}", game.score());
actor.map(|a| a.dbg());
Ok(())
}