qlearning
This commit is contained in:
parent
4444af9d07
commit
b4660d9f45
8 changed files with 312 additions and 56 deletions
77
Cargo.lock
generated
77
Cargo.lock
generated
|
@ -90,6 +90,17 @@ dependencies = [
|
|||
"syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clicolors-control"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colored"
|
||||
version = "1.9.3"
|
||||
|
@ -100,6 +111,26 @@ dependencies = [
|
|||
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "console"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"clicolors-control 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"encode_unicode 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"termios 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-width 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "encode_unicode"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.6"
|
||||
|
@ -158,6 +189,17 @@ dependencies = [
|
|||
"autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indicatif"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"console 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iovec"
|
||||
version = "0.1.4"
|
||||
|
@ -293,6 +335,11 @@ dependencies = [
|
|||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "number_prefix"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.1.4"
|
||||
|
@ -385,6 +432,19 @@ name = "redox_syscall"
|
|||
version = "0.1.56"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "sdl2"
|
||||
version = "0.33.0"
|
||||
|
@ -467,11 +527,20 @@ dependencies = [
|
|||
"syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termios"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tetris"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"clap 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)",
|
||||
"indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"sdl2 0.33.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
|
@ -610,7 +679,10 @@ dependencies = [
|
|||
"checksum chrono 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)" = "80094f509cf8b5ae86a4966a39b3ff66cd7e2a3e594accec3743ff3fabeab5b2"
|
||||
"checksum clap 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)" = "<none>"
|
||||
"checksum clap_derive 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)" = "<none>"
|
||||
"checksum clicolors-control 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90082ee5dcdd64dc4e9e0d37fbf3ee325419e39c0092191e0393df65518f741e"
|
||||
"checksum colored 1.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "f4ffc801dacf156c5854b9df4f425a626539c3a6ef7893cc0c5084a23f0b6c59"
|
||||
"checksum console 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "6728a28023f207181b193262711102bfbaf47cc9d13bc71d0736607ef8efe88c"
|
||||
"checksum encode_unicode 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
|
||||
"checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3"
|
||||
"checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82"
|
||||
"checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
|
||||
|
@ -619,6 +691,7 @@ dependencies = [
|
|||
"checksum heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205"
|
||||
"checksum hermit-abi 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "1010591b26bbfe835e9faeabeb11866061cc7dcebffd56ad7d0942d0e61aefd8"
|
||||
"checksum indexmap 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "076f042c5b7b98f31d205f1249267e12a6518c1481e9dae9764af19b707d2292"
|
||||
"checksum indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)" = "49a68371cf417889c9d7f98235b7102ea7c54fc59bcbd22f3dea785be9d27e40"
|
||||
"checksum iovec 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b2b3ea6ff95e175473f8ffe6a7eb7c00d054240321b84c57051175fe3c1e075e"
|
||||
"checksum kernel32-sys 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7507624b29483431c0ba2d82aece8ca6cdba9382bff4ddd0f7490560c056098d"
|
||||
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
|
@ -634,6 +707,7 @@ dependencies = [
|
|||
"checksum num-integer 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "3f6ea62e9d81a77cd3ee9a2a5b9b609447857f3d358704331e4ef39eb247fcba"
|
||||
"checksum num-traits 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "c62be47e61d1842b9170f0fdeec8eba98e60e90e5446449a0545e5152acd7096"
|
||||
"checksum num_cpus 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "46203554f085ff89c235cd12f7075f3233af9b11ed7c9e16dfe2560d03313ce6"
|
||||
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
|
||||
"checksum pin-project-lite 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "237844750cfbb86f67afe27eee600dfbbcb6188d734139b534cbfbf4f96792ae"
|
||||
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
|
||||
"checksum proc-macro-error 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7"
|
||||
|
@ -645,6 +719,8 @@ dependencies = [
|
|||
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||
"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84"
|
||||
"checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3"
|
||||
"checksum regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)" = "7fe5bd57d1d7414c6b5ed48563a2c855d995ff777729dcd91c369ec7fea395ae"
|
||||
"checksum sdl2 0.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1f74124048ea86b5cd50236b2443f6f57cf4625a8e8818009b4e50dbb8729a43"
|
||||
"checksum sdl2-sys 0.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c2e1deb61ff274d29fb985017d4611d4004b113676eaa9c06754194caf82094e"
|
||||
"checksum signal-hook-registry 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "94f478ede9f64724c5d173d7bb56099ec3e2d9fc2774aac65d34b8b890405f41"
|
||||
|
@ -654,6 +730,7 @@ dependencies = [
|
|||
"checksum strsim 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
|
||||
"checksum syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)" = "123bd9499cfb380418d509322d7a6d52e5315f064fe4b3ad18a53d6b92c07859"
|
||||
"checksum syn-mid 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7be3539f6c128a931cf19dcee741c1af532c7fd387baa739c03dd2e96479338a"
|
||||
"checksum termios 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "72b620c5ea021d75a735c943269bb07d30c9b77d6ac6b236bc8b5c496ef05625"
|
||||
"checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
|
||||
"checksum time 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "db8dcfca086c1143c9270ac42a2bbd8a7ee477b78ac8e45b19abfb0cbede4b6f"
|
||||
"checksum tokio 0.2.13 (registry+https://github.com/rust-lang/crates.io-index)" = "0fa5e81d6bc4e67fe889d5783bd2a128ab2e0cfa487e0be16b6a8d177b101616"
|
||||
|
|
|
@ -12,4 +12,5 @@ tokio = { version = "0.2", features = ["full"] }
|
|||
log = "0.4"
|
||||
simple_logger = "1.6"
|
||||
sdl2 = { version = "0.33.0", features = ["ttf"] }
|
||||
clap = { git = "https://github.com/clap-rs/clap/" }
|
||||
clap = { git = "https://github.com/clap-rs/clap/", features = ["color"] }
|
||||
indicatif = "0.14"
|
|
@ -5,7 +5,7 @@ use rand::RngCore;
|
|||
|
||||
pub mod qlearning;
|
||||
|
||||
#[derive(Hash, PartialEq, Eq, Clone)]
|
||||
#[derive(Hash, PartialEq, Eq, Clone, Debug)]
|
||||
pub struct State {
|
||||
matrix: Matrix,
|
||||
active_piece: Option<Tetromino>,
|
||||
|
@ -41,4 +41,12 @@ pub trait Actor {
|
|||
state: &State,
|
||||
legal_actions: &[Action],
|
||||
) -> Action;
|
||||
|
||||
fn update(&mut self, state: State, action: Action, next_state: State, reward: f64);
|
||||
|
||||
fn set_learning_rate(&mut self, learning_rate: f64);
|
||||
fn set_exploration_prob(&mut self, exploration_prob: f64);
|
||||
fn set_discount_rate(&mut self, discount_rate: f64);
|
||||
|
||||
fn dbg(&self);
|
||||
}
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
use crate::actors::{Actor, State};
|
||||
use crate::game::Action;
|
||||
use log::debug;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
use rand::RngCore;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub struct QLearningAgent {
|
||||
pub learning_rate: f64,
|
||||
pub exploration_prob: f64,
|
||||
discount_rate: f64,
|
||||
pub discount_rate: f64,
|
||||
q_values: HashMap<State, HashMap<Action, f64>>,
|
||||
}
|
||||
|
||||
|
@ -51,21 +53,6 @@ impl QLearningAgent {
|
|||
})
|
||||
.unwrap_or_else(|| &0.0)
|
||||
}
|
||||
|
||||
pub fn update(&mut self, state: State, action: Action, next_state: State, reward: f64) {
|
||||
let cur_q_val = self.get_q_value(&state, action);
|
||||
let new_q_val = cur_q_val
|
||||
+ self.learning_rate
|
||||
* (reward + self.discount_rate * self.get_value_from_q_values(&next_state)
|
||||
- cur_q_val);
|
||||
if !self.q_values.contains_key(&state) {
|
||||
self.q_values.insert(state.clone(), HashMap::default());
|
||||
}
|
||||
self.q_values
|
||||
.get_mut(&state)
|
||||
.unwrap()
|
||||
.insert(action, new_q_val);
|
||||
}
|
||||
}
|
||||
|
||||
impl Actor for QLearningAgent {
|
||||
|
@ -83,4 +70,33 @@ impl Actor for QLearningAgent {
|
|||
self.get_action_from_q_values(state, legal_actions)
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, state: State, action: Action, next_state: State, reward: f64) {
|
||||
let cur_q_val = self.get_q_value(&state, action);
|
||||
let new_q_val = cur_q_val
|
||||
+ self.learning_rate
|
||||
* (reward + self.discount_rate * self.get_value_from_q_values(&next_state)
|
||||
- cur_q_val);
|
||||
if !self.q_values.contains_key(&state) {
|
||||
self.q_values.insert(state.clone(), HashMap::default());
|
||||
}
|
||||
self.q_values
|
||||
.get_mut(&state)
|
||||
.unwrap()
|
||||
.insert(action, new_q_val);
|
||||
}
|
||||
|
||||
fn dbg(&self) {
|
||||
debug!("Total states: {}", self.q_values.len());
|
||||
}
|
||||
|
||||
fn set_learning_rate(&mut self, learning_rate: f64) {
|
||||
self.learning_rate = learning_rate;
|
||||
}
|
||||
fn set_exploration_prob(&mut self, exploration_prob: f64) {
|
||||
self.exploration_prob = exploration_prob;
|
||||
}
|
||||
fn set_discount_rate(&mut self, discount_rate: f64) {
|
||||
self.discount_rate = discount_rate;
|
||||
}
|
||||
}
|
||||
|
|
92
src/cli.rs
92
src/cli.rs
|
@ -1,4 +1,92 @@
|
|||
use clap::Clap;
|
||||
use crate::actors::*;
|
||||
use clap::{arg_enum, Clap};
|
||||
use log::Level;
|
||||
use simple_logger::init_with_level;
|
||||
|
||||
#[derive(Clap)]
|
||||
pub struct Ops {}
|
||||
pub struct Opts {
|
||||
/// Add more flags to increase verbosity to log at the debug or trace level.
|
||||
#[clap(
|
||||
short = "v",
|
||||
long = "verbose",
|
||||
parse(from_occurrences),
|
||||
conflicts_with("quiet")
|
||||
)]
|
||||
pub verbose: u8,
|
||||
/// Add more flags to decrease verbosity to the warn, error, or silent level.
|
||||
#[clap(short = "q", long = "quiet", parse(from_occurrences))]
|
||||
pub quiet: u8,
|
||||
#[clap(subcommand)]
|
||||
pub subcmd: SubCommand,
|
||||
}
|
||||
|
||||
#[derive(Clap)]
|
||||
pub enum SubCommand {
|
||||
/// Play the game, without training an actor
|
||||
Play(Play),
|
||||
/// Train an actor to play the game.
|
||||
Train(Train),
|
||||
}
|
||||
|
||||
#[derive(Clap)]
|
||||
pub struct Play {
|
||||
/// Disable gravity. Useful for debugging.
|
||||
#[clap(short = "G", long = "no-gravity")]
|
||||
pub no_gravity: bool,
|
||||
}
|
||||
|
||||
#[derive(Clap)]
|
||||
pub struct Train {
|
||||
/// Which agent to use for training.
|
||||
pub agent: Agent,
|
||||
|
||||
/// The rate at which temporal agents learn at.
|
||||
#[clap(short = "a", long = "alpha", default_value = "0.3")]
|
||||
pub learning_rate: f64,
|
||||
|
||||
/// The rate at which agents explore new actions. Range is [0, 1].
|
||||
#[clap(short = "e", long = "epsilon", default_value = "0.1")]
|
||||
pub exploration_prob: f64,
|
||||
|
||||
/// The discount rate for future states.
|
||||
#[clap(short = "g", long = "gamma", default_value = "0.7")]
|
||||
pub discount_rate: f64,
|
||||
|
||||
/// Stop learning during the evaluation of the agent. This sets the learning
|
||||
/// rate to 0 when displaying the results.
|
||||
#[clap(short = "L", long = "no-learn")]
|
||||
pub no_learn_during_evaluation: bool,
|
||||
|
||||
/// Stop exploring during the evaluation of the agent. This sets the
|
||||
/// exploration rate to 0 when displaying the results.
|
||||
#[clap(short = "E", long = "no-explore")]
|
||||
pub no_explore_during_evaluation: bool,
|
||||
|
||||
/// Number of episodes to train the agent
|
||||
#[clap(short = "n", long = "num", default_value = "10")]
|
||||
pub episodes: usize,
|
||||
}
|
||||
|
||||
arg_enum! {
|
||||
#[derive(Debug)]
|
||||
pub enum Agent {
|
||||
QLearning
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_verbosity(opts: &Opts) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match (opts.quiet, opts.verbose) {
|
||||
(0, 0) => init_with_level(Level::Info)?,
|
||||
(0, 1) => init_with_level(Level::Debug)?,
|
||||
(0, _) => init_with_level(Level::Trace)?,
|
||||
(1, 0) => init_with_level(Level::Warn)?,
|
||||
(2, 0) => init_with_level(Level::Error)?,
|
||||
_ => (),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_actor() -> impl Actor {
|
||||
qlearning::QLearningAgent::default()
|
||||
}
|
||||
|
|
|
@ -116,6 +116,7 @@ impl Game {
|
|||
pub fn score(&self) -> u32 {
|
||||
self.score
|
||||
}
|
||||
|
||||
pub fn is_game_over(&self) -> Option<LossReason> {
|
||||
self.is_game_over.or_else(|| {
|
||||
self.playfield
|
||||
|
@ -125,7 +126,6 @@ impl Game {
|
|||
}
|
||||
|
||||
fn update_gravity_tick(&mut self) {
|
||||
// self.next_gravity_tick = (-1 as i64) as u64;
|
||||
self.next_gravity_tick = self.tick + TICKS_PER_SECOND as u64;
|
||||
}
|
||||
|
||||
|
@ -178,6 +178,7 @@ impl Game {
|
|||
if cleared_lines > 0 {
|
||||
trace!("Lines were cleared.");
|
||||
self.line_clears += cleared_lines as u32;
|
||||
self.score += (cleared_lines * self.level as usize) as u32;
|
||||
self.level = (self.line_clears / 10) as u8;
|
||||
self.playfield.active_piece = None;
|
||||
self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY;
|
||||
|
@ -230,6 +231,7 @@ impl Controllable for Game {
|
|||
self.update_lock_tick();
|
||||
}
|
||||
}
|
||||
|
||||
fn move_right(&mut self) {
|
||||
if self.playfield.active_piece.is_none() {
|
||||
return;
|
||||
|
@ -240,6 +242,7 @@ impl Controllable for Game {
|
|||
self.update_lock_tick();
|
||||
}
|
||||
}
|
||||
|
||||
fn move_down(&mut self) {
|
||||
if self.playfield.active_piece.is_none() {
|
||||
return;
|
||||
|
@ -251,6 +254,7 @@ impl Controllable for Game {
|
|||
self.update_lock_tick();
|
||||
}
|
||||
}
|
||||
|
||||
fn rotate_left(&mut self) {
|
||||
if self.playfield.active_piece.is_none() {
|
||||
return;
|
||||
|
@ -267,6 +271,7 @@ impl Controllable for Game {
|
|||
Err(_) => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn rotate_right(&mut self) {
|
||||
if self.playfield.active_piece.is_none() {
|
||||
return;
|
||||
|
@ -283,6 +288,7 @@ impl Controllable for Game {
|
|||
Err(_) => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn hard_drop(&mut self) {
|
||||
if self.playfield.active_piece.is_none() {
|
||||
return;
|
||||
|
@ -323,6 +329,7 @@ impl Controllable for Game {
|
|||
|
||||
fn get_legal_actions(&self) -> Vec<Action> {
|
||||
let mut legal_actions = vec![
|
||||
Action::Nothing,
|
||||
Action::MoveLeft,
|
||||
Action::MoveRight,
|
||||
Action::SoftDrop,
|
||||
|
|
|
@ -4,6 +4,7 @@ use crate::{
|
|||
playfield::{PlayField, PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
||||
tetromino::{Position, RotationState, Tetromino, TetrominoType},
|
||||
};
|
||||
use log::error;
|
||||
use sdl2::{
|
||||
pixels::Color,
|
||||
rect::Rect,
|
||||
|
@ -21,7 +22,10 @@ pub fn render(canvas: &mut Canvas<Window>, game: &Game) {
|
|||
game_width as u32,
|
||||
game_height as u32,
|
||||
)));
|
||||
game.render(canvas);
|
||||
match game.render(canvas) {
|
||||
Ok(_) => (),
|
||||
Err(e) => error!("{}", e),
|
||||
};
|
||||
}
|
||||
|
||||
impl Renderable for Game {
|
||||
|
@ -74,7 +78,7 @@ impl Renderable for PlayField {
|
|||
match self.hold_piece() {
|
||||
Some(p) => {
|
||||
canvas.set_draw_color(p);
|
||||
canvas.fill_rect(Rect::new(-32 - UI_PADDING as i32, 0, CELL_SIZE, CELL_SIZE));
|
||||
canvas.fill_rect(Rect::new(-32 - UI_PADDING as i32, 0, CELL_SIZE, CELL_SIZE))?;
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
|
|
121
src/main.rs
121
src/main.rs
|
@ -1,16 +1,19 @@
|
|||
use actors::*;
|
||||
use clap::Clap;
|
||||
use cli::*;
|
||||
use game::{Action, Controllable, Game, Tickable};
|
||||
use graphics::standard_renderer;
|
||||
use graphics::COLOR_BACKGROUND;
|
||||
use indicatif::ProgressIterator;
|
||||
use log::{debug, info, trace};
|
||||
use rand::SeedableRng;
|
||||
use sdl2::event::Event;
|
||||
use sdl2::keyboard::Keycode;
|
||||
use simple_logger;
|
||||
use std::time::Duration;
|
||||
use tokio::time::interval;
|
||||
|
||||
mod actors;
|
||||
mod cli;
|
||||
mod game;
|
||||
mod graphics;
|
||||
mod playfield;
|
||||
|
@ -22,17 +25,50 @@ const TICKS_PER_SECOND: usize = 60;
|
|||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// simple_logger::init()?;
|
||||
simple_logger::init_with_level(log::Level::Info)?;
|
||||
let opts = crate::cli::Opts::parse();
|
||||
|
||||
let mut actor = qlearning::QLearningAgent::default();
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(1337);
|
||||
init_verbosity(&opts)?;
|
||||
let mut actor = None;
|
||||
|
||||
match opts.subcmd {
|
||||
SubCommand::Play(sub_opts) => {}
|
||||
SubCommand::Train(sub_opts) => {
|
||||
let mut to_train = get_actor();
|
||||
to_train.set_learning_rate(sub_opts.learning_rate);
|
||||
to_train.set_discount_rate(sub_opts.discount_rate);
|
||||
to_train.set_exploration_prob(sub_opts.exploration_prob);
|
||||
|
||||
info!(
|
||||
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
|
||||
sub_opts.learning_rate,
|
||||
sub_opts.discount_rate,
|
||||
sub_opts.exploration_prob
|
||||
);
|
||||
let mut trained_actor = train_actor(sub_opts.episodes, to_train);
|
||||
if sub_opts.no_explore_during_evaluation {
|
||||
trained_actor.set_exploration_prob(0.0);
|
||||
}
|
||||
|
||||
if sub_opts.no_learn_during_evaluation {
|
||||
trained_actor.set_learning_rate(0.0);
|
||||
}
|
||||
|
||||
actor = Some(trained_actor);
|
||||
}
|
||||
}
|
||||
|
||||
play_game(actor).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn train_actor(episodes: usize, mut actor: impl Actor) -> impl Actor {
|
||||
let mut rng = rand::rngs::StdRng::from_entropy();
|
||||
let mut avg = 0.0;
|
||||
|
||||
for i in 0..100000 {
|
||||
if i % 100 == 0 {
|
||||
info!("Last 100 scores avg: {}", avg);
|
||||
for i in (0..episodes).progress() {
|
||||
if i != 0 && i % (episodes / 10) == 0 {
|
||||
info!("Last {} scores avg: {}", (episodes / 10), avg);
|
||||
println!();
|
||||
avg = 0.0;
|
||||
}
|
||||
let mut game = Game::default();
|
||||
|
@ -53,18 +89,28 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
}
|
||||
|
||||
let new_state = (&game).into();
|
||||
let reward = game.score() - cur_score;
|
||||
actor.update(cur_state, action, new_state, reward as f64);
|
||||
let mut reward = game.score() as f64 - cur_score as f64;
|
||||
if action != Action::Nothing {
|
||||
reward -= 10.0;
|
||||
}
|
||||
|
||||
if game.is_game_over().is_some() {
|
||||
reward = -100.0;
|
||||
}
|
||||
|
||||
actor.update(cur_state, action, new_state, reward);
|
||||
|
||||
game.tick();
|
||||
}
|
||||
|
||||
avg += game.score() as f64 / 100.0;
|
||||
// info!("Game over with score of {}", game.score());
|
||||
avg += game.score() as f64 / (episodes / 10) as f64;
|
||||
}
|
||||
|
||||
actor.exploration_prob = 0.0;
|
||||
actor
|
||||
}
|
||||
|
||||
async fn play_game(mut actor: Option<impl Actor>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut rng = rand::rngs::StdRng::from_entropy();
|
||||
let sdl_context = sdl2::init()?;
|
||||
let video_subsystem = sdl_context.video()?;
|
||||
let window = video_subsystem
|
||||
|
@ -86,18 +132,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
}
|
||||
|
||||
let cur_state = (&game).into();
|
||||
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
|
||||
match action {
|
||||
Action::Nothing => (),
|
||||
Action::MoveLeft => game.move_left(),
|
||||
Action::MoveRight => game.move_right(),
|
||||
Action::SoftDrop => game.move_down(),
|
||||
Action::HardDrop => game.hard_drop(),
|
||||
Action::Hold => game.hold(),
|
||||
Action::RotateLeft => game.rotate_left(),
|
||||
Action::RotateRight => game.rotate_right(),
|
||||
}
|
||||
|
||||
// If there's an actor, the player action will get overridden. If not,
|
||||
// then then the player action falls through, if there is one. This is
|
||||
// to allow for restarting and quitting the game from the GUI.
|
||||
let mut action = None;
|
||||
for event in event_pump.poll_iter() {
|
||||
match event {
|
||||
Event::Quit { .. }
|
||||
|
@ -113,35 +152,35 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
..
|
||||
} => {
|
||||
debug!("Move left registered");
|
||||
game.move_left();
|
||||
action = Some(Action::MoveLeft);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Right),
|
||||
..
|
||||
} => {
|
||||
debug!("Move right registered");
|
||||
game.move_right();
|
||||
action = Some(Action::MoveRight);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Down),
|
||||
..
|
||||
} => {
|
||||
debug!("Soft drop registered");
|
||||
game.move_down();
|
||||
action = Some(Action::SoftDrop);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Z),
|
||||
..
|
||||
} => {
|
||||
debug!("Rotate left registered");
|
||||
game.rotate_left();
|
||||
action = Some(Action::RotateLeft);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::X),
|
||||
..
|
||||
} => {
|
||||
debug!("Rotate right registered");
|
||||
game.rotate_right();
|
||||
action = Some(Action::RotateRight);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Space),
|
||||
|
@ -152,14 +191,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
..
|
||||
} => {
|
||||
debug!("Hard drop registered");
|
||||
game.hard_drop();
|
||||
action = Some(Action::HardDrop);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::LShift),
|
||||
..
|
||||
} => {
|
||||
debug!("Hold registered");
|
||||
game.hold();
|
||||
action = Some(Action::Hold);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::R),
|
||||
|
@ -174,6 +213,22 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
actor.as_mut().map(|actor| {
|
||||
action = Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())));
|
||||
});
|
||||
|
||||
action.map(|action| match action {
|
||||
Action::Nothing => (),
|
||||
Action::MoveLeft => game.move_left(),
|
||||
Action::MoveRight => game.move_right(),
|
||||
Action::SoftDrop => game.move_down(),
|
||||
Action::HardDrop => game.hard_drop(),
|
||||
Action::Hold => game.hold(),
|
||||
Action::RotateLeft => game.rotate_left(),
|
||||
Action::RotateRight => game.rotate_right(),
|
||||
});
|
||||
|
||||
game.tick();
|
||||
canvas.set_draw_color(COLOR_BACKGROUND);
|
||||
canvas.clear();
|
||||
|
@ -182,7 +237,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
interval.tick().await;
|
||||
}
|
||||
|
||||
dbg!(game);
|
||||
|
||||
info!("Final score: {}", game.score());
|
||||
actor.map(|a| a.dbg());
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue