Compare commits
No commits in common. "5a9e2538aa08e796614cfb8193d18a16a29b7393" and "a65f48f5856d58803cee67369c0e63cf2b7623f9" have entirely different histories.
5a9e2538aa
...
a65f48f585
9 changed files with 296 additions and 734 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,4 +1,2 @@
|
|||
/target
|
||||
**/*.rs.bk
|
||||
flamegraph.svg
|
||||
perf.*
|
||||
|
|
104
Cargo.lock
generated
104
Cargo.lock
generated
|
@ -150,88 +150,11 @@ name = "fuchsia-zircon-sys"
|
|||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-executor 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proc-macro2 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-macro 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"pin-utils 0.1.0-alpha.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proc-macro-nested 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"slab 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.13"
|
||||
|
@ -422,11 +345,6 @@ name = "pin-project-lite"
|
|||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "pin-utils"
|
||||
version = "0.1.0-alpha.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.6"
|
||||
|
@ -456,16 +374,6 @@ dependencies = [
|
|||
"version_check 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-hack"
|
||||
version = "0.5.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-nested"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.9"
|
||||
|
@ -641,7 +549,6 @@ name = "tetris"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"clap 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)",
|
||||
"futures 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
|
@ -788,15 +695,7 @@ dependencies = [
|
|||
"checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3"
|
||||
"checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82"
|
||||
"checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
|
||||
"checksum futures 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "5c329ae8753502fb44ae4fc2b622fa2a94652c41e795143765ba0927f92ab780"
|
||||
"checksum futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f0c77d04ce8edd9cb903932b608268b3fffec4163dc053b3b402bf47eac1f1a8"
|
||||
"checksum futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f25592f769825e89b92358db00d26f965761e094951ac44d3663ef25b7ac464a"
|
||||
"checksum futures-executor 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f674f3e1bcb15b37284a90cedf55afdba482ab061c407a9c0ebbd0f3109741ba"
|
||||
"checksum futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "a638959aa96152c7a4cddf50fcb1e3fede0583b27157c26e67d6f99904090dc6"
|
||||
"checksum futures-macro 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "9a5081aa3de1f7542a794a397cde100ed903b0630152d0973479018fd85423a7"
|
||||
"checksum futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "3466821b4bc114d95b087b850a724c6f83115e929bc88f1fa98a3304a944c8a6"
|
||||
"checksum futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "7b0a34e53cf6cdcd0178aa573aed466b646eb3db769570841fda0c7ede375a27"
|
||||
"checksum futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "22766cf25d64306bedf0384da004d05c9974ab104fcc4528f1236181c18004c5"
|
||||
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
|
||||
"checksum heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205"
|
||||
"checksum hermit-abi 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "1010591b26bbfe835e9faeabeb11866061cc7dcebffd56ad7d0942d0e61aefd8"
|
||||
|
@ -819,12 +718,9 @@ dependencies = [
|
|||
"checksum num_cpus 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "46203554f085ff89c235cd12f7075f3233af9b11ed7c9e16dfe2560d03313ce6"
|
||||
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
|
||||
"checksum pin-project-lite 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "237844750cfbb86f67afe27eee600dfbbcb6188d734139b534cbfbf4f96792ae"
|
||||
"checksum pin-utils 0.1.0-alpha.4 (registry+https://github.com/rust-lang/crates.io-index)" = "5894c618ce612a3fa23881b152b608bafb8c56cfc22f434a3ba3120b40f7b587"
|
||||
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
|
||||
"checksum proc-macro-error 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7"
|
||||
"checksum proc-macro-error-attr 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "8a5b4b77fdb63c1eca72173d68d24501c54ab1269409f6b672c85deb18af69de"
|
||||
"checksum proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)" = "0d659fe7c6d27f25e9d80a1a094c223f5246f6a6596453e09d7229bf42750b63"
|
||||
"checksum proc-macro-nested 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "8e946095f9d3ed29ec38de908c22f95d9ac008e424c7bcae54c75a79c527c694"
|
||||
"checksum proc-macro2 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)" = "6c09721c6781493a2a492a96b5a5bf19b65917fe6728884e7c44dd0c60ca3435"
|
||||
"checksum quote 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2bdc6c187c65bca4260c9011c9e3132efe4909da44726bad24cf7572ae338d7f"
|
||||
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
|
||||
|
|
|
@ -13,5 +13,4 @@ log = "0.4"
|
|||
simple_logger = "1.6"
|
||||
sdl2 = { version = "0.33.0", features = ["ttf"] }
|
||||
clap = { git = "https://github.com/clap-rs/clap/", features = ["color"] }
|
||||
indicatif = "0.14"
|
||||
futures = "0.3"
|
||||
indicatif = "0.14"
|
|
@ -1,17 +1,9 @@
|
|||
// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
|
||||
|
||||
use super::{Actor, Predictable, State};
|
||||
use crate::{
|
||||
cli::Train,
|
||||
game::{Action, Controllable, Game, Tickable},
|
||||
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
||||
};
|
||||
use indicatif::ProgressBar;
|
||||
use log::debug;
|
||||
use super::Actor;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{seq::SliceRandom, Rng, SeedableRng};
|
||||
use rand::Rng;
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Parameters {
|
||||
total_height: f64,
|
||||
bumpiness: f64,
|
||||
|
@ -19,17 +11,6 @@ pub struct Parameters {
|
|||
complete_lines: f64,
|
||||
}
|
||||
|
||||
impl Default for Parameters {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
total_height: 1.0,
|
||||
bumpiness: 1.0,
|
||||
holes: 1.0,
|
||||
complete_lines: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
fn mutate(mut self, rng: &mut SmallRng) {
|
||||
let mutation_amt = rng.gen_range(-0.2, 0.2);
|
||||
|
@ -52,220 +33,39 @@ impl Parameters {
|
|||
self.holes /= normalization_factor;
|
||||
self.complete_lines /= normalization_factor;
|
||||
}
|
||||
|
||||
fn dot_multiply(&self, other: &Self) -> f64 {
|
||||
self.total_height * other.total_height
|
||||
+ self.bumpiness * other.bumpiness
|
||||
+ self.holes * other.holes
|
||||
+ self.complete_lines * other.complete_lines
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct GeneticHeuristicAgent {
|
||||
params: Parameters,
|
||||
}
|
||||
|
||||
impl Default for GeneticHeuristicAgent {
|
||||
fn default() -> Self {
|
||||
let mut rng = SmallRng::from_entropy();
|
||||
Self {
|
||||
params: Parameters {
|
||||
total_height: rng.gen::<f64>(),
|
||||
bumpiness: rng.gen::<f64>(),
|
||||
holes: rng.gen::<f64>(),
|
||||
complete_lines: rng.gen::<f64>(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GeneticHeuristicAgent {
|
||||
fn extract_features_from_state(state: &State) -> Parameters {
|
||||
let mut heights = [None; PLAYFIELD_WIDTH];
|
||||
for r in 0..PLAYFIELD_HEIGHT {
|
||||
for c in 0..PLAYFIELD_WIDTH {
|
||||
if heights[c].is_none() && state.matrix[r][c].is_some() {
|
||||
heights[c] = Some(PLAYFIELD_HEIGHT - r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let total_height = heights
|
||||
.iter()
|
||||
.map(|o| o.unwrap_or_else(|| 0))
|
||||
.sum::<usize>() as f64;
|
||||
|
||||
let bumpiness = heights
|
||||
.iter()
|
||||
.map(|o| o.unwrap_or_else(|| 0) as isize)
|
||||
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
|
||||
.0 as f64;
|
||||
|
||||
let complete_lines = state
|
||||
.matrix
|
||||
.iter()
|
||||
.map(|row| row.iter().all(Option::is_some))
|
||||
.map(|c| if c { 1.0 } else { 0.0 })
|
||||
.sum::<f64>();
|
||||
|
||||
let mut holes = 0;
|
||||
for r in 1..PLAYFIELD_HEIGHT {
|
||||
for c in 0..PLAYFIELD_WIDTH {
|
||||
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
|
||||
holes += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Parameters {
|
||||
total_height,
|
||||
bumpiness,
|
||||
complete_lines,
|
||||
holes: holes as f64,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_heuristic(&self, game: &Game, action: &Action) -> f64 {
|
||||
self.params.dot_multiply(&Self::extract_features_from_state(
|
||||
&game.get_next_state(*action).into(),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn breed(&self, self_fitness: u32, other: &Self, other_fitness: u32) -> Self {
|
||||
let weight = self_fitness + other_fitness;
|
||||
|
||||
if weight != 0 {
|
||||
let self_weight = self_fitness as f64 / weight as f64;
|
||||
let other_weight = other_fitness as f64 / weight as f64;
|
||||
Self {
|
||||
params: Parameters {
|
||||
total_height: self.params.total_height * self_weight
|
||||
+ other.params.total_height * other_weight,
|
||||
bumpiness: self.params.total_height * self_weight
|
||||
+ other.params.total_height * other_weight,
|
||||
holes: self.params.total_height * self_weight
|
||||
+ other.params.total_height * other_weight,
|
||||
complete_lines: self.params.total_height * self_weight
|
||||
+ other.params.total_height * other_weight,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn mutate(&mut self, rng: &mut SmallRng) {
|
||||
self.params.mutate(rng);
|
||||
}
|
||||
}
|
||||
pub struct GeneticHeuristicAgent {}
|
||||
|
||||
impl Actor for GeneticHeuristicAgent {
|
||||
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
|
||||
let actions = legal_actions
|
||||
.iter()
|
||||
.map(|action| {
|
||||
(
|
||||
action,
|
||||
(self.get_heuristic(game, action) * 1_000_000.0) as usize,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let max_val = actions
|
||||
.iter()
|
||||
.max_by_key(|(_, heuristic)| heuristic)
|
||||
.unwrap()
|
||||
.1;
|
||||
|
||||
*actions
|
||||
.iter()
|
||||
.filter(|e| e.1 == max_val)
|
||||
.collect::<Vec<_>>()
|
||||
.choose(rng)
|
||||
.unwrap()
|
||||
.0
|
||||
fn get_action(
|
||||
&self,
|
||||
rng: &mut SmallRng,
|
||||
state: &super::State,
|
||||
legal_actions: &[crate::game::Action],
|
||||
) -> crate::game::Action {
|
||||
unimplemented!()
|
||||
}
|
||||
fn update(
|
||||
&mut self,
|
||||
state: super::State,
|
||||
action: crate::game::Action,
|
||||
next_state: super::State,
|
||||
next_legal_actions: &[crate::game::Action],
|
||||
reward: f64,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
fn set_learning_rate(&mut self, learning_rate: f64) {
|
||||
unimplemented!()
|
||||
}
|
||||
fn set_exploration_prob(&mut self, exploration_prob: f64) {
|
||||
unimplemented!()
|
||||
}
|
||||
fn set_discount_rate(&mut self, discount_rate: f64) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn dbg(&self) {
|
||||
debug!("{:?}", self.params);
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn train_actor(opts: &Train) -> Box<dyn Actor> {
|
||||
use std::sync::{Arc, Mutex};
|
||||
let rng = Arc::new(Mutex::new(SmallRng::from_entropy()));
|
||||
let mut population = vec![(GeneticHeuristicAgent::default(), 0); 1000];
|
||||
let total_pb = Arc::new(Mutex::new(ProgressBar::new(
|
||||
population.len() as u64 * opts.episodes as u64,
|
||||
)));
|
||||
|
||||
for _ in 0..opts.episodes {
|
||||
let mut new_pop_futs = Vec::with_capacity(population.len());
|
||||
let new_population = Arc::new(Mutex::new(Vec::with_capacity(population.len())));
|
||||
let num_rounds = 10;
|
||||
for i in 0..population.len() {
|
||||
let rng = rng.clone();
|
||||
let new_population = new_population.clone();
|
||||
let total_pb = total_pb.clone();
|
||||
let agent = population[i].0;
|
||||
new_pop_futs.push(tokio::task::spawn(async move {
|
||||
let mut fitness = 0;
|
||||
for _ in 0..num_rounds {
|
||||
let mut game = Game::default();
|
||||
game.set_piece_limit(500);
|
||||
game.set_time_limit(std::time::Duration::from_secs(60));
|
||||
while (&game).is_game_over().is_none() {
|
||||
let mut rng = rng.lock().expect("rng failed");
|
||||
super::apply_action_to_game(
|
||||
agent.get_action(
|
||||
&mut rng,
|
||||
&game.clone().into(),
|
||||
&game.get_legal_actions(),
|
||||
),
|
||||
&mut game,
|
||||
);
|
||||
game.tick();
|
||||
}
|
||||
fitness += game.line_clears;
|
||||
}
|
||||
new_population.lock().unwrap().push((agent, fitness));
|
||||
|
||||
total_pb.lock().expect("progressbar failed").inc(1);
|
||||
}));
|
||||
}
|
||||
futures::future::join_all(new_pop_futs).await;
|
||||
|
||||
let mut rng = SmallRng::from_entropy();
|
||||
|
||||
let mut new_population = new_population.lock().unwrap().clone();
|
||||
let new_pop_size = population.len() * 3 / 10;
|
||||
let mut breeded_population: Vec<(GeneticHeuristicAgent, u32)> =
|
||||
Vec::with_capacity(new_pop_size);
|
||||
for _ in 0..3 {
|
||||
let mut random_selection = new_population
|
||||
.choose_multiple(&mut rng, new_pop_size / 3)
|
||||
.collect::<Vec<_>>();
|
||||
random_selection.sort_unstable_by(|e1, e2| e1.1.cmp(&e2.1));
|
||||
let best_two = random_selection.iter().rev().take(2).collect::<Vec<_>>();
|
||||
let parent1 = dbg!(best_two[0]);
|
||||
let parent2 = dbg!(best_two[1]);
|
||||
for _ in 0..new_pop_size / 3 {
|
||||
let breeded = parent1.0.breed(parent1.1, &parent2.0, parent2.1);
|
||||
let mut cloned = breeded.clone();
|
||||
if rng.gen::<f64>() < 0.05 {
|
||||
cloned.mutate(&mut rng);
|
||||
}
|
||||
breeded_population.push((cloned, 0));
|
||||
}
|
||||
}
|
||||
|
||||
new_population.sort_unstable_by_key(|e| e.1);
|
||||
new_population.splice(..breeded_population.len(), breeded_population);
|
||||
population = new_population;
|
||||
}
|
||||
total_pb.lock().unwrap().finish();
|
||||
|
||||
population.sort_unstable_by_key(|e| e.1);
|
||||
Box::new(population.iter().rev().next().unwrap().0)
|
||||
}
|
||||
|
|
|
@ -40,7 +40,20 @@ impl From<PlayField> for State {
|
|||
}
|
||||
|
||||
pub trait Actor {
|
||||
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action;
|
||||
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action;
|
||||
|
||||
fn update(
|
||||
&mut self,
|
||||
state: State,
|
||||
action: Action,
|
||||
next_state: State,
|
||||
next_legal_actions: &[Action],
|
||||
reward: f64,
|
||||
);
|
||||
|
||||
fn set_learning_rate(&mut self, learning_rate: f64);
|
||||
fn set_exploration_prob(&mut self, exploration_prob: f64);
|
||||
fn set_discount_rate(&mut self, discount_rate: f64);
|
||||
|
||||
fn dbg(&self);
|
||||
}
|
||||
|
@ -67,16 +80,3 @@ impl Predictable for Game {
|
|||
game
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply_action_to_game(action: Action, game: &mut Game) {
|
||||
match action {
|
||||
Action::Nothing => (),
|
||||
Action::MoveLeft => game.move_left(),
|
||||
Action::MoveRight => game.move_right(),
|
||||
Action::SoftDrop => game.move_down(),
|
||||
Action::HardDrop => game.hard_drop(),
|
||||
Action::Hold => game.hold(),
|
||||
Action::RotateLeft => game.rotate_left(),
|
||||
Action::RotateRight => game.rotate_right(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,33 +1,16 @@
|
|||
use super::Predictable;
|
||||
use crate::actors::{Actor, State};
|
||||
use crate::{
|
||||
cli::Train,
|
||||
game::{Action, Controllable, Game, Tickable},
|
||||
playfield::{Matrix, PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
||||
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
||||
};
|
||||
use indicatif::ProgressIterator;
|
||||
use log::{debug, info, trace};
|
||||
use log::{debug, info};
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
use rand::SeedableRng;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub trait QLearningActor: Actor {
|
||||
fn update(
|
||||
&mut self,
|
||||
game_state: Game,
|
||||
action: Action,
|
||||
next_game_state: Game,
|
||||
next_legal_actions: &[Action],
|
||||
reward: f64,
|
||||
);
|
||||
|
||||
fn set_learning_rate(&mut self, learning_rate: f64);
|
||||
fn set_exploration_prob(&mut self, exploration_prob: f64);
|
||||
fn set_discount_rate(&mut self, discount_rate: f64);
|
||||
}
|
||||
|
||||
pub struct QLearningAgent {
|
||||
pub learning_rate: f64,
|
||||
pub exploration_prob: f64,
|
||||
|
@ -55,32 +38,11 @@ impl QLearningAgent {
|
|||
}
|
||||
|
||||
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
|
||||
let legal_actions = legal_actions
|
||||
.iter()
|
||||
.map(|action| (action, self.get_q_value(state, *action)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let max_val = legal_actions
|
||||
*legal_actions
|
||||
.iter()
|
||||
.map(|action| (action, self.get_q_value(&state, *action)))
|
||||
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
|
||||
.expect("Failed to select an action")
|
||||
.1;
|
||||
|
||||
let actions_to_choose = legal_actions
|
||||
.iter()
|
||||
.filter(|(_, v)| max_val == *v)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if actions_to_choose.len() != 1 {
|
||||
trace!(
|
||||
"more than one best option, choosing randomly: {:?}",
|
||||
actions_to_choose
|
||||
);
|
||||
}
|
||||
|
||||
*actions_to_choose
|
||||
.choose(&mut SmallRng::from_entropy())
|
||||
.unwrap()
|
||||
.0
|
||||
}
|
||||
|
||||
|
@ -101,30 +63,22 @@ impl QLearningAgent {
|
|||
impl Actor for QLearningAgent {
|
||||
// Because doing (Nothing) is in the set of legal actions, this will never
|
||||
// be empty
|
||||
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
|
||||
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
|
||||
if rng.gen::<f64>() < self.exploration_prob {
|
||||
*legal_actions.choose(rng).unwrap()
|
||||
} else {
|
||||
self.get_action_from_q_values(&game.into(), legal_actions)
|
||||
self.get_action_from_q_values(state, legal_actions)
|
||||
}
|
||||
}
|
||||
|
||||
fn dbg(&self) {
|
||||
debug!("Total states: {}", self.q_values.len());
|
||||
}
|
||||
}
|
||||
|
||||
impl QLearningActor for QLearningAgent {
|
||||
fn update(
|
||||
&mut self,
|
||||
game_state: Game,
|
||||
state: State,
|
||||
action: Action,
|
||||
next_game_state: Game,
|
||||
next_state: State,
|
||||
_next_legal_actions: &[Action],
|
||||
reward: f64,
|
||||
) {
|
||||
let state = (&game_state).into();
|
||||
let next_state = (&next_game_state).into();
|
||||
let cur_q_val = self.get_q_value(&state, action);
|
||||
let new_q_val = cur_q_val
|
||||
+ self.learning_rate
|
||||
|
@ -139,6 +93,10 @@ impl QLearningActor for QLearningAgent {
|
|||
.insert(action, new_q_val);
|
||||
}
|
||||
|
||||
fn dbg(&self) {
|
||||
debug!("Total states: {}", self.q_values.len());
|
||||
}
|
||||
|
||||
fn set_learning_rate(&mut self, learning_rate: f64) {
|
||||
self.learning_rate = learning_rate;
|
||||
}
|
||||
|
@ -154,187 +112,126 @@ pub struct ApproximateQLearning {
|
|||
pub learning_rate: f64,
|
||||
pub exploration_prob: f64,
|
||||
pub discount_rate: f64,
|
||||
weights: HashMap<Feature, f64>,
|
||||
weights: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
impl Default for ApproximateQLearning {
|
||||
fn default() -> Self {
|
||||
let mut weights = HashMap::default();
|
||||
weights.insert(Feature::TotalHeight, 1.0);
|
||||
weights.insert(Feature::Bumpiness, 1.0);
|
||||
weights.insert(Feature::LinesCleared, 1.0);
|
||||
weights.insert(Feature::Holes, 1.0);
|
||||
|
||||
Self {
|
||||
learning_rate: 0.0,
|
||||
exploration_prob: 0.0,
|
||||
discount_rate: 0.0,
|
||||
weights,
|
||||
weights: HashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
|
||||
enum Feature {
|
||||
TotalHeight,
|
||||
Bumpiness,
|
||||
LinesCleared,
|
||||
Holes,
|
||||
}
|
||||
|
||||
impl ApproximateQLearning {
|
||||
fn get_features(&self, game: &Game, action: &Action) -> HashMap<Feature, f64> {
|
||||
// let game = game.get_next_state(*action);
|
||||
|
||||
fn get_features(
|
||||
&self,
|
||||
state: &State,
|
||||
_action: &Action,
|
||||
new_state: &State,
|
||||
) -> HashMap<String, f64> {
|
||||
let mut features = HashMap::default();
|
||||
let field = game.playfield().field();
|
||||
let heights = self.get_heights(field);
|
||||
|
||||
let mut heights = [None; PLAYFIELD_WIDTH];
|
||||
for r in 0..PLAYFIELD_HEIGHT {
|
||||
for c in 0..PLAYFIELD_WIDTH {
|
||||
if heights[c].is_none() && state.matrix[r][c].is_some() {
|
||||
heights[c] = Some(PLAYFIELD_HEIGHT - r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
features.insert(
|
||||
Feature::TotalHeight,
|
||||
heights.iter().sum::<usize>() as f64 / (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64,
|
||||
"Total Height".into(),
|
||||
heights
|
||||
.iter()
|
||||
.map(|o| o.unwrap_or_else(|| 0))
|
||||
.sum::<usize>() as f64
|
||||
/ (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64,
|
||||
);
|
||||
|
||||
features.insert(
|
||||
Feature::Bumpiness,
|
||||
"Bumpiness".into(),
|
||||
heights
|
||||
.iter()
|
||||
.fold((0, 0), |(acc, prev), cur| {
|
||||
(acc + (prev as isize - *cur as isize).abs(), *cur)
|
||||
})
|
||||
.map(|o| o.unwrap_or_else(|| 0) as isize)
|
||||
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
|
||||
.0 as f64
|
||||
/ (PLAYFIELD_WIDTH * 40) as f64,
|
||||
);
|
||||
|
||||
features.insert(
|
||||
Feature::LinesCleared,
|
||||
game.playfield()
|
||||
.field()
|
||||
.iter()
|
||||
.map(|r| r.iter().all(Option::is_some))
|
||||
.map(|r| if r { 1 } else { 0 })
|
||||
.sum::<usize>() as f64
|
||||
/ 4.0,
|
||||
"Lines cleared".into(),
|
||||
(new_state.line_clears - state.line_clears) as f64 / 4.0,
|
||||
);
|
||||
|
||||
let mut holes = 0;
|
||||
for r in 1..PLAYFIELD_HEIGHT {
|
||||
for c in 0..PLAYFIELD_WIDTH {
|
||||
if field[r][c].is_none() && field[r - 1][c].is_some() {
|
||||
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
|
||||
holes += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
features.insert(Feature::Holes, holes as f64);
|
||||
features.insert("Holes".into(), holes as f64);
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
fn get_heights(&self, matrix: &Matrix) -> Vec<usize> {
|
||||
let mut heights = vec![0; matrix[0].len()];
|
||||
for r in 0..matrix.len() {
|
||||
for c in 0..matrix[0].len() {
|
||||
if heights[c] == 0 && matrix[r][c].is_some() {
|
||||
heights[c] = matrix.len() - r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
heights
|
||||
}
|
||||
|
||||
fn get_q_value(&self, game: &Game, action: &Action) -> f64 {
|
||||
self.get_features(game, action)
|
||||
fn get_q_value(&self, state: &State, action: &Action, next_state: &State) -> f64 {
|
||||
self.get_features(state, action, next_state)
|
||||
.iter()
|
||||
.map(|(key, val)| val * *self.weights.get(key).unwrap())
|
||||
.map(|(key, val)| val * *self.weights.get(key).unwrap_or_else(|| &0.0))
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn get_action_from_q_values(&self, game: &Game) -> Action {
|
||||
let legal_actions = game.get_legal_actions();
|
||||
let legal_actions = legal_actions
|
||||
.iter()
|
||||
.map(|action| (action, self.get_q_value(game, action)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let max_val = legal_actions
|
||||
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
|
||||
*legal_actions
|
||||
.iter()
|
||||
.map(|action| (action, self.get_q_value(&state, action, state)))
|
||||
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
|
||||
.expect("Failed to select an action")
|
||||
.1;
|
||||
|
||||
let actions_to_choose = legal_actions
|
||||
.iter()
|
||||
.filter(|(_, v)| max_val == *v)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if actions_to_choose.len() != 1 {
|
||||
trace!(
|
||||
"more than one best option, choosing randomly: {:?}",
|
||||
actions_to_choose
|
||||
);
|
||||
}
|
||||
|
||||
*actions_to_choose
|
||||
.choose(&mut SmallRng::from_entropy())
|
||||
.unwrap()
|
||||
.0
|
||||
}
|
||||
|
||||
fn get_value(&self, game: &Game) -> f64 {
|
||||
game.get_legal_actions()
|
||||
fn get_value(&self, state: &State, legal_actions: &[Action]) -> f64 {
|
||||
legal_actions
|
||||
.iter()
|
||||
.map(|action| self.get_q_value(game, action))
|
||||
.map(|action| self.get_q_value(state, action, state))
|
||||
.max_by_key(|v| (v * 1_000_000.0) as isize)
|
||||
.unwrap_or_else(|| 0.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod aaa {
|
||||
use super::*;
|
||||
use crate::tetromino::TetrominoType;
|
||||
|
||||
#[test]
|
||||
fn test_height() {
|
||||
let agent = ApproximateQLearning::default();
|
||||
let matrix = vec![
|
||||
vec![None, None, None],
|
||||
vec![None, Some(TetrominoType::T), None],
|
||||
vec![None, None, None],
|
||||
];
|
||||
assert_eq!(agent.get_heights(&matrix), vec![0, 2, 0]);
|
||||
}
|
||||
}
|
||||
|
||||
impl Actor for ApproximateQLearning {
|
||||
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
|
||||
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
|
||||
if rng.gen::<f64>() < self.exploration_prob {
|
||||
*legal_actions.choose(rng).unwrap()
|
||||
} else {
|
||||
self.get_action_from_q_values(game)
|
||||
self.get_action_from_q_values(state, legal_actions)
|
||||
}
|
||||
}
|
||||
|
||||
fn dbg(&self) {
|
||||
dbg!(&self.weights);
|
||||
}
|
||||
}
|
||||
|
||||
impl QLearningActor for ApproximateQLearning {
|
||||
fn update(
|
||||
&mut self,
|
||||
game_state: Game,
|
||||
state: State,
|
||||
action: Action,
|
||||
next_game_state: Game,
|
||||
next_state: State,
|
||||
next_legal_actions: &[Action],
|
||||
reward: f64,
|
||||
) {
|
||||
let difference = reward + self.discount_rate * self.get_value(&next_game_state)
|
||||
- self.get_q_value(&game_state, &action);
|
||||
let difference = reward
|
||||
+ self.discount_rate * self.get_value(&next_state, next_legal_actions)
|
||||
- self.get_q_value(&state, &action, &next_state);
|
||||
|
||||
for (feat_key, feat_val) in self.get_features(&game_state, &action) {
|
||||
for (feat_key, feat_val) in self.get_features(&state, &action, &next_state) {
|
||||
self.weights.insert(
|
||||
feat_key.clone(),
|
||||
*self.weights.get(&feat_key).unwrap() + self.learning_rate * difference * feat_val,
|
||||
*self.weights.get(&feat_key).unwrap_or_else(|| &0.0)
|
||||
+ self.learning_rate * difference * feat_val,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -350,23 +247,15 @@ impl QLearningActor for ApproximateQLearning {
|
|||
fn set_discount_rate(&mut self, discount_rate: f64) {
|
||||
self.discount_rate = discount_rate;
|
||||
}
|
||||
|
||||
fn dbg(&self) {
|
||||
dbg!(&self.weights);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn train_actor<T: 'static + QLearningActor + Actor>(
|
||||
mut actor: T,
|
||||
opts: &Train,
|
||||
) -> Box<dyn Actor> {
|
||||
pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
|
||||
let mut rng = SmallRng::from_entropy();
|
||||
let mut avg = 0.0;
|
||||
let episodes = opts.episodes;
|
||||
|
||||
actor.set_learning_rate(opts.learning_rate);
|
||||
actor.set_discount_rate(opts.discount_rate);
|
||||
actor.set_exploration_prob(opts.exploration_prob);
|
||||
info!(
|
||||
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
|
||||
opts.learning_rate, opts.discount_rate, opts.exploration_prob
|
||||
);
|
||||
|
||||
for i in (0..episodes).progress() {
|
||||
if i != 0 && i % (episodes / 10) == 0 {
|
||||
|
@ -376,17 +265,26 @@ pub fn train_actor<T: 'static + QLearningActor + Actor>(
|
|||
}
|
||||
let mut game = Game::default();
|
||||
while (&game).is_game_over().is_none() {
|
||||
let cur_state = game.clone();
|
||||
let cur_state = (&game).into();
|
||||
let cur_score = game.score();
|
||||
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
|
||||
|
||||
super::apply_action_to_game(action, &mut game);
|
||||
match action {
|
||||
Action::Nothing => (),
|
||||
Action::MoveLeft => game.move_left(),
|
||||
Action::MoveRight => game.move_right(),
|
||||
Action::SoftDrop => game.move_down(),
|
||||
Action::HardDrop => game.hard_drop(),
|
||||
Action::Hold => game.hold(),
|
||||
Action::RotateLeft => game.rotate_left(),
|
||||
Action::RotateRight => game.rotate_right(),
|
||||
}
|
||||
|
||||
let new_state = game.clone();
|
||||
let new_state = (&game).into();
|
||||
let mut reward = game.score() as f64 - cur_score as f64;
|
||||
// if action != Action::Nothing {
|
||||
// reward -= 0.0;
|
||||
// }
|
||||
if action != Action::Nothing {
|
||||
reward -= 0.0;
|
||||
}
|
||||
|
||||
if game.is_game_over().is_some() {
|
||||
reward = -1.0;
|
||||
|
@ -394,13 +292,7 @@ pub fn train_actor<T: 'static + QLearningActor + Actor>(
|
|||
|
||||
let new_legal_actions = game.get_legal_actions();
|
||||
|
||||
actor.update(
|
||||
cur_state.into(),
|
||||
action,
|
||||
new_state,
|
||||
&new_legal_actions,
|
||||
reward,
|
||||
);
|
||||
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
|
||||
|
||||
game.tick();
|
||||
}
|
||||
|
@ -408,12 +300,5 @@ pub fn train_actor<T: 'static + QLearningActor + Actor>(
|
|||
avg += game.score() as f64 / (episodes / 10) as f64;
|
||||
}
|
||||
|
||||
if opts.no_explore_during_evaluation {
|
||||
actor.set_exploration_prob(0.0);
|
||||
}
|
||||
|
||||
if opts.no_learn_during_evaluation {
|
||||
actor.set_learning_rate(0.0);
|
||||
}
|
||||
Box::new(actor)
|
||||
actor
|
||||
}
|
||||
|
|
|
@ -74,7 +74,6 @@ arg_enum! {
|
|||
pub enum Agent {
|
||||
QLearning,
|
||||
ApproximateQLearning,
|
||||
HeuristicGenetic
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,3 +89,10 @@ pub fn init_verbosity(opts: &Opts) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_actor(agent: Agent) -> Box<dyn Actor> {
|
||||
match agent {
|
||||
Agent::QLearning => Box::new(qlearning::QLearningAgent::default()),
|
||||
Agent::ApproximateQLearning => Box::new(qlearning::ApproximateQLearning::default()),
|
||||
}
|
||||
}
|
||||
|
|
57
src/game.rs
57
src/game.rs
|
@ -16,8 +16,6 @@ const LINE_CLEAR_DELAY: u64 = TICKS_PER_SECOND as u64 * 41 / 60;
|
|||
pub enum LossReason {
|
||||
TopOut,
|
||||
LockOut,
|
||||
PieceLimitReached,
|
||||
TickLimitReached,
|
||||
BlockOut(Position),
|
||||
}
|
||||
|
||||
|
@ -37,12 +35,6 @@ pub struct Game {
|
|||
/// bonus is needed.
|
||||
last_clear_action: ClearAction,
|
||||
pub line_clears: u32,
|
||||
// used if we set a limit on how long a game can last.
|
||||
pieces_placed: usize,
|
||||
piece_limit: usize,
|
||||
|
||||
// used if we set a limit on how long the game can be played.
|
||||
tick_limit: u64,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Game {
|
||||
|
@ -67,9 +59,6 @@ impl Default for Game {
|
|||
is_game_over: None,
|
||||
last_clear_action: ClearAction::Single, // Doesn't matter what it's initialized to
|
||||
line_clears: 0,
|
||||
pieces_placed: 0,
|
||||
piece_limit: 0,
|
||||
tick_limit: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -84,7 +73,6 @@ impl Tickable for Game {
|
|||
return;
|
||||
}
|
||||
self.tick += 1;
|
||||
|
||||
match self.tick {
|
||||
t if t == self.next_spawn_tick => {
|
||||
trace!("Spawn tick was met, spawning new Tetromino!");
|
||||
|
@ -111,10 +99,6 @@ impl Tickable for Game {
|
|||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
if self.tick == self.tick_limit {
|
||||
self.is_game_over = Some(LossReason::TickLimitReached);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -184,18 +168,10 @@ impl Game {
|
|||
// It's possible that the player moved the piece in the meantime.
|
||||
if !self.playfield.can_active_piece_move_down() {
|
||||
let positions = self.playfield.lock_active_piece();
|
||||
if self.pieces_placed < self.piece_limit {
|
||||
self.pieces_placed += 1;
|
||||
if self.pieces_placed >= self.piece_limit {
|
||||
trace!("Loss due to piece limit!");
|
||||
self.is_game_over = Some(LossReason::PieceLimitReached);
|
||||
}
|
||||
}
|
||||
|
||||
self.is_game_over = self.is_game_over.or_else(|| {
|
||||
if positions.iter().map(|p| p.y).all(|y| y < 20) {
|
||||
trace!("Loss due to lockout! {:?}", positions);
|
||||
Some(LossReason::LockOut)
|
||||
trace!("Loss due to topout! {:?}", positions);
|
||||
Some(LossReason::TopOut)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -207,7 +183,6 @@ impl Game {
|
|||
self.line_clears += cleared_lines as u32;
|
||||
self.score += (cleared_lines * 100 * self.level as usize) as u32;
|
||||
self.level = (self.line_clears / 10) as u8;
|
||||
|
||||
self.playfield.active_piece = None;
|
||||
self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY;
|
||||
} else {
|
||||
|
@ -220,17 +195,9 @@ impl Game {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn set_piece_limit(&mut self, size: usize) {
|
||||
self.piece_limit = size;
|
||||
}
|
||||
|
||||
pub fn playfield(&self) -> &PlayField {
|
||||
&self.playfield
|
||||
}
|
||||
|
||||
pub fn set_time_limit(&mut self, duration: std::time::Duration) {
|
||||
self.tick_limit = duration.as_secs() * TICKS_PER_SECOND as u64;
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Controllable {
|
||||
|
@ -244,7 +211,7 @@ pub trait Controllable {
|
|||
fn get_legal_actions(&self) -> Vec<Action>;
|
||||
}
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Copy, Clone)]
|
||||
pub enum Action {
|
||||
Nothing, // Default value
|
||||
MoveLeft,
|
||||
|
@ -262,7 +229,8 @@ impl Controllable for Game {
|
|||
return;
|
||||
}
|
||||
|
||||
if self.playfield.move_offset(-1, 0) && !self.playfield.can_active_piece_move_down() {
|
||||
self.playfield.move_offset(-1, 0);
|
||||
if !self.playfield.can_active_piece_move_down() {
|
||||
self.update_lock_tick();
|
||||
}
|
||||
}
|
||||
|
@ -272,7 +240,8 @@ impl Controllable for Game {
|
|||
return;
|
||||
}
|
||||
|
||||
if self.playfield.move_offset(1, 0) && !self.playfield.can_active_piece_move_down() {
|
||||
self.playfield.move_offset(1, 0);
|
||||
if !self.playfield.can_active_piece_move_down() {
|
||||
self.update_lock_tick();
|
||||
}
|
||||
}
|
||||
|
@ -300,7 +269,7 @@ impl Controllable for Game {
|
|||
active_piece.position = active_piece.position.offset(x, y);
|
||||
active_piece.rotate_left();
|
||||
self.playfield.active_piece = Some(active_piece);
|
||||
self.update_lock_tick();
|
||||
// self.update_lock_tick();
|
||||
}
|
||||
Err(_) => (),
|
||||
}
|
||||
|
@ -317,7 +286,7 @@ impl Controllable for Game {
|
|||
active_piece.position = active_piece.position.offset(x, y);
|
||||
active_piece.rotate_right();
|
||||
self.playfield.active_piece = Some(active_piece);
|
||||
self.update_lock_tick();
|
||||
// self.update_lock_tick();
|
||||
}
|
||||
Err(_) => (),
|
||||
}
|
||||
|
@ -363,13 +332,13 @@ impl Controllable for Game {
|
|||
|
||||
fn get_legal_actions(&self) -> Vec<Action> {
|
||||
let mut legal_actions = vec![
|
||||
Action::RotateLeft,
|
||||
Action::RotateRight,
|
||||
Action::SoftDrop,
|
||||
Action::HardDrop,
|
||||
Action::Nothing,
|
||||
Action::MoveLeft,
|
||||
Action::MoveRight,
|
||||
Action::SoftDrop,
|
||||
Action::HardDrop,
|
||||
Action::RotateLeft,
|
||||
Action::RotateRight,
|
||||
];
|
||||
|
||||
if self.playfield.can_swap() {
|
||||
|
|
285
src/main.rs
285
src/main.rs
|
@ -4,7 +4,9 @@ use cli::*;
|
|||
use game::{Action, Controllable, Game, Tickable};
|
||||
use graphics::standard_renderer;
|
||||
use graphics::COLOR_BACKGROUND;
|
||||
use indicatif::ProgressIterator;
|
||||
use log::{debug, info, trace};
|
||||
use qlearning::train_actor;
|
||||
use rand::SeedableRng;
|
||||
use sdl2::event::Event;
|
||||
use sdl2::keyboard::Keycode;
|
||||
|
@ -22,33 +24,42 @@ mod tetromino;
|
|||
|
||||
const TICKS_PER_SECOND: usize = 60;
|
||||
|
||||
#[tokio::main(core_threads = 16)]
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let opts = crate::cli::Opts::parse();
|
||||
|
||||
init_verbosity(&opts)?;
|
||||
let mut actor = None;
|
||||
|
||||
let agent = match opts.subcmd {
|
||||
SubCommand::Play(_) => None,
|
||||
SubCommand::Train(sub_opts) => Some(match sub_opts.agent {
|
||||
Agent::QLearning => {
|
||||
qlearning::train_actor(qlearning::QLearningAgent::default(), &sub_opts)
|
||||
}
|
||||
Agent::ApproximateQLearning => {
|
||||
let agent =
|
||||
qlearning::train_actor(qlearning::ApproximateQLearning::default(), &sub_opts);
|
||||
agent.dbg();
|
||||
agent
|
||||
}
|
||||
Agent::HeuristicGenetic => {
|
||||
let agent = genetic::train_actor(&sub_opts).await;
|
||||
agent.dbg();
|
||||
agent
|
||||
}
|
||||
}),
|
||||
};
|
||||
match opts.subcmd {
|
||||
SubCommand::Play(sub_opts) => {}
|
||||
SubCommand::Train(sub_opts) => {
|
||||
let mut to_train = get_actor(sub_opts.agent);
|
||||
to_train.set_learning_rate(sub_opts.learning_rate);
|
||||
to_train.set_discount_rate(sub_opts.discount_rate);
|
||||
to_train.set_exploration_prob(sub_opts.exploration_prob);
|
||||
|
||||
play_game(agent).await
|
||||
info!(
|
||||
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
|
||||
sub_opts.learning_rate,
|
||||
sub_opts.discount_rate,
|
||||
sub_opts.exploration_prob
|
||||
);
|
||||
let mut trained_actor = train_actor(sub_opts.episodes, to_train);
|
||||
if sub_opts.no_explore_during_evaluation {
|
||||
trained_actor.set_exploration_prob(0.0);
|
||||
}
|
||||
|
||||
if sub_opts.no_learn_during_evaluation {
|
||||
trained_actor.set_learning_rate(0.0);
|
||||
}
|
||||
|
||||
actor = Some(trained_actor);
|
||||
}
|
||||
}
|
||||
|
||||
play_game(actor).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
@ -61,127 +72,125 @@ async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std:
|
|||
.build()?;
|
||||
let mut canvas = window.into_canvas().build()?;
|
||||
let mut event_pump = sdl_context.event_pump()?;
|
||||
let mut game = Game::default();
|
||||
let mut interval = interval(Duration::from_millis(1000 / TICKS_PER_SECOND as u64));
|
||||
|
||||
'escape: loop {
|
||||
let mut game = Game::default();
|
||||
|
||||
loop {
|
||||
match game.is_game_over() {
|
||||
Some(e) => {
|
||||
println!("Lost due to: {:?}", e);
|
||||
break;
|
||||
}
|
||||
None => (),
|
||||
'running: loop {
|
||||
match game.is_game_over() {
|
||||
Some(e) => {
|
||||
println!("Lost due to: {:?}", e);
|
||||
break;
|
||||
}
|
||||
|
||||
let cur_state = game.clone();
|
||||
|
||||
// If there's an actor, the player action will get overridden. If not,
|
||||
// then then the player action falls through, if there is one. This is
|
||||
// to allow for restarting and quitting the game from the GUI.
|
||||
let mut action = None;
|
||||
for event in event_pump.poll_iter() {
|
||||
match event {
|
||||
Event::Quit { .. }
|
||||
| Event::KeyDown {
|
||||
keycode: Some(Keycode::Escape),
|
||||
..
|
||||
} => {
|
||||
debug!("Escape registered");
|
||||
break 'escape Ok(());
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Left),
|
||||
..
|
||||
} => {
|
||||
debug!("Move left registered");
|
||||
action = Some(Action::MoveLeft);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Right),
|
||||
..
|
||||
} => {
|
||||
debug!("Move right registered");
|
||||
action = Some(Action::MoveRight);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Down),
|
||||
..
|
||||
} => {
|
||||
debug!("Soft drop registered");
|
||||
action = Some(Action::SoftDrop);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Z),
|
||||
..
|
||||
} => {
|
||||
debug!("Rotate left registered");
|
||||
action = Some(Action::RotateLeft);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::X),
|
||||
..
|
||||
} => {
|
||||
debug!("Rotate right registered");
|
||||
action = Some(Action::RotateRight);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Space),
|
||||
..
|
||||
}
|
||||
| Event::KeyDown {
|
||||
keycode: Some(Keycode::Up),
|
||||
..
|
||||
} => {
|
||||
debug!("Hard drop registered");
|
||||
action = Some(Action::HardDrop);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::LShift),
|
||||
..
|
||||
} => {
|
||||
debug!("Hold registered");
|
||||
action = Some(Action::Hold);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::R),
|
||||
..
|
||||
} => {
|
||||
info!("Restarting game");
|
||||
game = Game::default();
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(e), ..
|
||||
} => trace!("Ignoring keycode {}", e),
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
actor.as_mut().map(|actor| {
|
||||
action =
|
||||
Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())));
|
||||
});
|
||||
|
||||
action.map(|action| match action {
|
||||
Action::Nothing => (),
|
||||
Action::MoveLeft => game.move_left(),
|
||||
Action::MoveRight => game.move_right(),
|
||||
Action::SoftDrop => game.move_down(),
|
||||
Action::HardDrop => game.hard_drop(),
|
||||
Action::Hold => game.hold(),
|
||||
Action::RotateLeft => game.rotate_left(),
|
||||
Action::RotateRight => game.rotate_right(),
|
||||
});
|
||||
|
||||
game.tick();
|
||||
canvas.set_draw_color(COLOR_BACKGROUND);
|
||||
canvas.clear();
|
||||
standard_renderer::render(&mut canvas, &game);
|
||||
canvas.present();
|
||||
interval.tick().await;
|
||||
None => (),
|
||||
}
|
||||
|
||||
info!("Final score: {}", game.score());
|
||||
let cur_state = (&game).into();
|
||||
|
||||
// If there's an actor, the player action will get overridden. If not,
|
||||
// then then the player action falls through, if there is one. This is
|
||||
// to allow for restarting and quitting the game from the GUI.
|
||||
let mut action = None;
|
||||
for event in event_pump.poll_iter() {
|
||||
match event {
|
||||
Event::Quit { .. }
|
||||
| Event::KeyDown {
|
||||
keycode: Some(Keycode::Escape),
|
||||
..
|
||||
} => {
|
||||
debug!("Escape registered");
|
||||
break 'running;
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Left),
|
||||
..
|
||||
} => {
|
||||
debug!("Move left registered");
|
||||
action = Some(Action::MoveLeft);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Right),
|
||||
..
|
||||
} => {
|
||||
debug!("Move right registered");
|
||||
action = Some(Action::MoveRight);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Down),
|
||||
..
|
||||
} => {
|
||||
debug!("Soft drop registered");
|
||||
action = Some(Action::SoftDrop);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Z),
|
||||
..
|
||||
} => {
|
||||
debug!("Rotate left registered");
|
||||
action = Some(Action::RotateLeft);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::X),
|
||||
..
|
||||
} => {
|
||||
debug!("Rotate right registered");
|
||||
action = Some(Action::RotateRight);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::Space),
|
||||
..
|
||||
}
|
||||
| Event::KeyDown {
|
||||
keycode: Some(Keycode::Up),
|
||||
..
|
||||
} => {
|
||||
debug!("Hard drop registered");
|
||||
action = Some(Action::HardDrop);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::LShift),
|
||||
..
|
||||
} => {
|
||||
debug!("Hold registered");
|
||||
action = Some(Action::Hold);
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(Keycode::R),
|
||||
..
|
||||
} => {
|
||||
info!("Restarting game");
|
||||
game = Game::default();
|
||||
}
|
||||
Event::KeyDown {
|
||||
keycode: Some(e), ..
|
||||
} => trace!("Ignoring keycode {}", e),
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
actor.as_mut().map(|actor| {
|
||||
action = Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())));
|
||||
});
|
||||
|
||||
action.map(|action| match action {
|
||||
Action::Nothing => (),
|
||||
Action::MoveLeft => game.move_left(),
|
||||
Action::MoveRight => game.move_right(),
|
||||
Action::SoftDrop => game.move_down(),
|
||||
Action::HardDrop => game.hard_drop(),
|
||||
Action::Hold => game.hold(),
|
||||
Action::RotateLeft => game.rotate_left(),
|
||||
Action::RotateRight => game.rotate_right(),
|
||||
});
|
||||
|
||||
game.tick();
|
||||
canvas.set_draw_color(COLOR_BACKGROUND);
|
||||
canvas.clear();
|
||||
standard_renderer::render(&mut canvas, &game);
|
||||
canvas.present();
|
||||
interval.tick().await;
|
||||
}
|
||||
|
||||
info!("Final score: {}", game.score());
|
||||
actor.map(|a| a.dbg());
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue