Compare commits

...

3 commits

Author SHA1 Message Date
5a9e2538aa
heuristicgenetic 2020-04-06 16:07:30 -04:00
deb74da552
much refactor 2020-04-05 23:39:19 -04:00
710a7dbebb
make qlearning train_agent specific 2020-04-05 20:18:48 -04:00
9 changed files with 732 additions and 294 deletions

2
.gitignore vendored
View file

@ -1,2 +1,4 @@
/target
**/*.rs.bk
flamegraph.svg
perf.*

104
Cargo.lock generated
View file

@ -150,11 +150,88 @@ name = "fuchsia-zircon-sys"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "futures"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-executor 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "futures-channel"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "futures-core"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "futures-executor"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "futures-io"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "futures-macro"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro2 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "futures-sink"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "futures-task"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "futures-util"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-macro 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
"pin-utils 0.1.0-alpha.4 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro-nested 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
"slab 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "getrandom"
version = "0.1.13"
@ -345,6 +422,11 @@ name = "pin-project-lite"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "pin-utils"
version = "0.1.0-alpha.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "ppv-lite86"
version = "0.2.6"
@ -374,6 +456,16 @@ dependencies = [
"version_check 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "proc-macro-nested"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "proc-macro2"
version = "1.0.9"
@ -549,6 +641,7 @@ name = "tetris"
version = "0.1.0"
dependencies = [
"clap 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)",
"futures 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
@ -695,7 +788,15 @@ dependencies = [
"checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3"
"checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82"
"checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
"checksum futures 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "5c329ae8753502fb44ae4fc2b622fa2a94652c41e795143765ba0927f92ab780"
"checksum futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f0c77d04ce8edd9cb903932b608268b3fffec4163dc053b3b402bf47eac1f1a8"
"checksum futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f25592f769825e89b92358db00d26f965761e094951ac44d3663ef25b7ac464a"
"checksum futures-executor 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f674f3e1bcb15b37284a90cedf55afdba482ab061c407a9c0ebbd0f3109741ba"
"checksum futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "a638959aa96152c7a4cddf50fcb1e3fede0583b27157c26e67d6f99904090dc6"
"checksum futures-macro 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "9a5081aa3de1f7542a794a397cde100ed903b0630152d0973479018fd85423a7"
"checksum futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "3466821b4bc114d95b087b850a724c6f83115e929bc88f1fa98a3304a944c8a6"
"checksum futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "7b0a34e53cf6cdcd0178aa573aed466b646eb3db769570841fda0c7ede375a27"
"checksum futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "22766cf25d64306bedf0384da004d05c9974ab104fcc4528f1236181c18004c5"
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
"checksum heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205"
"checksum hermit-abi 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "1010591b26bbfe835e9faeabeb11866061cc7dcebffd56ad7d0942d0e61aefd8"
@ -718,9 +819,12 @@ dependencies = [
"checksum num_cpus 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "46203554f085ff89c235cd12f7075f3233af9b11ed7c9e16dfe2560d03313ce6"
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
"checksum pin-project-lite 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "237844750cfbb86f67afe27eee600dfbbcb6188d734139b534cbfbf4f96792ae"
"checksum pin-utils 0.1.0-alpha.4 (registry+https://github.com/rust-lang/crates.io-index)" = "5894c618ce612a3fa23881b152b608bafb8c56cfc22f434a3ba3120b40f7b587"
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
"checksum proc-macro-error 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7"
"checksum proc-macro-error-attr 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "8a5b4b77fdb63c1eca72173d68d24501c54ab1269409f6b672c85deb18af69de"
"checksum proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)" = "0d659fe7c6d27f25e9d80a1a094c223f5246f6a6596453e09d7229bf42750b63"
"checksum proc-macro-nested 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "8e946095f9d3ed29ec38de908c22f95d9ac008e424c7bcae54c75a79c527c694"
"checksum proc-macro2 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)" = "6c09721c6781493a2a492a96b5a5bf19b65917fe6728884e7c44dd0c60ca3435"
"checksum quote 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2bdc6c187c65bca4260c9011c9e3132efe4909da44726bad24cf7572ae338d7f"
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"

View file

@ -14,3 +14,4 @@ simple_logger = "1.6"
sdl2 = { version = "0.33.0", features = ["ttf"] }
clap = { git = "https://github.com/clap-rs/clap/", features = ["color"] }
indicatif = "0.14"
futures = "0.3"

View file

@ -1,9 +1,17 @@
// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
use super::Actor;
use super::{Actor, Predictable, State};
use crate::{
cli::Train,
game::{Action, Controllable, Game, Tickable},
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
};
use indicatif::ProgressBar;
use log::debug;
use rand::rngs::SmallRng;
use rand::Rng;
use rand::{seq::SliceRandom, Rng, SeedableRng};
#[derive(Copy, Clone, Debug)]
pub struct Parameters {
total_height: f64,
bumpiness: f64,
@ -11,6 +19,17 @@ pub struct Parameters {
complete_lines: f64,
}
impl Default for Parameters {
fn default() -> Self {
Self {
total_height: 1.0,
bumpiness: 1.0,
holes: 1.0,
complete_lines: 1.0,
}
}
}
impl Parameters {
fn mutate(mut self, rng: &mut SmallRng) {
let mutation_amt = rng.gen_range(-0.2, 0.2);
@ -33,39 +52,220 @@ impl Parameters {
self.holes /= normalization_factor;
self.complete_lines /= normalization_factor;
}
fn dot_multiply(&self, other: &Self) -> f64 {
self.total_height * other.total_height
+ self.bumpiness * other.bumpiness
+ self.holes * other.holes
+ self.complete_lines * other.complete_lines
}
}
pub struct GeneticHeuristicAgent {}
#[derive(Clone, Copy, Debug)]
pub struct GeneticHeuristicAgent {
params: Parameters,
}
impl Default for GeneticHeuristicAgent {
fn default() -> Self {
let mut rng = SmallRng::from_entropy();
Self {
params: Parameters {
total_height: rng.gen::<f64>(),
bumpiness: rng.gen::<f64>(),
holes: rng.gen::<f64>(),
complete_lines: rng.gen::<f64>(),
},
}
}
}
impl GeneticHeuristicAgent {
fn extract_features_from_state(state: &State) -> Parameters {
let mut heights = [None; PLAYFIELD_WIDTH];
for r in 0..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if heights[c].is_none() && state.matrix[r][c].is_some() {
heights[c] = Some(PLAYFIELD_HEIGHT - r);
}
}
}
let total_height = heights
.iter()
.map(|o| o.unwrap_or_else(|| 0))
.sum::<usize>() as f64;
let bumpiness = heights
.iter()
.map(|o| o.unwrap_or_else(|| 0) as isize)
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
.0 as f64;
let complete_lines = state
.matrix
.iter()
.map(|row| row.iter().all(Option::is_some))
.map(|c| if c { 1.0 } else { 0.0 })
.sum::<f64>();
let mut holes = 0;
for r in 1..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
holes += 1;
}
}
}
Parameters {
total_height,
bumpiness,
complete_lines,
holes: holes as f64,
}
}
fn get_heuristic(&self, game: &Game, action: &Action) -> f64 {
self.params.dot_multiply(&Self::extract_features_from_state(
&game.get_next_state(*action).into(),
))
}
pub fn breed(&self, self_fitness: u32, other: &Self, other_fitness: u32) -> Self {
let weight = self_fitness + other_fitness;
if weight != 0 {
let self_weight = self_fitness as f64 / weight as f64;
let other_weight = other_fitness as f64 / weight as f64;
Self {
params: Parameters {
total_height: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
bumpiness: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
holes: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
complete_lines: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
},
}
} else {
Self::default()
}
}
fn mutate(&mut self, rng: &mut SmallRng) {
self.params.mutate(rng);
}
}
impl Actor for GeneticHeuristicAgent {
fn get_action(
&self,
rng: &mut SmallRng,
state: &super::State,
legal_actions: &[crate::game::Action],
) -> crate::game::Action {
unimplemented!()
}
fn update(
&mut self,
state: super::State,
action: crate::game::Action,
next_state: super::State,
next_legal_actions: &[crate::game::Action],
reward: f64,
) {
unimplemented!()
}
fn set_learning_rate(&mut self, learning_rate: f64) {
unimplemented!()
}
fn set_exploration_prob(&mut self, exploration_prob: f64) {
unimplemented!()
}
fn set_discount_rate(&mut self, discount_rate: f64) {
unimplemented!()
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
let actions = legal_actions
.iter()
.map(|action| {
(
action,
(self.get_heuristic(game, action) * 1_000_000.0) as usize,
)
})
.collect::<Vec<_>>();
let max_val = actions
.iter()
.max_by_key(|(_, heuristic)| heuristic)
.unwrap()
.1;
*actions
.iter()
.filter(|e| e.1 == max_val)
.collect::<Vec<_>>()
.choose(rng)
.unwrap()
.0
}
fn dbg(&self) {
unimplemented!()
debug!("{:?}", self.params);
}
}
pub async fn train_actor(opts: &Train) -> Box<dyn Actor> {
use std::sync::{Arc, Mutex};
let rng = Arc::new(Mutex::new(SmallRng::from_entropy()));
let mut population = vec![(GeneticHeuristicAgent::default(), 0); 1000];
let total_pb = Arc::new(Mutex::new(ProgressBar::new(
population.len() as u64 * opts.episodes as u64,
)));
for _ in 0..opts.episodes {
let mut new_pop_futs = Vec::with_capacity(population.len());
let new_population = Arc::new(Mutex::new(Vec::with_capacity(population.len())));
let num_rounds = 10;
for i in 0..population.len() {
let rng = rng.clone();
let new_population = new_population.clone();
let total_pb = total_pb.clone();
let agent = population[i].0;
new_pop_futs.push(tokio::task::spawn(async move {
let mut fitness = 0;
for _ in 0..num_rounds {
let mut game = Game::default();
game.set_piece_limit(500);
game.set_time_limit(std::time::Duration::from_secs(60));
while (&game).is_game_over().is_none() {
let mut rng = rng.lock().expect("rng failed");
super::apply_action_to_game(
agent.get_action(
&mut rng,
&game.clone().into(),
&game.get_legal_actions(),
),
&mut game,
);
game.tick();
}
fitness += game.line_clears;
}
new_population.lock().unwrap().push((agent, fitness));
total_pb.lock().expect("progressbar failed").inc(1);
}));
}
futures::future::join_all(new_pop_futs).await;
let mut rng = SmallRng::from_entropy();
let mut new_population = new_population.lock().unwrap().clone();
let new_pop_size = population.len() * 3 / 10;
let mut breeded_population: Vec<(GeneticHeuristicAgent, u32)> =
Vec::with_capacity(new_pop_size);
for _ in 0..3 {
let mut random_selection = new_population
.choose_multiple(&mut rng, new_pop_size / 3)
.collect::<Vec<_>>();
random_selection.sort_unstable_by(|e1, e2| e1.1.cmp(&e2.1));
let best_two = random_selection.iter().rev().take(2).collect::<Vec<_>>();
let parent1 = dbg!(best_two[0]);
let parent2 = dbg!(best_two[1]);
for _ in 0..new_pop_size / 3 {
let breeded = parent1.0.breed(parent1.1, &parent2.0, parent2.1);
let mut cloned = breeded.clone();
if rng.gen::<f64>() < 0.05 {
cloned.mutate(&mut rng);
}
breeded_population.push((cloned, 0));
}
}
new_population.sort_unstable_by_key(|e| e.1);
new_population.splice(..breeded_population.len(), breeded_population);
population = new_population;
}
total_pb.lock().unwrap().finish();
population.sort_unstable_by_key(|e| e.1);
Box::new(population.iter().rev().next().unwrap().0)
}

View file

@ -40,20 +40,7 @@ impl From<PlayField> for State {
}
pub trait Actor {
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action;
fn update(
&mut self,
state: State,
action: Action,
next_state: State,
next_legal_actions: &[Action],
reward: f64,
);
fn set_learning_rate(&mut self, learning_rate: f64);
fn set_exploration_prob(&mut self, exploration_prob: f64);
fn set_discount_rate(&mut self, discount_rate: f64);
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action;
fn dbg(&self);
}
@ -80,3 +67,16 @@ impl Predictable for Game {
game
}
}
pub fn apply_action_to_game(action: Action, game: &mut Game) {
match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
}
}

View file

@ -1,16 +1,33 @@
use super::Predictable;
use crate::actors::{Actor, State};
use crate::{
cli::Train,
game::{Action, Controllable, Game, Tickable},
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
playfield::{Matrix, PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
};
use indicatif::ProgressIterator;
use log::{debug, info};
use log::{debug, info, trace};
use rand::rngs::SmallRng;
use rand::seq::SliceRandom;
use rand::Rng;
use rand::SeedableRng;
use std::collections::HashMap;
pub trait QLearningActor: Actor {
fn update(
&mut self,
game_state: Game,
action: Action,
next_game_state: Game,
next_legal_actions: &[Action],
reward: f64,
);
fn set_learning_rate(&mut self, learning_rate: f64);
fn set_exploration_prob(&mut self, exploration_prob: f64);
fn set_discount_rate(&mut self, discount_rate: f64);
}
pub struct QLearningAgent {
pub learning_rate: f64,
pub exploration_prob: f64,
@ -38,11 +55,32 @@ impl QLearningAgent {
}
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
*legal_actions
let legal_actions = legal_actions
.iter()
.map(|action| (action, self.get_q_value(state, *action)))
.collect::<Vec<_>>();
let max_val = legal_actions
.iter()
.map(|action| (action, self.get_q_value(&state, *action)))
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
.expect("Failed to select an action")
.1;
let actions_to_choose = legal_actions
.iter()
.filter(|(_, v)| max_val == *v)
.collect::<Vec<_>>();
if actions_to_choose.len() != 1 {
trace!(
"more than one best option, choosing randomly: {:?}",
actions_to_choose
);
}
*actions_to_choose
.choose(&mut SmallRng::from_entropy())
.unwrap()
.0
}
@ -63,22 +101,30 @@ impl QLearningAgent {
impl Actor for QLearningAgent {
// Because doing (Nothing) is in the set of legal actions, this will never
// be empty
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
if rng.gen::<f64>() < self.exploration_prob {
*legal_actions.choose(rng).unwrap()
} else {
self.get_action_from_q_values(state, legal_actions)
self.get_action_from_q_values(&game.into(), legal_actions)
}
}
fn dbg(&self) {
debug!("Total states: {}", self.q_values.len());
}
}
impl QLearningActor for QLearningAgent {
fn update(
&mut self,
state: State,
game_state: Game,
action: Action,
next_state: State,
next_game_state: Game,
_next_legal_actions: &[Action],
reward: f64,
) {
let state = (&game_state).into();
let next_state = (&next_game_state).into();
let cur_q_val = self.get_q_value(&state, action);
let new_q_val = cur_q_val
+ self.learning_rate
@ -93,10 +139,6 @@ impl Actor for QLearningAgent {
.insert(action, new_q_val);
}
fn dbg(&self) {
debug!("Total states: {}", self.q_values.len());
}
fn set_learning_rate(&mut self, learning_rate: f64) {
self.learning_rate = learning_rate;
}
@ -112,126 +154,187 @@ pub struct ApproximateQLearning {
pub learning_rate: f64,
pub exploration_prob: f64,
pub discount_rate: f64,
weights: HashMap<String, f64>,
weights: HashMap<Feature, f64>,
}
impl Default for ApproximateQLearning {
fn default() -> Self {
let mut weights = HashMap::default();
weights.insert(Feature::TotalHeight, 1.0);
weights.insert(Feature::Bumpiness, 1.0);
weights.insert(Feature::LinesCleared, 1.0);
weights.insert(Feature::Holes, 1.0);
Self {
learning_rate: 0.0,
exploration_prob: 0.0,
discount_rate: 0.0,
weights: HashMap::default(),
weights,
}
}
}
#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
enum Feature {
TotalHeight,
Bumpiness,
LinesCleared,
Holes,
}
impl ApproximateQLearning {
fn get_features(
&self,
state: &State,
_action: &Action,
new_state: &State,
) -> HashMap<String, f64> {
fn get_features(&self, game: &Game, action: &Action) -> HashMap<Feature, f64> {
// let game = game.get_next_state(*action);
let mut features = HashMap::default();
let mut heights = [None; PLAYFIELD_WIDTH];
for r in 0..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if heights[c].is_none() && state.matrix[r][c].is_some() {
heights[c] = Some(PLAYFIELD_HEIGHT - r);
}
}
}
let field = game.playfield().field();
let heights = self.get_heights(field);
features.insert(
"Total Height".into(),
heights
.iter()
.map(|o| o.unwrap_or_else(|| 0))
.sum::<usize>() as f64
/ (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64,
Feature::TotalHeight,
heights.iter().sum::<usize>() as f64 / (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64,
);
features.insert(
"Bumpiness".into(),
Feature::Bumpiness,
heights
.iter()
.map(|o| o.unwrap_or_else(|| 0) as isize)
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
.fold((0, 0), |(acc, prev), cur| {
(acc + (prev as isize - *cur as isize).abs(), *cur)
})
.0 as f64
/ (PLAYFIELD_WIDTH * 40) as f64,
);
features.insert(
"Lines cleared".into(),
(new_state.line_clears - state.line_clears) as f64 / 4.0,
Feature::LinesCleared,
game.playfield()
.field()
.iter()
.map(|r| r.iter().all(Option::is_some))
.map(|r| if r { 1 } else { 0 })
.sum::<usize>() as f64
/ 4.0,
);
let mut holes = 0;
for r in 1..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
if field[r][c].is_none() && field[r - 1][c].is_some() {
holes += 1;
}
}
}
features.insert("Holes".into(), holes as f64);
features.insert(Feature::Holes, holes as f64);
features
}
fn get_q_value(&self, state: &State, action: &Action, next_state: &State) -> f64 {
self.get_features(state, action, next_state)
fn get_heights(&self, matrix: &Matrix) -> Vec<usize> {
let mut heights = vec![0; matrix[0].len()];
for r in 0..matrix.len() {
for c in 0..matrix[0].len() {
if heights[c] == 0 && matrix[r][c].is_some() {
heights[c] = matrix.len() - r;
}
}
}
heights
}
fn get_q_value(&self, game: &Game, action: &Action) -> f64 {
self.get_features(game, action)
.iter()
.map(|(key, val)| val * *self.weights.get(key).unwrap_or_else(|| &0.0))
.map(|(key, val)| val * *self.weights.get(key).unwrap())
.sum()
}
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
*legal_actions
fn get_action_from_q_values(&self, game: &Game) -> Action {
let legal_actions = game.get_legal_actions();
let legal_actions = legal_actions
.iter()
.map(|action| (action, self.get_q_value(game, action)))
.collect::<Vec<_>>();
let max_val = legal_actions
.iter()
.map(|action| (action, self.get_q_value(&state, action, state)))
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
.expect("Failed to select an action")
.1;
let actions_to_choose = legal_actions
.iter()
.filter(|(_, v)| max_val == *v)
.collect::<Vec<_>>();
if actions_to_choose.len() != 1 {
trace!(
"more than one best option, choosing randomly: {:?}",
actions_to_choose
);
}
*actions_to_choose
.choose(&mut SmallRng::from_entropy())
.unwrap()
.0
}
fn get_value(&self, state: &State, legal_actions: &[Action]) -> f64 {
legal_actions
fn get_value(&self, game: &Game) -> f64 {
game.get_legal_actions()
.iter()
.map(|action| self.get_q_value(state, action, state))
.map(|action| self.get_q_value(game, action))
.max_by_key(|v| (v * 1_000_000.0) as isize)
.unwrap_or_else(|| 0.0)
}
}
#[cfg(test)]
mod aaa {
use super::*;
use crate::tetromino::TetrominoType;
#[test]
fn test_height() {
let agent = ApproximateQLearning::default();
let matrix = vec![
vec![None, None, None],
vec![None, Some(TetrominoType::T), None],
vec![None, None, None],
];
assert_eq!(agent.get_heights(&matrix), vec![0, 2, 0]);
}
}
impl Actor for ApproximateQLearning {
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
if rng.gen::<f64>() < self.exploration_prob {
*legal_actions.choose(rng).unwrap()
} else {
self.get_action_from_q_values(state, legal_actions)
self.get_action_from_q_values(game)
}
}
fn dbg(&self) {
dbg!(&self.weights);
}
}
impl QLearningActor for ApproximateQLearning {
fn update(
&mut self,
state: State,
game_state: Game,
action: Action,
next_state: State,
next_game_state: Game,
next_legal_actions: &[Action],
reward: f64,
) {
let difference = reward
+ self.discount_rate * self.get_value(&next_state, next_legal_actions)
- self.get_q_value(&state, &action, &next_state);
let difference = reward + self.discount_rate * self.get_value(&next_game_state)
- self.get_q_value(&game_state, &action);
for (feat_key, feat_val) in self.get_features(&state, &action, &next_state) {
for (feat_key, feat_val) in self.get_features(&game_state, &action) {
self.weights.insert(
feat_key.clone(),
*self.weights.get(&feat_key).unwrap_or_else(|| &0.0)
+ self.learning_rate * difference * feat_val,
*self.weights.get(&feat_key).unwrap() + self.learning_rate * difference * feat_val,
);
}
}
@ -247,15 +350,23 @@ impl Actor for ApproximateQLearning {
fn set_discount_rate(&mut self, discount_rate: f64) {
self.discount_rate = discount_rate;
}
fn dbg(&self) {
dbg!(&self.weights);
}
}
pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
pub fn train_actor<T: 'static + QLearningActor + Actor>(
mut actor: T,
opts: &Train,
) -> Box<dyn Actor> {
let mut rng = SmallRng::from_entropy();
let mut avg = 0.0;
let episodes = opts.episodes;
actor.set_learning_rate(opts.learning_rate);
actor.set_discount_rate(opts.discount_rate);
actor.set_exploration_prob(opts.exploration_prob);
info!(
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
opts.learning_rate, opts.discount_rate, opts.exploration_prob
);
for i in (0..episodes).progress() {
if i != 0 && i % (episodes / 10) == 0 {
@ -265,26 +376,17 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor>
}
let mut game = Game::default();
while (&game).is_game_over().is_none() {
let cur_state = (&game).into();
let cur_state = game.clone();
let cur_score = game.score();
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
}
super::apply_action_to_game(action, &mut game);
let new_state = (&game).into();
let new_state = game.clone();
let mut reward = game.score() as f64 - cur_score as f64;
if action != Action::Nothing {
reward -= 0.0;
}
// if action != Action::Nothing {
// reward -= 0.0;
// }
if game.is_game_over().is_some() {
reward = -1.0;
@ -292,7 +394,13 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor>
let new_legal_actions = game.get_legal_actions();
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
actor.update(
cur_state.into(),
action,
new_state,
&new_legal_actions,
reward,
);
game.tick();
}
@ -300,5 +408,12 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor>
avg += game.score() as f64 / (episodes / 10) as f64;
}
actor
if opts.no_explore_during_evaluation {
actor.set_exploration_prob(0.0);
}
if opts.no_learn_during_evaluation {
actor.set_learning_rate(0.0);
}
Box::new(actor)
}

View file

@ -74,6 +74,7 @@ arg_enum! {
pub enum Agent {
QLearning,
ApproximateQLearning,
HeuristicGenetic
}
}
@ -89,10 +90,3 @@ pub fn init_verbosity(opts: &Opts) -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
pub fn get_actor(agent: Agent) -> Box<dyn Actor> {
match agent {
Agent::QLearning => Box::new(qlearning::QLearningAgent::default()),
Agent::ApproximateQLearning => Box::new(qlearning::ApproximateQLearning::default()),
}
}

View file

@ -16,6 +16,8 @@ const LINE_CLEAR_DELAY: u64 = TICKS_PER_SECOND as u64 * 41 / 60;
pub enum LossReason {
TopOut,
LockOut,
PieceLimitReached,
TickLimitReached,
BlockOut(Position),
}
@ -35,6 +37,12 @@ pub struct Game {
/// bonus is needed.
last_clear_action: ClearAction,
pub line_clears: u32,
// used if we set a limit on how long a game can last.
pieces_placed: usize,
piece_limit: usize,
// used if we set a limit on how long the game can be played.
tick_limit: u64,
}
impl fmt::Debug for Game {
@ -59,6 +67,9 @@ impl Default for Game {
is_game_over: None,
last_clear_action: ClearAction::Single, // Doesn't matter what it's initialized to
line_clears: 0,
pieces_placed: 0,
piece_limit: 0,
tick_limit: 0,
}
}
}
@ -73,6 +84,7 @@ impl Tickable for Game {
return;
}
self.tick += 1;
match self.tick {
t if t == self.next_spawn_tick => {
trace!("Spawn tick was met, spawning new Tetromino!");
@ -99,6 +111,10 @@ impl Tickable for Game {
}
_ => (),
}
if self.tick == self.tick_limit {
self.is_game_over = Some(LossReason::TickLimitReached);
}
}
}
@ -168,10 +184,18 @@ impl Game {
// It's possible that the player moved the piece in the meantime.
if !self.playfield.can_active_piece_move_down() {
let positions = self.playfield.lock_active_piece();
if self.pieces_placed < self.piece_limit {
self.pieces_placed += 1;
if self.pieces_placed >= self.piece_limit {
trace!("Loss due to piece limit!");
self.is_game_over = Some(LossReason::PieceLimitReached);
}
}
self.is_game_over = self.is_game_over.or_else(|| {
if positions.iter().map(|p| p.y).all(|y| y < 20) {
trace!("Loss due to topout! {:?}", positions);
Some(LossReason::TopOut)
trace!("Loss due to lockout! {:?}", positions);
Some(LossReason::LockOut)
} else {
None
}
@ -183,6 +207,7 @@ impl Game {
self.line_clears += cleared_lines as u32;
self.score += (cleared_lines * 100 * self.level as usize) as u32;
self.level = (self.line_clears / 10) as u8;
self.playfield.active_piece = None;
self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY;
} else {
@ -195,9 +220,17 @@ impl Game {
}
}
pub fn set_piece_limit(&mut self, size: usize) {
self.piece_limit = size;
}
pub fn playfield(&self) -> &PlayField {
&self.playfield
}
pub fn set_time_limit(&mut self, duration: std::time::Duration) {
self.tick_limit = duration.as_secs() * TICKS_PER_SECOND as u64;
}
}
pub trait Controllable {
@ -211,7 +244,7 @@ pub trait Controllable {
fn get_legal_actions(&self) -> Vec<Action>;
}
#[derive(Hash, Eq, PartialEq, Debug, Copy, Clone)]
#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
pub enum Action {
Nothing, // Default value
MoveLeft,
@ -229,8 +262,7 @@ impl Controllable for Game {
return;
}
self.playfield.move_offset(-1, 0);
if !self.playfield.can_active_piece_move_down() {
if self.playfield.move_offset(-1, 0) && !self.playfield.can_active_piece_move_down() {
self.update_lock_tick();
}
}
@ -240,8 +272,7 @@ impl Controllable for Game {
return;
}
self.playfield.move_offset(1, 0);
if !self.playfield.can_active_piece_move_down() {
if self.playfield.move_offset(1, 0) && !self.playfield.can_active_piece_move_down() {
self.update_lock_tick();
}
}
@ -269,7 +300,7 @@ impl Controllable for Game {
active_piece.position = active_piece.position.offset(x, y);
active_piece.rotate_left();
self.playfield.active_piece = Some(active_piece);
// self.update_lock_tick();
self.update_lock_tick();
}
Err(_) => (),
}
@ -286,7 +317,7 @@ impl Controllable for Game {
active_piece.position = active_piece.position.offset(x, y);
active_piece.rotate_right();
self.playfield.active_piece = Some(active_piece);
// self.update_lock_tick();
self.update_lock_tick();
}
Err(_) => (),
}
@ -332,13 +363,13 @@ impl Controllable for Game {
fn get_legal_actions(&self) -> Vec<Action> {
let mut legal_actions = vec![
Action::RotateLeft,
Action::RotateRight,
Action::SoftDrop,
Action::HardDrop,
Action::Nothing,
Action::MoveLeft,
Action::MoveRight,
Action::SoftDrop,
Action::HardDrop,
Action::RotateLeft,
Action::RotateRight,
];
if self.playfield.can_swap() {

View file

@ -4,9 +4,7 @@ use cli::*;
use game::{Action, Controllable, Game, Tickable};
use graphics::standard_renderer;
use graphics::COLOR_BACKGROUND;
use indicatif::ProgressIterator;
use log::{debug, info, trace};
use qlearning::train_actor;
use rand::SeedableRng;
use sdl2::event::Event;
use sdl2::keyboard::Keycode;
@ -24,42 +22,33 @@ mod tetromino;
const TICKS_PER_SECOND: usize = 60;
#[tokio::main]
#[tokio::main(core_threads = 16)]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let opts = crate::cli::Opts::parse();
init_verbosity(&opts)?;
let mut actor = None;
match opts.subcmd {
SubCommand::Play(sub_opts) => {}
SubCommand::Train(sub_opts) => {
let mut to_train = get_actor(sub_opts.agent);
to_train.set_learning_rate(sub_opts.learning_rate);
to_train.set_discount_rate(sub_opts.discount_rate);
to_train.set_exploration_prob(sub_opts.exploration_prob);
info!(
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
sub_opts.learning_rate,
sub_opts.discount_rate,
sub_opts.exploration_prob
);
let mut trained_actor = train_actor(sub_opts.episodes, to_train);
if sub_opts.no_explore_during_evaluation {
trained_actor.set_exploration_prob(0.0);
let agent = match opts.subcmd {
SubCommand::Play(_) => None,
SubCommand::Train(sub_opts) => Some(match sub_opts.agent {
Agent::QLearning => {
qlearning::train_actor(qlearning::QLearningAgent::default(), &sub_opts)
}
if sub_opts.no_learn_during_evaluation {
trained_actor.set_learning_rate(0.0);
Agent::ApproximateQLearning => {
let agent =
qlearning::train_actor(qlearning::ApproximateQLearning::default(), &sub_opts);
agent.dbg();
agent
}
Agent::HeuristicGenetic => {
let agent = genetic::train_actor(&sub_opts).await;
agent.dbg();
agent
}
}),
};
actor = Some(trained_actor);
}
}
play_game(actor).await?;
Ok(())
play_game(agent).await
}
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
@ -72,125 +61,127 @@ async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std:
.build()?;
let mut canvas = window.into_canvas().build()?;
let mut event_pump = sdl_context.event_pump()?;
let mut game = Game::default();
let mut interval = interval(Duration::from_millis(1000 / TICKS_PER_SECOND as u64));
'running: loop {
match game.is_game_over() {
Some(e) => {
println!("Lost due to: {:?}", e);
break;
'escape: loop {
let mut game = Game::default();
loop {
match game.is_game_over() {
Some(e) => {
println!("Lost due to: {:?}", e);
break;
}
None => (),
}
None => (),
let cur_state = game.clone();
// If there's an actor, the player action will get overridden. If not,
// then then the player action falls through, if there is one. This is
// to allow for restarting and quitting the game from the GUI.
let mut action = None;
for event in event_pump.poll_iter() {
match event {
Event::Quit { .. }
| Event::KeyDown {
keycode: Some(Keycode::Escape),
..
} => {
debug!("Escape registered");
break 'escape Ok(());
}
Event::KeyDown {
keycode: Some(Keycode::Left),
..
} => {
debug!("Move left registered");
action = Some(Action::MoveLeft);
}
Event::KeyDown {
keycode: Some(Keycode::Right),
..
} => {
debug!("Move right registered");
action = Some(Action::MoveRight);
}
Event::KeyDown {
keycode: Some(Keycode::Down),
..
} => {
debug!("Soft drop registered");
action = Some(Action::SoftDrop);
}
Event::KeyDown {
keycode: Some(Keycode::Z),
..
} => {
debug!("Rotate left registered");
action = Some(Action::RotateLeft);
}
Event::KeyDown {
keycode: Some(Keycode::X),
..
} => {
debug!("Rotate right registered");
action = Some(Action::RotateRight);
}
Event::KeyDown {
keycode: Some(Keycode::Space),
..
}
| Event::KeyDown {
keycode: Some(Keycode::Up),
..
} => {
debug!("Hard drop registered");
action = Some(Action::HardDrop);
}
Event::KeyDown {
keycode: Some(Keycode::LShift),
..
} => {
debug!("Hold registered");
action = Some(Action::Hold);
}
Event::KeyDown {
keycode: Some(Keycode::R),
..
} => {
info!("Restarting game");
game = Game::default();
}
Event::KeyDown {
keycode: Some(e), ..
} => trace!("Ignoring keycode {}", e),
_ => (),
}
}
actor.as_mut().map(|actor| {
action =
Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())));
});
action.map(|action| match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
});
game.tick();
canvas.set_draw_color(COLOR_BACKGROUND);
canvas.clear();
standard_renderer::render(&mut canvas, &game);
canvas.present();
interval.tick().await;
}
let cur_state = (&game).into();
// If there's an actor, the player action will get overridden. If not,
// then then the player action falls through, if there is one. This is
// to allow for restarting and quitting the game from the GUI.
let mut action = None;
for event in event_pump.poll_iter() {
match event {
Event::Quit { .. }
| Event::KeyDown {
keycode: Some(Keycode::Escape),
..
} => {
debug!("Escape registered");
break 'running;
}
Event::KeyDown {
keycode: Some(Keycode::Left),
..
} => {
debug!("Move left registered");
action = Some(Action::MoveLeft);
}
Event::KeyDown {
keycode: Some(Keycode::Right),
..
} => {
debug!("Move right registered");
action = Some(Action::MoveRight);
}
Event::KeyDown {
keycode: Some(Keycode::Down),
..
} => {
debug!("Soft drop registered");
action = Some(Action::SoftDrop);
}
Event::KeyDown {
keycode: Some(Keycode::Z),
..
} => {
debug!("Rotate left registered");
action = Some(Action::RotateLeft);
}
Event::KeyDown {
keycode: Some(Keycode::X),
..
} => {
debug!("Rotate right registered");
action = Some(Action::RotateRight);
}
Event::KeyDown {
keycode: Some(Keycode::Space),
..
}
| Event::KeyDown {
keycode: Some(Keycode::Up),
..
} => {
debug!("Hard drop registered");
action = Some(Action::HardDrop);
}
Event::KeyDown {
keycode: Some(Keycode::LShift),
..
} => {
debug!("Hold registered");
action = Some(Action::Hold);
}
Event::KeyDown {
keycode: Some(Keycode::R),
..
} => {
info!("Restarting game");
game = Game::default();
}
Event::KeyDown {
keycode: Some(e), ..
} => trace!("Ignoring keycode {}", e),
_ => (),
}
}
actor.as_mut().map(|actor| {
action = Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())));
});
action.map(|action| match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
});
game.tick();
canvas.set_draw_color(COLOR_BACKGROUND);
canvas.clear();
standard_renderer::render(&mut canvas, &game);
canvas.present();
interval.tick().await;
info!("Final score: {}", game.score());
}
info!("Final score: {}", game.score());
actor.map(|a| a.dbg());
Ok(())
}