Compare commits
3 commits
a65f48f585
...
5a9e2538aa
Author | SHA1 | Date | |
---|---|---|---|
5a9e2538aa | |||
deb74da552 | |||
710a7dbebb |
9 changed files with 732 additions and 294 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,2 +1,4 @@
|
||||||
/target
|
/target
|
||||||
**/*.rs.bk
|
**/*.rs.bk
|
||||||
|
flamegraph.svg
|
||||||
|
perf.*
|
||||||
|
|
104
Cargo.lock
generated
104
Cargo.lock
generated
|
@ -150,11 +150,88 @@ name = "fuchsia-zircon-sys"
|
||||||
version = "0.3.3"
|
version = "0.3.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-executor 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-channel"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-core"
|
name = "futures-core"
|
||||||
version = "0.3.4"
|
version = "0.3.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-executor"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-io"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-macro"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"proc-macro2 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"quote 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-sink"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-task"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-util"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-macro 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"pin-utils 0.1.0-alpha.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"proc-macro-nested 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"slab 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getrandom"
|
name = "getrandom"
|
||||||
version = "0.1.13"
|
version = "0.1.13"
|
||||||
|
@ -345,6 +422,11 @@ name = "pin-project-lite"
|
||||||
version = "0.1.4"
|
version = "0.1.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pin-utils"
|
||||||
|
version = "0.1.0-alpha.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ppv-lite86"
|
name = "ppv-lite86"
|
||||||
version = "0.2.6"
|
version = "0.2.6"
|
||||||
|
@ -374,6 +456,16 @@ dependencies = [
|
||||||
"version_check 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"version_check 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro-hack"
|
||||||
|
version = "0.5.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro-nested"
|
||||||
|
version = "0.1.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro2"
|
||||||
version = "1.0.9"
|
version = "1.0.9"
|
||||||
|
@ -549,6 +641,7 @@ name = "tetris"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)",
|
"clap 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)",
|
||||||
|
"futures 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
"log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
@ -695,7 +788,15 @@ dependencies = [
|
||||||
"checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3"
|
"checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3"
|
||||||
"checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82"
|
"checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82"
|
||||||
"checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
|
"checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
|
||||||
|
"checksum futures 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "5c329ae8753502fb44ae4fc2b622fa2a94652c41e795143765ba0927f92ab780"
|
||||||
|
"checksum futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f0c77d04ce8edd9cb903932b608268b3fffec4163dc053b3b402bf47eac1f1a8"
|
||||||
"checksum futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f25592f769825e89b92358db00d26f965761e094951ac44d3663ef25b7ac464a"
|
"checksum futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f25592f769825e89b92358db00d26f965761e094951ac44d3663ef25b7ac464a"
|
||||||
|
"checksum futures-executor 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f674f3e1bcb15b37284a90cedf55afdba482ab061c407a9c0ebbd0f3109741ba"
|
||||||
|
"checksum futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "a638959aa96152c7a4cddf50fcb1e3fede0583b27157c26e67d6f99904090dc6"
|
||||||
|
"checksum futures-macro 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "9a5081aa3de1f7542a794a397cde100ed903b0630152d0973479018fd85423a7"
|
||||||
|
"checksum futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "3466821b4bc114d95b087b850a724c6f83115e929bc88f1fa98a3304a944c8a6"
|
||||||
|
"checksum futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "7b0a34e53cf6cdcd0178aa573aed466b646eb3db769570841fda0c7ede375a27"
|
||||||
|
"checksum futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "22766cf25d64306bedf0384da004d05c9974ab104fcc4528f1236181c18004c5"
|
||||||
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
|
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
|
||||||
"checksum heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205"
|
"checksum heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205"
|
||||||
"checksum hermit-abi 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "1010591b26bbfe835e9faeabeb11866061cc7dcebffd56ad7d0942d0e61aefd8"
|
"checksum hermit-abi 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "1010591b26bbfe835e9faeabeb11866061cc7dcebffd56ad7d0942d0e61aefd8"
|
||||||
|
@ -718,9 +819,12 @@ dependencies = [
|
||||||
"checksum num_cpus 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "46203554f085ff89c235cd12f7075f3233af9b11ed7c9e16dfe2560d03313ce6"
|
"checksum num_cpus 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "46203554f085ff89c235cd12f7075f3233af9b11ed7c9e16dfe2560d03313ce6"
|
||||||
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
|
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
|
||||||
"checksum pin-project-lite 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "237844750cfbb86f67afe27eee600dfbbcb6188d734139b534cbfbf4f96792ae"
|
"checksum pin-project-lite 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "237844750cfbb86f67afe27eee600dfbbcb6188d734139b534cbfbf4f96792ae"
|
||||||
|
"checksum pin-utils 0.1.0-alpha.4 (registry+https://github.com/rust-lang/crates.io-index)" = "5894c618ce612a3fa23881b152b608bafb8c56cfc22f434a3ba3120b40f7b587"
|
||||||
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
|
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
|
||||||
"checksum proc-macro-error 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7"
|
"checksum proc-macro-error 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7"
|
||||||
"checksum proc-macro-error-attr 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "8a5b4b77fdb63c1eca72173d68d24501c54ab1269409f6b672c85deb18af69de"
|
"checksum proc-macro-error-attr 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "8a5b4b77fdb63c1eca72173d68d24501c54ab1269409f6b672c85deb18af69de"
|
||||||
|
"checksum proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)" = "0d659fe7c6d27f25e9d80a1a094c223f5246f6a6596453e09d7229bf42750b63"
|
||||||
|
"checksum proc-macro-nested 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "8e946095f9d3ed29ec38de908c22f95d9ac008e424c7bcae54c75a79c527c694"
|
||||||
"checksum proc-macro2 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)" = "6c09721c6781493a2a492a96b5a5bf19b65917fe6728884e7c44dd0c60ca3435"
|
"checksum proc-macro2 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)" = "6c09721c6781493a2a492a96b5a5bf19b65917fe6728884e7c44dd0c60ca3435"
|
||||||
"checksum quote 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2bdc6c187c65bca4260c9011c9e3132efe4909da44726bad24cf7572ae338d7f"
|
"checksum quote 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2bdc6c187c65bca4260c9011c9e3132efe4909da44726bad24cf7572ae338d7f"
|
||||||
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
|
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
|
||||||
|
|
|
@ -13,4 +13,5 @@ log = "0.4"
|
||||||
simple_logger = "1.6"
|
simple_logger = "1.6"
|
||||||
sdl2 = { version = "0.33.0", features = ["ttf"] }
|
sdl2 = { version = "0.33.0", features = ["ttf"] }
|
||||||
clap = { git = "https://github.com/clap-rs/clap/", features = ["color"] }
|
clap = { git = "https://github.com/clap-rs/clap/", features = ["color"] }
|
||||||
indicatif = "0.14"
|
indicatif = "0.14"
|
||||||
|
futures = "0.3"
|
|
@ -1,9 +1,17 @@
|
||||||
// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
|
// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
|
||||||
|
|
||||||
use super::Actor;
|
use super::{Actor, Predictable, State};
|
||||||
|
use crate::{
|
||||||
|
cli::Train,
|
||||||
|
game::{Action, Controllable, Game, Tickable},
|
||||||
|
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
||||||
|
};
|
||||||
|
use indicatif::ProgressBar;
|
||||||
|
use log::debug;
|
||||||
use rand::rngs::SmallRng;
|
use rand::rngs::SmallRng;
|
||||||
use rand::Rng;
|
use rand::{seq::SliceRandom, Rng, SeedableRng};
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug)]
|
||||||
pub struct Parameters {
|
pub struct Parameters {
|
||||||
total_height: f64,
|
total_height: f64,
|
||||||
bumpiness: f64,
|
bumpiness: f64,
|
||||||
|
@ -11,6 +19,17 @@ pub struct Parameters {
|
||||||
complete_lines: f64,
|
complete_lines: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for Parameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
total_height: 1.0,
|
||||||
|
bumpiness: 1.0,
|
||||||
|
holes: 1.0,
|
||||||
|
complete_lines: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Parameters {
|
impl Parameters {
|
||||||
fn mutate(mut self, rng: &mut SmallRng) {
|
fn mutate(mut self, rng: &mut SmallRng) {
|
||||||
let mutation_amt = rng.gen_range(-0.2, 0.2);
|
let mutation_amt = rng.gen_range(-0.2, 0.2);
|
||||||
|
@ -33,39 +52,220 @@ impl Parameters {
|
||||||
self.holes /= normalization_factor;
|
self.holes /= normalization_factor;
|
||||||
self.complete_lines /= normalization_factor;
|
self.complete_lines /= normalization_factor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn dot_multiply(&self, other: &Self) -> f64 {
|
||||||
|
self.total_height * other.total_height
|
||||||
|
+ self.bumpiness * other.bumpiness
|
||||||
|
+ self.holes * other.holes
|
||||||
|
+ self.complete_lines * other.complete_lines
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct GeneticHeuristicAgent {}
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct GeneticHeuristicAgent {
|
||||||
|
params: Parameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for GeneticHeuristicAgent {
|
||||||
|
fn default() -> Self {
|
||||||
|
let mut rng = SmallRng::from_entropy();
|
||||||
|
Self {
|
||||||
|
params: Parameters {
|
||||||
|
total_height: rng.gen::<f64>(),
|
||||||
|
bumpiness: rng.gen::<f64>(),
|
||||||
|
holes: rng.gen::<f64>(),
|
||||||
|
complete_lines: rng.gen::<f64>(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GeneticHeuristicAgent {
|
||||||
|
fn extract_features_from_state(state: &State) -> Parameters {
|
||||||
|
let mut heights = [None; PLAYFIELD_WIDTH];
|
||||||
|
for r in 0..PLAYFIELD_HEIGHT {
|
||||||
|
for c in 0..PLAYFIELD_WIDTH {
|
||||||
|
if heights[c].is_none() && state.matrix[r][c].is_some() {
|
||||||
|
heights[c] = Some(PLAYFIELD_HEIGHT - r);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_height = heights
|
||||||
|
.iter()
|
||||||
|
.map(|o| o.unwrap_or_else(|| 0))
|
||||||
|
.sum::<usize>() as f64;
|
||||||
|
|
||||||
|
let bumpiness = heights
|
||||||
|
.iter()
|
||||||
|
.map(|o| o.unwrap_or_else(|| 0) as isize)
|
||||||
|
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
|
||||||
|
.0 as f64;
|
||||||
|
|
||||||
|
let complete_lines = state
|
||||||
|
.matrix
|
||||||
|
.iter()
|
||||||
|
.map(|row| row.iter().all(Option::is_some))
|
||||||
|
.map(|c| if c { 1.0 } else { 0.0 })
|
||||||
|
.sum::<f64>();
|
||||||
|
|
||||||
|
let mut holes = 0;
|
||||||
|
for r in 1..PLAYFIELD_HEIGHT {
|
||||||
|
for c in 0..PLAYFIELD_WIDTH {
|
||||||
|
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
|
||||||
|
holes += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Parameters {
|
||||||
|
total_height,
|
||||||
|
bumpiness,
|
||||||
|
complete_lines,
|
||||||
|
holes: holes as f64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_heuristic(&self, game: &Game, action: &Action) -> f64 {
|
||||||
|
self.params.dot_multiply(&Self::extract_features_from_state(
|
||||||
|
&game.get_next_state(*action).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn breed(&self, self_fitness: u32, other: &Self, other_fitness: u32) -> Self {
|
||||||
|
let weight = self_fitness + other_fitness;
|
||||||
|
|
||||||
|
if weight != 0 {
|
||||||
|
let self_weight = self_fitness as f64 / weight as f64;
|
||||||
|
let other_weight = other_fitness as f64 / weight as f64;
|
||||||
|
Self {
|
||||||
|
params: Parameters {
|
||||||
|
total_height: self.params.total_height * self_weight
|
||||||
|
+ other.params.total_height * other_weight,
|
||||||
|
bumpiness: self.params.total_height * self_weight
|
||||||
|
+ other.params.total_height * other_weight,
|
||||||
|
holes: self.params.total_height * self_weight
|
||||||
|
+ other.params.total_height * other_weight,
|
||||||
|
complete_lines: self.params.total_height * self_weight
|
||||||
|
+ other.params.total_height * other_weight,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mutate(&mut self, rng: &mut SmallRng) {
|
||||||
|
self.params.mutate(rng);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Actor for GeneticHeuristicAgent {
|
impl Actor for GeneticHeuristicAgent {
|
||||||
fn get_action(
|
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
|
||||||
&self,
|
let actions = legal_actions
|
||||||
rng: &mut SmallRng,
|
.iter()
|
||||||
state: &super::State,
|
.map(|action| {
|
||||||
legal_actions: &[crate::game::Action],
|
(
|
||||||
) -> crate::game::Action {
|
action,
|
||||||
unimplemented!()
|
(self.get_heuristic(game, action) * 1_000_000.0) as usize,
|
||||||
}
|
)
|
||||||
fn update(
|
})
|
||||||
&mut self,
|
.collect::<Vec<_>>();
|
||||||
state: super::State,
|
|
||||||
action: crate::game::Action,
|
let max_val = actions
|
||||||
next_state: super::State,
|
.iter()
|
||||||
next_legal_actions: &[crate::game::Action],
|
.max_by_key(|(_, heuristic)| heuristic)
|
||||||
reward: f64,
|
.unwrap()
|
||||||
) {
|
.1;
|
||||||
unimplemented!()
|
|
||||||
}
|
*actions
|
||||||
fn set_learning_rate(&mut self, learning_rate: f64) {
|
.iter()
|
||||||
unimplemented!()
|
.filter(|e| e.1 == max_val)
|
||||||
}
|
.collect::<Vec<_>>()
|
||||||
fn set_exploration_prob(&mut self, exploration_prob: f64) {
|
.choose(rng)
|
||||||
unimplemented!()
|
.unwrap()
|
||||||
}
|
.0
|
||||||
fn set_discount_rate(&mut self, discount_rate: f64) {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dbg(&self) {
|
fn dbg(&self) {
|
||||||
unimplemented!()
|
debug!("{:?}", self.params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn train_actor(opts: &Train) -> Box<dyn Actor> {
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
let rng = Arc::new(Mutex::new(SmallRng::from_entropy()));
|
||||||
|
let mut population = vec![(GeneticHeuristicAgent::default(), 0); 1000];
|
||||||
|
let total_pb = Arc::new(Mutex::new(ProgressBar::new(
|
||||||
|
population.len() as u64 * opts.episodes as u64,
|
||||||
|
)));
|
||||||
|
|
||||||
|
for _ in 0..opts.episodes {
|
||||||
|
let mut new_pop_futs = Vec::with_capacity(population.len());
|
||||||
|
let new_population = Arc::new(Mutex::new(Vec::with_capacity(population.len())));
|
||||||
|
let num_rounds = 10;
|
||||||
|
for i in 0..population.len() {
|
||||||
|
let rng = rng.clone();
|
||||||
|
let new_population = new_population.clone();
|
||||||
|
let total_pb = total_pb.clone();
|
||||||
|
let agent = population[i].0;
|
||||||
|
new_pop_futs.push(tokio::task::spawn(async move {
|
||||||
|
let mut fitness = 0;
|
||||||
|
for _ in 0..num_rounds {
|
||||||
|
let mut game = Game::default();
|
||||||
|
game.set_piece_limit(500);
|
||||||
|
game.set_time_limit(std::time::Duration::from_secs(60));
|
||||||
|
while (&game).is_game_over().is_none() {
|
||||||
|
let mut rng = rng.lock().expect("rng failed");
|
||||||
|
super::apply_action_to_game(
|
||||||
|
agent.get_action(
|
||||||
|
&mut rng,
|
||||||
|
&game.clone().into(),
|
||||||
|
&game.get_legal_actions(),
|
||||||
|
),
|
||||||
|
&mut game,
|
||||||
|
);
|
||||||
|
game.tick();
|
||||||
|
}
|
||||||
|
fitness += game.line_clears;
|
||||||
|
}
|
||||||
|
new_population.lock().unwrap().push((agent, fitness));
|
||||||
|
|
||||||
|
total_pb.lock().expect("progressbar failed").inc(1);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
futures::future::join_all(new_pop_futs).await;
|
||||||
|
|
||||||
|
let mut rng = SmallRng::from_entropy();
|
||||||
|
|
||||||
|
let mut new_population = new_population.lock().unwrap().clone();
|
||||||
|
let new_pop_size = population.len() * 3 / 10;
|
||||||
|
let mut breeded_population: Vec<(GeneticHeuristicAgent, u32)> =
|
||||||
|
Vec::with_capacity(new_pop_size);
|
||||||
|
for _ in 0..3 {
|
||||||
|
let mut random_selection = new_population
|
||||||
|
.choose_multiple(&mut rng, new_pop_size / 3)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
random_selection.sort_unstable_by(|e1, e2| e1.1.cmp(&e2.1));
|
||||||
|
let best_two = random_selection.iter().rev().take(2).collect::<Vec<_>>();
|
||||||
|
let parent1 = dbg!(best_two[0]);
|
||||||
|
let parent2 = dbg!(best_two[1]);
|
||||||
|
for _ in 0..new_pop_size / 3 {
|
||||||
|
let breeded = parent1.0.breed(parent1.1, &parent2.0, parent2.1);
|
||||||
|
let mut cloned = breeded.clone();
|
||||||
|
if rng.gen::<f64>() < 0.05 {
|
||||||
|
cloned.mutate(&mut rng);
|
||||||
|
}
|
||||||
|
breeded_population.push((cloned, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
new_population.sort_unstable_by_key(|e| e.1);
|
||||||
|
new_population.splice(..breeded_population.len(), breeded_population);
|
||||||
|
population = new_population;
|
||||||
|
}
|
||||||
|
total_pb.lock().unwrap().finish();
|
||||||
|
|
||||||
|
population.sort_unstable_by_key(|e| e.1);
|
||||||
|
Box::new(population.iter().rev().next().unwrap().0)
|
||||||
|
}
|
||||||
|
|
|
@ -40,20 +40,7 @@ impl From<PlayField> for State {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Actor {
|
pub trait Actor {
|
||||||
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action;
|
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action;
|
||||||
|
|
||||||
fn update(
|
|
||||||
&mut self,
|
|
||||||
state: State,
|
|
||||||
action: Action,
|
|
||||||
next_state: State,
|
|
||||||
next_legal_actions: &[Action],
|
|
||||||
reward: f64,
|
|
||||||
);
|
|
||||||
|
|
||||||
fn set_learning_rate(&mut self, learning_rate: f64);
|
|
||||||
fn set_exploration_prob(&mut self, exploration_prob: f64);
|
|
||||||
fn set_discount_rate(&mut self, discount_rate: f64);
|
|
||||||
|
|
||||||
fn dbg(&self);
|
fn dbg(&self);
|
||||||
}
|
}
|
||||||
|
@ -80,3 +67,16 @@ impl Predictable for Game {
|
||||||
game
|
game
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn apply_action_to_game(action: Action, game: &mut Game) {
|
||||||
|
match action {
|
||||||
|
Action::Nothing => (),
|
||||||
|
Action::MoveLeft => game.move_left(),
|
||||||
|
Action::MoveRight => game.move_right(),
|
||||||
|
Action::SoftDrop => game.move_down(),
|
||||||
|
Action::HardDrop => game.hard_drop(),
|
||||||
|
Action::Hold => game.hold(),
|
||||||
|
Action::RotateLeft => game.rotate_left(),
|
||||||
|
Action::RotateRight => game.rotate_right(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,16 +1,33 @@
|
||||||
|
use super::Predictable;
|
||||||
use crate::actors::{Actor, State};
|
use crate::actors::{Actor, State};
|
||||||
use crate::{
|
use crate::{
|
||||||
|
cli::Train,
|
||||||
game::{Action, Controllable, Game, Tickable},
|
game::{Action, Controllable, Game, Tickable},
|
||||||
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
playfield::{Matrix, PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
||||||
};
|
};
|
||||||
use indicatif::ProgressIterator;
|
use indicatif::ProgressIterator;
|
||||||
use log::{debug, info};
|
use log::{debug, info, trace};
|
||||||
use rand::rngs::SmallRng;
|
use rand::rngs::SmallRng;
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
pub trait QLearningActor: Actor {
|
||||||
|
fn update(
|
||||||
|
&mut self,
|
||||||
|
game_state: Game,
|
||||||
|
action: Action,
|
||||||
|
next_game_state: Game,
|
||||||
|
next_legal_actions: &[Action],
|
||||||
|
reward: f64,
|
||||||
|
);
|
||||||
|
|
||||||
|
fn set_learning_rate(&mut self, learning_rate: f64);
|
||||||
|
fn set_exploration_prob(&mut self, exploration_prob: f64);
|
||||||
|
fn set_discount_rate(&mut self, discount_rate: f64);
|
||||||
|
}
|
||||||
|
|
||||||
pub struct QLearningAgent {
|
pub struct QLearningAgent {
|
||||||
pub learning_rate: f64,
|
pub learning_rate: f64,
|
||||||
pub exploration_prob: f64,
|
pub exploration_prob: f64,
|
||||||
|
@ -38,11 +55,32 @@ impl QLearningAgent {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
|
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
|
||||||
*legal_actions
|
let legal_actions = legal_actions
|
||||||
|
.iter()
|
||||||
|
.map(|action| (action, self.get_q_value(state, *action)))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let max_val = legal_actions
|
||||||
.iter()
|
.iter()
|
||||||
.map(|action| (action, self.get_q_value(&state, *action)))
|
|
||||||
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
|
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
|
||||||
.expect("Failed to select an action")
|
.expect("Failed to select an action")
|
||||||
|
.1;
|
||||||
|
|
||||||
|
let actions_to_choose = legal_actions
|
||||||
|
.iter()
|
||||||
|
.filter(|(_, v)| max_val == *v)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
if actions_to_choose.len() != 1 {
|
||||||
|
trace!(
|
||||||
|
"more than one best option, choosing randomly: {:?}",
|
||||||
|
actions_to_choose
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
*actions_to_choose
|
||||||
|
.choose(&mut SmallRng::from_entropy())
|
||||||
|
.unwrap()
|
||||||
.0
|
.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,22 +101,30 @@ impl QLearningAgent {
|
||||||
impl Actor for QLearningAgent {
|
impl Actor for QLearningAgent {
|
||||||
// Because doing (Nothing) is in the set of legal actions, this will never
|
// Because doing (Nothing) is in the set of legal actions, this will never
|
||||||
// be empty
|
// be empty
|
||||||
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
|
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
|
||||||
if rng.gen::<f64>() < self.exploration_prob {
|
if rng.gen::<f64>() < self.exploration_prob {
|
||||||
*legal_actions.choose(rng).unwrap()
|
*legal_actions.choose(rng).unwrap()
|
||||||
} else {
|
} else {
|
||||||
self.get_action_from_q_values(state, legal_actions)
|
self.get_action_from_q_values(&game.into(), legal_actions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn dbg(&self) {
|
||||||
|
debug!("Total states: {}", self.q_values.len());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QLearningActor for QLearningAgent {
|
||||||
fn update(
|
fn update(
|
||||||
&mut self,
|
&mut self,
|
||||||
state: State,
|
game_state: Game,
|
||||||
action: Action,
|
action: Action,
|
||||||
next_state: State,
|
next_game_state: Game,
|
||||||
_next_legal_actions: &[Action],
|
_next_legal_actions: &[Action],
|
||||||
reward: f64,
|
reward: f64,
|
||||||
) {
|
) {
|
||||||
|
let state = (&game_state).into();
|
||||||
|
let next_state = (&next_game_state).into();
|
||||||
let cur_q_val = self.get_q_value(&state, action);
|
let cur_q_val = self.get_q_value(&state, action);
|
||||||
let new_q_val = cur_q_val
|
let new_q_val = cur_q_val
|
||||||
+ self.learning_rate
|
+ self.learning_rate
|
||||||
|
@ -93,10 +139,6 @@ impl Actor for QLearningAgent {
|
||||||
.insert(action, new_q_val);
|
.insert(action, new_q_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dbg(&self) {
|
|
||||||
debug!("Total states: {}", self.q_values.len());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_learning_rate(&mut self, learning_rate: f64) {
|
fn set_learning_rate(&mut self, learning_rate: f64) {
|
||||||
self.learning_rate = learning_rate;
|
self.learning_rate = learning_rate;
|
||||||
}
|
}
|
||||||
|
@ -112,126 +154,187 @@ pub struct ApproximateQLearning {
|
||||||
pub learning_rate: f64,
|
pub learning_rate: f64,
|
||||||
pub exploration_prob: f64,
|
pub exploration_prob: f64,
|
||||||
pub discount_rate: f64,
|
pub discount_rate: f64,
|
||||||
weights: HashMap<String, f64>,
|
weights: HashMap<Feature, f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ApproximateQLearning {
|
impl Default for ApproximateQLearning {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
|
let mut weights = HashMap::default();
|
||||||
|
weights.insert(Feature::TotalHeight, 1.0);
|
||||||
|
weights.insert(Feature::Bumpiness, 1.0);
|
||||||
|
weights.insert(Feature::LinesCleared, 1.0);
|
||||||
|
weights.insert(Feature::Holes, 1.0);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
learning_rate: 0.0,
|
learning_rate: 0.0,
|
||||||
exploration_prob: 0.0,
|
exploration_prob: 0.0,
|
||||||
discount_rate: 0.0,
|
discount_rate: 0.0,
|
||||||
weights: HashMap::default(),
|
weights,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
|
||||||
|
enum Feature {
|
||||||
|
TotalHeight,
|
||||||
|
Bumpiness,
|
||||||
|
LinesCleared,
|
||||||
|
Holes,
|
||||||
|
}
|
||||||
|
|
||||||
impl ApproximateQLearning {
|
impl ApproximateQLearning {
|
||||||
fn get_features(
|
fn get_features(&self, game: &Game, action: &Action) -> HashMap<Feature, f64> {
|
||||||
&self,
|
// let game = game.get_next_state(*action);
|
||||||
state: &State,
|
|
||||||
_action: &Action,
|
|
||||||
new_state: &State,
|
|
||||||
) -> HashMap<String, f64> {
|
|
||||||
let mut features = HashMap::default();
|
let mut features = HashMap::default();
|
||||||
|
let field = game.playfield().field();
|
||||||
let mut heights = [None; PLAYFIELD_WIDTH];
|
let heights = self.get_heights(field);
|
||||||
for r in 0..PLAYFIELD_HEIGHT {
|
|
||||||
for c in 0..PLAYFIELD_WIDTH {
|
|
||||||
if heights[c].is_none() && state.matrix[r][c].is_some() {
|
|
||||||
heights[c] = Some(PLAYFIELD_HEIGHT - r);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
features.insert(
|
features.insert(
|
||||||
"Total Height".into(),
|
Feature::TotalHeight,
|
||||||
heights
|
heights.iter().sum::<usize>() as f64 / (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64,
|
||||||
.iter()
|
|
||||||
.map(|o| o.unwrap_or_else(|| 0))
|
|
||||||
.sum::<usize>() as f64
|
|
||||||
/ (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
features.insert(
|
features.insert(
|
||||||
"Bumpiness".into(),
|
Feature::Bumpiness,
|
||||||
heights
|
heights
|
||||||
.iter()
|
.iter()
|
||||||
.map(|o| o.unwrap_or_else(|| 0) as isize)
|
.fold((0, 0), |(acc, prev), cur| {
|
||||||
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
|
(acc + (prev as isize - *cur as isize).abs(), *cur)
|
||||||
|
})
|
||||||
.0 as f64
|
.0 as f64
|
||||||
/ (PLAYFIELD_WIDTH * 40) as f64,
|
/ (PLAYFIELD_WIDTH * 40) as f64,
|
||||||
);
|
);
|
||||||
|
|
||||||
features.insert(
|
features.insert(
|
||||||
"Lines cleared".into(),
|
Feature::LinesCleared,
|
||||||
(new_state.line_clears - state.line_clears) as f64 / 4.0,
|
game.playfield()
|
||||||
|
.field()
|
||||||
|
.iter()
|
||||||
|
.map(|r| r.iter().all(Option::is_some))
|
||||||
|
.map(|r| if r { 1 } else { 0 })
|
||||||
|
.sum::<usize>() as f64
|
||||||
|
/ 4.0,
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut holes = 0;
|
let mut holes = 0;
|
||||||
for r in 1..PLAYFIELD_HEIGHT {
|
for r in 1..PLAYFIELD_HEIGHT {
|
||||||
for c in 0..PLAYFIELD_WIDTH {
|
for c in 0..PLAYFIELD_WIDTH {
|
||||||
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
|
if field[r][c].is_none() && field[r - 1][c].is_some() {
|
||||||
holes += 1;
|
holes += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
features.insert("Holes".into(), holes as f64);
|
features.insert(Feature::Holes, holes as f64);
|
||||||
|
|
||||||
features
|
features
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_q_value(&self, state: &State, action: &Action, next_state: &State) -> f64 {
|
fn get_heights(&self, matrix: &Matrix) -> Vec<usize> {
|
||||||
self.get_features(state, action, next_state)
|
let mut heights = vec![0; matrix[0].len()];
|
||||||
|
for r in 0..matrix.len() {
|
||||||
|
for c in 0..matrix[0].len() {
|
||||||
|
if heights[c] == 0 && matrix[r][c].is_some() {
|
||||||
|
heights[c] = matrix.len() - r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
heights
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_q_value(&self, game: &Game, action: &Action) -> f64 {
|
||||||
|
self.get_features(game, action)
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(key, val)| val * *self.weights.get(key).unwrap_or_else(|| &0.0))
|
.map(|(key, val)| val * *self.weights.get(key).unwrap())
|
||||||
.sum()
|
.sum()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
|
fn get_action_from_q_values(&self, game: &Game) -> Action {
|
||||||
*legal_actions
|
let legal_actions = game.get_legal_actions();
|
||||||
|
let legal_actions = legal_actions
|
||||||
|
.iter()
|
||||||
|
.map(|action| (action, self.get_q_value(game, action)))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let max_val = legal_actions
|
||||||
.iter()
|
.iter()
|
||||||
.map(|action| (action, self.get_q_value(&state, action, state)))
|
|
||||||
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
|
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
|
||||||
.expect("Failed to select an action")
|
.expect("Failed to select an action")
|
||||||
|
.1;
|
||||||
|
|
||||||
|
let actions_to_choose = legal_actions
|
||||||
|
.iter()
|
||||||
|
.filter(|(_, v)| max_val == *v)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
if actions_to_choose.len() != 1 {
|
||||||
|
trace!(
|
||||||
|
"more than one best option, choosing randomly: {:?}",
|
||||||
|
actions_to_choose
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
*actions_to_choose
|
||||||
|
.choose(&mut SmallRng::from_entropy())
|
||||||
|
.unwrap()
|
||||||
.0
|
.0
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_value(&self, state: &State, legal_actions: &[Action]) -> f64 {
|
fn get_value(&self, game: &Game) -> f64 {
|
||||||
legal_actions
|
game.get_legal_actions()
|
||||||
.iter()
|
.iter()
|
||||||
.map(|action| self.get_q_value(state, action, state))
|
.map(|action| self.get_q_value(game, action))
|
||||||
.max_by_key(|v| (v * 1_000_000.0) as isize)
|
.max_by_key(|v| (v * 1_000_000.0) as isize)
|
||||||
.unwrap_or_else(|| 0.0)
|
.unwrap_or_else(|| 0.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod aaa {
|
||||||
|
use super::*;
|
||||||
|
use crate::tetromino::TetrominoType;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_height() {
|
||||||
|
let agent = ApproximateQLearning::default();
|
||||||
|
let matrix = vec![
|
||||||
|
vec![None, None, None],
|
||||||
|
vec![None, Some(TetrominoType::T), None],
|
||||||
|
vec![None, None, None],
|
||||||
|
];
|
||||||
|
assert_eq!(agent.get_heights(&matrix), vec![0, 2, 0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Actor for ApproximateQLearning {
|
impl Actor for ApproximateQLearning {
|
||||||
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
|
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
|
||||||
if rng.gen::<f64>() < self.exploration_prob {
|
if rng.gen::<f64>() < self.exploration_prob {
|
||||||
*legal_actions.choose(rng).unwrap()
|
*legal_actions.choose(rng).unwrap()
|
||||||
} else {
|
} else {
|
||||||
self.get_action_from_q_values(state, legal_actions)
|
self.get_action_from_q_values(game)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn dbg(&self) {
|
||||||
|
dbg!(&self.weights);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QLearningActor for ApproximateQLearning {
|
||||||
fn update(
|
fn update(
|
||||||
&mut self,
|
&mut self,
|
||||||
state: State,
|
game_state: Game,
|
||||||
action: Action,
|
action: Action,
|
||||||
next_state: State,
|
next_game_state: Game,
|
||||||
next_legal_actions: &[Action],
|
next_legal_actions: &[Action],
|
||||||
reward: f64,
|
reward: f64,
|
||||||
) {
|
) {
|
||||||
let difference = reward
|
let difference = reward + self.discount_rate * self.get_value(&next_game_state)
|
||||||
+ self.discount_rate * self.get_value(&next_state, next_legal_actions)
|
- self.get_q_value(&game_state, &action);
|
||||||
- self.get_q_value(&state, &action, &next_state);
|
|
||||||
|
|
||||||
for (feat_key, feat_val) in self.get_features(&state, &action, &next_state) {
|
for (feat_key, feat_val) in self.get_features(&game_state, &action) {
|
||||||
self.weights.insert(
|
self.weights.insert(
|
||||||
feat_key.clone(),
|
feat_key.clone(),
|
||||||
*self.weights.get(&feat_key).unwrap_or_else(|| &0.0)
|
*self.weights.get(&feat_key).unwrap() + self.learning_rate * difference * feat_val,
|
||||||
+ self.learning_rate * difference * feat_val,
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -247,15 +350,23 @@ impl Actor for ApproximateQLearning {
|
||||||
fn set_discount_rate(&mut self, discount_rate: f64) {
|
fn set_discount_rate(&mut self, discount_rate: f64) {
|
||||||
self.discount_rate = discount_rate;
|
self.discount_rate = discount_rate;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dbg(&self) {
|
|
||||||
dbg!(&self.weights);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
|
pub fn train_actor<T: 'static + QLearningActor + Actor>(
|
||||||
|
mut actor: T,
|
||||||
|
opts: &Train,
|
||||||
|
) -> Box<dyn Actor> {
|
||||||
let mut rng = SmallRng::from_entropy();
|
let mut rng = SmallRng::from_entropy();
|
||||||
let mut avg = 0.0;
|
let mut avg = 0.0;
|
||||||
|
let episodes = opts.episodes;
|
||||||
|
|
||||||
|
actor.set_learning_rate(opts.learning_rate);
|
||||||
|
actor.set_discount_rate(opts.discount_rate);
|
||||||
|
actor.set_exploration_prob(opts.exploration_prob);
|
||||||
|
info!(
|
||||||
|
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
|
||||||
|
opts.learning_rate, opts.discount_rate, opts.exploration_prob
|
||||||
|
);
|
||||||
|
|
||||||
for i in (0..episodes).progress() {
|
for i in (0..episodes).progress() {
|
||||||
if i != 0 && i % (episodes / 10) == 0 {
|
if i != 0 && i % (episodes / 10) == 0 {
|
||||||
|
@ -265,26 +376,17 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor>
|
||||||
}
|
}
|
||||||
let mut game = Game::default();
|
let mut game = Game::default();
|
||||||
while (&game).is_game_over().is_none() {
|
while (&game).is_game_over().is_none() {
|
||||||
let cur_state = (&game).into();
|
let cur_state = game.clone();
|
||||||
let cur_score = game.score();
|
let cur_score = game.score();
|
||||||
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
|
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
|
||||||
|
|
||||||
match action {
|
super::apply_action_to_game(action, &mut game);
|
||||||
Action::Nothing => (),
|
|
||||||
Action::MoveLeft => game.move_left(),
|
|
||||||
Action::MoveRight => game.move_right(),
|
|
||||||
Action::SoftDrop => game.move_down(),
|
|
||||||
Action::HardDrop => game.hard_drop(),
|
|
||||||
Action::Hold => game.hold(),
|
|
||||||
Action::RotateLeft => game.rotate_left(),
|
|
||||||
Action::RotateRight => game.rotate_right(),
|
|
||||||
}
|
|
||||||
|
|
||||||
let new_state = (&game).into();
|
let new_state = game.clone();
|
||||||
let mut reward = game.score() as f64 - cur_score as f64;
|
let mut reward = game.score() as f64 - cur_score as f64;
|
||||||
if action != Action::Nothing {
|
// if action != Action::Nothing {
|
||||||
reward -= 0.0;
|
// reward -= 0.0;
|
||||||
}
|
// }
|
||||||
|
|
||||||
if game.is_game_over().is_some() {
|
if game.is_game_over().is_some() {
|
||||||
reward = -1.0;
|
reward = -1.0;
|
||||||
|
@ -292,7 +394,13 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor>
|
||||||
|
|
||||||
let new_legal_actions = game.get_legal_actions();
|
let new_legal_actions = game.get_legal_actions();
|
||||||
|
|
||||||
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
|
actor.update(
|
||||||
|
cur_state.into(),
|
||||||
|
action,
|
||||||
|
new_state,
|
||||||
|
&new_legal_actions,
|
||||||
|
reward,
|
||||||
|
);
|
||||||
|
|
||||||
game.tick();
|
game.tick();
|
||||||
}
|
}
|
||||||
|
@ -300,5 +408,12 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor>
|
||||||
avg += game.score() as f64 / (episodes / 10) as f64;
|
avg += game.score() as f64 / (episodes / 10) as f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
actor
|
if opts.no_explore_during_evaluation {
|
||||||
|
actor.set_exploration_prob(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.no_learn_during_evaluation {
|
||||||
|
actor.set_learning_rate(0.0);
|
||||||
|
}
|
||||||
|
Box::new(actor)
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,6 +74,7 @@ arg_enum! {
|
||||||
pub enum Agent {
|
pub enum Agent {
|
||||||
QLearning,
|
QLearning,
|
||||||
ApproximateQLearning,
|
ApproximateQLearning,
|
||||||
|
HeuristicGenetic
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,10 +90,3 @@ pub fn init_verbosity(opts: &Opts) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_actor(agent: Agent) -> Box<dyn Actor> {
|
|
||||||
match agent {
|
|
||||||
Agent::QLearning => Box::new(qlearning::QLearningAgent::default()),
|
|
||||||
Agent::ApproximateQLearning => Box::new(qlearning::ApproximateQLearning::default()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
57
src/game.rs
57
src/game.rs
|
@ -16,6 +16,8 @@ const LINE_CLEAR_DELAY: u64 = TICKS_PER_SECOND as u64 * 41 / 60;
|
||||||
pub enum LossReason {
|
pub enum LossReason {
|
||||||
TopOut,
|
TopOut,
|
||||||
LockOut,
|
LockOut,
|
||||||
|
PieceLimitReached,
|
||||||
|
TickLimitReached,
|
||||||
BlockOut(Position),
|
BlockOut(Position),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,6 +37,12 @@ pub struct Game {
|
||||||
/// bonus is needed.
|
/// bonus is needed.
|
||||||
last_clear_action: ClearAction,
|
last_clear_action: ClearAction,
|
||||||
pub line_clears: u32,
|
pub line_clears: u32,
|
||||||
|
// used if we set a limit on how long a game can last.
|
||||||
|
pieces_placed: usize,
|
||||||
|
piece_limit: usize,
|
||||||
|
|
||||||
|
// used if we set a limit on how long the game can be played.
|
||||||
|
tick_limit: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Debug for Game {
|
impl fmt::Debug for Game {
|
||||||
|
@ -59,6 +67,9 @@ impl Default for Game {
|
||||||
is_game_over: None,
|
is_game_over: None,
|
||||||
last_clear_action: ClearAction::Single, // Doesn't matter what it's initialized to
|
last_clear_action: ClearAction::Single, // Doesn't matter what it's initialized to
|
||||||
line_clears: 0,
|
line_clears: 0,
|
||||||
|
pieces_placed: 0,
|
||||||
|
piece_limit: 0,
|
||||||
|
tick_limit: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -73,6 +84,7 @@ impl Tickable for Game {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
self.tick += 1;
|
self.tick += 1;
|
||||||
|
|
||||||
match self.tick {
|
match self.tick {
|
||||||
t if t == self.next_spawn_tick => {
|
t if t == self.next_spawn_tick => {
|
||||||
trace!("Spawn tick was met, spawning new Tetromino!");
|
trace!("Spawn tick was met, spawning new Tetromino!");
|
||||||
|
@ -99,6 +111,10 @@ impl Tickable for Game {
|
||||||
}
|
}
|
||||||
_ => (),
|
_ => (),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.tick == self.tick_limit {
|
||||||
|
self.is_game_over = Some(LossReason::TickLimitReached);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,10 +184,18 @@ impl Game {
|
||||||
// It's possible that the player moved the piece in the meantime.
|
// It's possible that the player moved the piece in the meantime.
|
||||||
if !self.playfield.can_active_piece_move_down() {
|
if !self.playfield.can_active_piece_move_down() {
|
||||||
let positions = self.playfield.lock_active_piece();
|
let positions = self.playfield.lock_active_piece();
|
||||||
|
if self.pieces_placed < self.piece_limit {
|
||||||
|
self.pieces_placed += 1;
|
||||||
|
if self.pieces_placed >= self.piece_limit {
|
||||||
|
trace!("Loss due to piece limit!");
|
||||||
|
self.is_game_over = Some(LossReason::PieceLimitReached);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
self.is_game_over = self.is_game_over.or_else(|| {
|
self.is_game_over = self.is_game_over.or_else(|| {
|
||||||
if positions.iter().map(|p| p.y).all(|y| y < 20) {
|
if positions.iter().map(|p| p.y).all(|y| y < 20) {
|
||||||
trace!("Loss due to topout! {:?}", positions);
|
trace!("Loss due to lockout! {:?}", positions);
|
||||||
Some(LossReason::TopOut)
|
Some(LossReason::LockOut)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -183,6 +207,7 @@ impl Game {
|
||||||
self.line_clears += cleared_lines as u32;
|
self.line_clears += cleared_lines as u32;
|
||||||
self.score += (cleared_lines * 100 * self.level as usize) as u32;
|
self.score += (cleared_lines * 100 * self.level as usize) as u32;
|
||||||
self.level = (self.line_clears / 10) as u8;
|
self.level = (self.line_clears / 10) as u8;
|
||||||
|
|
||||||
self.playfield.active_piece = None;
|
self.playfield.active_piece = None;
|
||||||
self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY;
|
self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY;
|
||||||
} else {
|
} else {
|
||||||
|
@ -195,9 +220,17 @@ impl Game {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_piece_limit(&mut self, size: usize) {
|
||||||
|
self.piece_limit = size;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn playfield(&self) -> &PlayField {
|
pub fn playfield(&self) -> &PlayField {
|
||||||
&self.playfield
|
&self.playfield
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_time_limit(&mut self, duration: std::time::Duration) {
|
||||||
|
self.tick_limit = duration.as_secs() * TICKS_PER_SECOND as u64;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Controllable {
|
pub trait Controllable {
|
||||||
|
@ -211,7 +244,7 @@ pub trait Controllable {
|
||||||
fn get_legal_actions(&self) -> Vec<Action>;
|
fn get_legal_actions(&self) -> Vec<Action>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Hash, Eq, PartialEq, Debug, Copy, Clone)]
|
#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
|
||||||
pub enum Action {
|
pub enum Action {
|
||||||
Nothing, // Default value
|
Nothing, // Default value
|
||||||
MoveLeft,
|
MoveLeft,
|
||||||
|
@ -229,8 +262,7 @@ impl Controllable for Game {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.playfield.move_offset(-1, 0);
|
if self.playfield.move_offset(-1, 0) && !self.playfield.can_active_piece_move_down() {
|
||||||
if !self.playfield.can_active_piece_move_down() {
|
|
||||||
self.update_lock_tick();
|
self.update_lock_tick();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -240,8 +272,7 @@ impl Controllable for Game {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.playfield.move_offset(1, 0);
|
if self.playfield.move_offset(1, 0) && !self.playfield.can_active_piece_move_down() {
|
||||||
if !self.playfield.can_active_piece_move_down() {
|
|
||||||
self.update_lock_tick();
|
self.update_lock_tick();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -269,7 +300,7 @@ impl Controllable for Game {
|
||||||
active_piece.position = active_piece.position.offset(x, y);
|
active_piece.position = active_piece.position.offset(x, y);
|
||||||
active_piece.rotate_left();
|
active_piece.rotate_left();
|
||||||
self.playfield.active_piece = Some(active_piece);
|
self.playfield.active_piece = Some(active_piece);
|
||||||
// self.update_lock_tick();
|
self.update_lock_tick();
|
||||||
}
|
}
|
||||||
Err(_) => (),
|
Err(_) => (),
|
||||||
}
|
}
|
||||||
|
@ -286,7 +317,7 @@ impl Controllable for Game {
|
||||||
active_piece.position = active_piece.position.offset(x, y);
|
active_piece.position = active_piece.position.offset(x, y);
|
||||||
active_piece.rotate_right();
|
active_piece.rotate_right();
|
||||||
self.playfield.active_piece = Some(active_piece);
|
self.playfield.active_piece = Some(active_piece);
|
||||||
// self.update_lock_tick();
|
self.update_lock_tick();
|
||||||
}
|
}
|
||||||
Err(_) => (),
|
Err(_) => (),
|
||||||
}
|
}
|
||||||
|
@ -332,13 +363,13 @@ impl Controllable for Game {
|
||||||
|
|
||||||
fn get_legal_actions(&self) -> Vec<Action> {
|
fn get_legal_actions(&self) -> Vec<Action> {
|
||||||
let mut legal_actions = vec![
|
let mut legal_actions = vec![
|
||||||
|
Action::RotateLeft,
|
||||||
|
Action::RotateRight,
|
||||||
|
Action::SoftDrop,
|
||||||
|
Action::HardDrop,
|
||||||
Action::Nothing,
|
Action::Nothing,
|
||||||
Action::MoveLeft,
|
Action::MoveLeft,
|
||||||
Action::MoveRight,
|
Action::MoveRight,
|
||||||
Action::SoftDrop,
|
|
||||||
Action::HardDrop,
|
|
||||||
Action::RotateLeft,
|
|
||||||
Action::RotateRight,
|
|
||||||
];
|
];
|
||||||
|
|
||||||
if self.playfield.can_swap() {
|
if self.playfield.can_swap() {
|
||||||
|
|
281
src/main.rs
281
src/main.rs
|
@ -4,9 +4,7 @@ use cli::*;
|
||||||
use game::{Action, Controllable, Game, Tickable};
|
use game::{Action, Controllable, Game, Tickable};
|
||||||
use graphics::standard_renderer;
|
use graphics::standard_renderer;
|
||||||
use graphics::COLOR_BACKGROUND;
|
use graphics::COLOR_BACKGROUND;
|
||||||
use indicatif::ProgressIterator;
|
|
||||||
use log::{debug, info, trace};
|
use log::{debug, info, trace};
|
||||||
use qlearning::train_actor;
|
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
use sdl2::event::Event;
|
use sdl2::event::Event;
|
||||||
use sdl2::keyboard::Keycode;
|
use sdl2::keyboard::Keycode;
|
||||||
|
@ -24,42 +22,33 @@ mod tetromino;
|
||||||
|
|
||||||
const TICKS_PER_SECOND: usize = 60;
|
const TICKS_PER_SECOND: usize = 60;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main(core_threads = 16)]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let opts = crate::cli::Opts::parse();
|
let opts = crate::cli::Opts::parse();
|
||||||
|
|
||||||
init_verbosity(&opts)?;
|
init_verbosity(&opts)?;
|
||||||
let mut actor = None;
|
|
||||||
|
|
||||||
match opts.subcmd {
|
let agent = match opts.subcmd {
|
||||||
SubCommand::Play(sub_opts) => {}
|
SubCommand::Play(_) => None,
|
||||||
SubCommand::Train(sub_opts) => {
|
SubCommand::Train(sub_opts) => Some(match sub_opts.agent {
|
||||||
let mut to_train = get_actor(sub_opts.agent);
|
Agent::QLearning => {
|
||||||
to_train.set_learning_rate(sub_opts.learning_rate);
|
qlearning::train_actor(qlearning::QLearningAgent::default(), &sub_opts)
|
||||||
to_train.set_discount_rate(sub_opts.discount_rate);
|
|
||||||
to_train.set_exploration_prob(sub_opts.exploration_prob);
|
|
||||||
|
|
||||||
info!(
|
|
||||||
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
|
|
||||||
sub_opts.learning_rate,
|
|
||||||
sub_opts.discount_rate,
|
|
||||||
sub_opts.exploration_prob
|
|
||||||
);
|
|
||||||
let mut trained_actor = train_actor(sub_opts.episodes, to_train);
|
|
||||||
if sub_opts.no_explore_during_evaluation {
|
|
||||||
trained_actor.set_exploration_prob(0.0);
|
|
||||||
}
|
}
|
||||||
|
Agent::ApproximateQLearning => {
|
||||||
if sub_opts.no_learn_during_evaluation {
|
let agent =
|
||||||
trained_actor.set_learning_rate(0.0);
|
qlearning::train_actor(qlearning::ApproximateQLearning::default(), &sub_opts);
|
||||||
|
agent.dbg();
|
||||||
|
agent
|
||||||
}
|
}
|
||||||
|
Agent::HeuristicGenetic => {
|
||||||
|
let agent = genetic::train_actor(&sub_opts).await;
|
||||||
|
agent.dbg();
|
||||||
|
agent
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
actor = Some(trained_actor);
|
play_game(agent).await
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
play_game(actor).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
|
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
@ -72,125 +61,127 @@ async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std:
|
||||||
.build()?;
|
.build()?;
|
||||||
let mut canvas = window.into_canvas().build()?;
|
let mut canvas = window.into_canvas().build()?;
|
||||||
let mut event_pump = sdl_context.event_pump()?;
|
let mut event_pump = sdl_context.event_pump()?;
|
||||||
let mut game = Game::default();
|
|
||||||
let mut interval = interval(Duration::from_millis(1000 / TICKS_PER_SECOND as u64));
|
let mut interval = interval(Duration::from_millis(1000 / TICKS_PER_SECOND as u64));
|
||||||
|
|
||||||
'running: loop {
|
'escape: loop {
|
||||||
match game.is_game_over() {
|
let mut game = Game::default();
|
||||||
Some(e) => {
|
|
||||||
println!("Lost due to: {:?}", e);
|
loop {
|
||||||
break;
|
match game.is_game_over() {
|
||||||
|
Some(e) => {
|
||||||
|
println!("Lost due to: {:?}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
None => (),
|
||||||
}
|
}
|
||||||
None => (),
|
|
||||||
|
let cur_state = game.clone();
|
||||||
|
|
||||||
|
// If there's an actor, the player action will get overridden. If not,
|
||||||
|
// then then the player action falls through, if there is one. This is
|
||||||
|
// to allow for restarting and quitting the game from the GUI.
|
||||||
|
let mut action = None;
|
||||||
|
for event in event_pump.poll_iter() {
|
||||||
|
match event {
|
||||||
|
Event::Quit { .. }
|
||||||
|
| Event::KeyDown {
|
||||||
|
keycode: Some(Keycode::Escape),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
debug!("Escape registered");
|
||||||
|
break 'escape Ok(());
|
||||||
|
}
|
||||||
|
Event::KeyDown {
|
||||||
|
keycode: Some(Keycode::Left),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
debug!("Move left registered");
|
||||||
|
action = Some(Action::MoveLeft);
|
||||||
|
}
|
||||||
|
Event::KeyDown {
|
||||||
|
keycode: Some(Keycode::Right),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
debug!("Move right registered");
|
||||||
|
action = Some(Action::MoveRight);
|
||||||
|
}
|
||||||
|
Event::KeyDown {
|
||||||
|
keycode: Some(Keycode::Down),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
debug!("Soft drop registered");
|
||||||
|
action = Some(Action::SoftDrop);
|
||||||
|
}
|
||||||
|
Event::KeyDown {
|
||||||
|
keycode: Some(Keycode::Z),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
debug!("Rotate left registered");
|
||||||
|
action = Some(Action::RotateLeft);
|
||||||
|
}
|
||||||
|
Event::KeyDown {
|
||||||
|
keycode: Some(Keycode::X),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
debug!("Rotate right registered");
|
||||||
|
action = Some(Action::RotateRight);
|
||||||
|
}
|
||||||
|
Event::KeyDown {
|
||||||
|
keycode: Some(Keycode::Space),
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| Event::KeyDown {
|
||||||
|
keycode: Some(Keycode::Up),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
debug!("Hard drop registered");
|
||||||
|
action = Some(Action::HardDrop);
|
||||||
|
}
|
||||||
|
Event::KeyDown {
|
||||||
|
keycode: Some(Keycode::LShift),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
debug!("Hold registered");
|
||||||
|
action = Some(Action::Hold);
|
||||||
|
}
|
||||||
|
Event::KeyDown {
|
||||||
|
keycode: Some(Keycode::R),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
info!("Restarting game");
|
||||||
|
game = Game::default();
|
||||||
|
}
|
||||||
|
Event::KeyDown {
|
||||||
|
keycode: Some(e), ..
|
||||||
|
} => trace!("Ignoring keycode {}", e),
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
actor.as_mut().map(|actor| {
|
||||||
|
action =
|
||||||
|
Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())));
|
||||||
|
});
|
||||||
|
|
||||||
|
action.map(|action| match action {
|
||||||
|
Action::Nothing => (),
|
||||||
|
Action::MoveLeft => game.move_left(),
|
||||||
|
Action::MoveRight => game.move_right(),
|
||||||
|
Action::SoftDrop => game.move_down(),
|
||||||
|
Action::HardDrop => game.hard_drop(),
|
||||||
|
Action::Hold => game.hold(),
|
||||||
|
Action::RotateLeft => game.rotate_left(),
|
||||||
|
Action::RotateRight => game.rotate_right(),
|
||||||
|
});
|
||||||
|
|
||||||
|
game.tick();
|
||||||
|
canvas.set_draw_color(COLOR_BACKGROUND);
|
||||||
|
canvas.clear();
|
||||||
|
standard_renderer::render(&mut canvas, &game);
|
||||||
|
canvas.present();
|
||||||
|
interval.tick().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
let cur_state = (&game).into();
|
info!("Final score: {}", game.score());
|
||||||
|
|
||||||
// If there's an actor, the player action will get overridden. If not,
|
|
||||||
// then then the player action falls through, if there is one. This is
|
|
||||||
// to allow for restarting and quitting the game from the GUI.
|
|
||||||
let mut action = None;
|
|
||||||
for event in event_pump.poll_iter() {
|
|
||||||
match event {
|
|
||||||
Event::Quit { .. }
|
|
||||||
| Event::KeyDown {
|
|
||||||
keycode: Some(Keycode::Escape),
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
debug!("Escape registered");
|
|
||||||
break 'running;
|
|
||||||
}
|
|
||||||
Event::KeyDown {
|
|
||||||
keycode: Some(Keycode::Left),
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
debug!("Move left registered");
|
|
||||||
action = Some(Action::MoveLeft);
|
|
||||||
}
|
|
||||||
Event::KeyDown {
|
|
||||||
keycode: Some(Keycode::Right),
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
debug!("Move right registered");
|
|
||||||
action = Some(Action::MoveRight);
|
|
||||||
}
|
|
||||||
Event::KeyDown {
|
|
||||||
keycode: Some(Keycode::Down),
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
debug!("Soft drop registered");
|
|
||||||
action = Some(Action::SoftDrop);
|
|
||||||
}
|
|
||||||
Event::KeyDown {
|
|
||||||
keycode: Some(Keycode::Z),
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
debug!("Rotate left registered");
|
|
||||||
action = Some(Action::RotateLeft);
|
|
||||||
}
|
|
||||||
Event::KeyDown {
|
|
||||||
keycode: Some(Keycode::X),
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
debug!("Rotate right registered");
|
|
||||||
action = Some(Action::RotateRight);
|
|
||||||
}
|
|
||||||
Event::KeyDown {
|
|
||||||
keycode: Some(Keycode::Space),
|
|
||||||
..
|
|
||||||
}
|
|
||||||
| Event::KeyDown {
|
|
||||||
keycode: Some(Keycode::Up),
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
debug!("Hard drop registered");
|
|
||||||
action = Some(Action::HardDrop);
|
|
||||||
}
|
|
||||||
Event::KeyDown {
|
|
||||||
keycode: Some(Keycode::LShift),
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
debug!("Hold registered");
|
|
||||||
action = Some(Action::Hold);
|
|
||||||
}
|
|
||||||
Event::KeyDown {
|
|
||||||
keycode: Some(Keycode::R),
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
info!("Restarting game");
|
|
||||||
game = Game::default();
|
|
||||||
}
|
|
||||||
Event::KeyDown {
|
|
||||||
keycode: Some(e), ..
|
|
||||||
} => trace!("Ignoring keycode {}", e),
|
|
||||||
_ => (),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
actor.as_mut().map(|actor| {
|
|
||||||
action = Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())));
|
|
||||||
});
|
|
||||||
|
|
||||||
action.map(|action| match action {
|
|
||||||
Action::Nothing => (),
|
|
||||||
Action::MoveLeft => game.move_left(),
|
|
||||||
Action::MoveRight => game.move_right(),
|
|
||||||
Action::SoftDrop => game.move_down(),
|
|
||||||
Action::HardDrop => game.hard_drop(),
|
|
||||||
Action::Hold => game.hold(),
|
|
||||||
Action::RotateLeft => game.rotate_left(),
|
|
||||||
Action::RotateRight => game.rotate_right(),
|
|
||||||
});
|
|
||||||
|
|
||||||
game.tick();
|
|
||||||
canvas.set_draw_color(COLOR_BACKGROUND);
|
|
||||||
canvas.clear();
|
|
||||||
standard_renderer::render(&mut canvas, &game);
|
|
||||||
canvas.present();
|
|
||||||
interval.tick().await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Final score: {}", game.score());
|
|
||||||
actor.map(|a| a.dbg());
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue