From 5a9e2538aa08e796614cfb8193d18a16a29b7393 Mon Sep 17 00:00:00 2001 From: Edward Shen Date: Mon, 6 Apr 2020 16:07:30 -0400 Subject: [PATCH] heuristicgenetic --- Cargo.lock | 104 +++++++++++++++++ Cargo.toml | 3 +- src/actors/genetic.rs | 135 +++++++++++++++------- src/actors/qlearning.rs | 201 ++++++++++++++++++++++---------- src/game.rs | 34 ++++-- src/main.rs | 246 +++++++++++++++++++++------------------- 6 files changed, 490 insertions(+), 233 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dde6bd9..26ca0e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -150,11 +150,88 @@ name = "fuchsia-zircon-sys" version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "futures" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-executor 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "futures-channel" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "futures-core" version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "futures-executor" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "futures-io" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "futures-macro" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro2 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "futures-sink" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "futures-task" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "futures-util" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-macro 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", + "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", + "pin-utils 0.1.0-alpha.4 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro-nested 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)", + "slab 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "getrandom" version = "0.1.13" @@ -345,6 +422,11 @@ name = "pin-project-lite" version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "pin-utils" +version = "0.1.0-alpha.4" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "ppv-lite86" version = "0.2.6" @@ -374,6 +456,16 @@ dependencies = [ "version_check 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "proc-macro-hack" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "proc-macro-nested" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "proc-macro2" version = "1.0.9" @@ -549,6 +641,7 @@ name = "tetris" version = "0.1.0" dependencies = [ "clap 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)", + "futures 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", "indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", @@ -695,7 +788,15 @@ dependencies = [ "checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3" "checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82" "checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7" +"checksum futures 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "5c329ae8753502fb44ae4fc2b622fa2a94652c41e795143765ba0927f92ab780" +"checksum futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f0c77d04ce8edd9cb903932b608268b3fffec4163dc053b3b402bf47eac1f1a8" "checksum futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f25592f769825e89b92358db00d26f965761e094951ac44d3663ef25b7ac464a" +"checksum futures-executor 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f674f3e1bcb15b37284a90cedf55afdba482ab061c407a9c0ebbd0f3109741ba" +"checksum futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "a638959aa96152c7a4cddf50fcb1e3fede0583b27157c26e67d6f99904090dc6" +"checksum futures-macro 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "9a5081aa3de1f7542a794a397cde100ed903b0630152d0973479018fd85423a7" +"checksum futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "3466821b4bc114d95b087b850a724c6f83115e929bc88f1fa98a3304a944c8a6" +"checksum futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "7b0a34e53cf6cdcd0178aa573aed466b646eb3db769570841fda0c7ede375a27" +"checksum futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "22766cf25d64306bedf0384da004d05c9974ab104fcc4528f1236181c18004c5" "checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407" "checksum heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205" "checksum hermit-abi 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "1010591b26bbfe835e9faeabeb11866061cc7dcebffd56ad7d0942d0e61aefd8" @@ -718,9 +819,12 @@ dependencies = [ "checksum num_cpus 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "46203554f085ff89c235cd12f7075f3233af9b11ed7c9e16dfe2560d03313ce6" "checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a" "checksum pin-project-lite 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "237844750cfbb86f67afe27eee600dfbbcb6188d734139b534cbfbf4f96792ae" +"checksum pin-utils 0.1.0-alpha.4 (registry+https://github.com/rust-lang/crates.io-index)" = "5894c618ce612a3fa23881b152b608bafb8c56cfc22f434a3ba3120b40f7b587" "checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" "checksum proc-macro-error 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7" "checksum proc-macro-error-attr 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "8a5b4b77fdb63c1eca72173d68d24501c54ab1269409f6b672c85deb18af69de" +"checksum proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)" = "0d659fe7c6d27f25e9d80a1a094c223f5246f6a6596453e09d7229bf42750b63" +"checksum proc-macro-nested 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "8e946095f9d3ed29ec38de908c22f95d9ac008e424c7bcae54c75a79c527c694" "checksum proc-macro2 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)" = "6c09721c6781493a2a492a96b5a5bf19b65917fe6728884e7c44dd0c60ca3435" "checksum quote 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2bdc6c187c65bca4260c9011c9e3132efe4909da44726bad24cf7572ae338d7f" "checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412" diff --git a/Cargo.toml b/Cargo.toml index 456d62e..e695899 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,4 +13,5 @@ log = "0.4" simple_logger = "1.6" sdl2 = { version = "0.33.0", features = ["ttf"] } clap = { git = "https://github.com/clap-rs/clap/", features = ["color"] } -indicatif = "0.14" \ No newline at end of file +indicatif = "0.14" +futures = "0.3" \ No newline at end of file diff --git a/src/actors/genetic.rs b/src/actors/genetic.rs index fce6799..cf23488 100644 --- a/src/actors/genetic.rs +++ b/src/actors/genetic.rs @@ -1,11 +1,13 @@ // https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/ -use super::{Actor, State}; +use super::{Actor, Predictable, State}; use crate::{ cli::Train, - game::{Action, Controllable, Game}, + game::{Action, Controllable, Game, Tickable}, playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, }; +use indicatif::ProgressBar; +use log::debug; use rand::rngs::SmallRng; use rand::{seq::SliceRandom, Rng, SeedableRng}; @@ -59,15 +61,21 @@ impl Parameters { } } -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub struct GeneticHeuristicAgent { params: Parameters, } impl Default for GeneticHeuristicAgent { fn default() -> Self { + let mut rng = SmallRng::from_entropy(); Self { - params: Parameters::default(), + params: Parameters { + total_height: rng.gen::(), + bumpiness: rng.gen::(), + holes: rng.gen::(), + complete_lines: rng.gen::(), + }, } } } @@ -119,24 +127,31 @@ impl GeneticHeuristicAgent { } fn get_heuristic(&self, game: &Game, action: &Action) -> f64 { - todo!(); + self.params.dot_multiply(&Self::extract_features_from_state( + &game.get_next_state(*action).into(), + )) } pub fn breed(&self, self_fitness: u32, other: &Self, other_fitness: u32) -> Self { - let weight = (self_fitness + other_fitness) as f64; - let self_weight = self_fitness as f64 / weight; - let other_weight = other_fitness as f64 / weight; - Self { - params: Parameters { - total_height: self.params.total_height * self_weight - + other.params.total_height * other_weight, - bumpiness: self.params.total_height * self_weight - + other.params.total_height * other_weight, - holes: self.params.total_height * self_weight - + other.params.total_height * other_weight, - complete_lines: self.params.total_height * self_weight - + other.params.total_height * other_weight, - }, + let weight = self_fitness + other_fitness; + + if weight != 0 { + let self_weight = self_fitness as f64 / weight as f64; + let other_weight = other_fitness as f64 / weight as f64; + Self { + params: Parameters { + total_height: self.params.total_height * self_weight + + other.params.total_height * other_weight, + bumpiness: self.params.total_height * self_weight + + other.params.total_height * other_weight, + holes: self.params.total_height * self_weight + + other.params.total_height * other_weight, + complete_lines: self.params.total_height * self_weight + + other.params.total_height * other_weight, + }, + } + } else { + Self::default() } } @@ -146,33 +161,62 @@ impl GeneticHeuristicAgent { } impl Actor for GeneticHeuristicAgent { - fn get_action(&self, _: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action { - *legal_actions + fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action { + let actions = legal_actions .iter() - .map(|action| (action, self.get_heuristic(game, action))) - .max_by_key(|(_, heuristic)| (heuristic * 1_000_00.0) as usize) + .map(|action| { + ( + action, + (self.get_heuristic(game, action) * 1_000_000.0) as usize, + ) + }) + .collect::>(); + + let max_val = actions + .iter() + .max_by_key(|(_, heuristic)| heuristic) + .unwrap() + .1; + + *actions + .iter() + .filter(|e| e.1 == max_val) + .collect::>() + .choose(rng) .unwrap() .0 } fn dbg(&self) { - // unimplemented!() + debug!("{:?}", self.params); } } -pub fn train_actor(opts: &Train) -> Box { - let mut rng = SmallRng::from_entropy(); +pub async fn train_actor(opts: &Train) -> Box { + use std::sync::{Arc, Mutex}; + let rng = Arc::new(Mutex::new(SmallRng::from_entropy())); let mut population = vec![(GeneticHeuristicAgent::default(), 0); 1000]; + let total_pb = Arc::new(Mutex::new(ProgressBar::new( + population.len() as u64 * opts.episodes as u64, + ))); for _ in 0..opts.episodes { - let mut new_population: Vec<(GeneticHeuristicAgent, u32)> = population - .iter() - .map(|(agent, _)| { + let mut new_pop_futs = Vec::with_capacity(population.len()); + let new_population = Arc::new(Mutex::new(Vec::with_capacity(population.len()))); + let num_rounds = 10; + for i in 0..population.len() { + let rng = rng.clone(); + let new_population = new_population.clone(); + let total_pb = total_pb.clone(); + let agent = population[i].0; + new_pop_futs.push(tokio::task::spawn(async move { let mut fitness = 0; - for _ in 0..100 { + for _ in 0..num_rounds { let mut game = Game::default(); game.set_piece_limit(500); + game.set_time_limit(std::time::Duration::from_secs(60)); while (&game).is_game_over().is_none() { + let mut rng = rng.lock().expect("rng failed"); super::apply_action_to_game( agent.get_action( &mut rng, @@ -180,24 +224,33 @@ pub fn train_actor(opts: &Train) -> Box { &game.get_legal_actions(), ), &mut game, - ) + ); + game.tick(); } fitness += game.line_clears; } - (*agent, fitness) - }) - .collect::>(); + new_population.lock().unwrap().push((agent, fitness)); - let mut breeded_population: Vec<(GeneticHeuristicAgent, u32)> = Vec::with_capacity(300); + total_pb.lock().expect("progressbar failed").inc(1); + })); + } + futures::future::join_all(new_pop_futs).await; + + let mut rng = SmallRng::from_entropy(); + + let mut new_population = new_population.lock().unwrap().clone(); + let new_pop_size = population.len() * 3 / 10; + let mut breeded_population: Vec<(GeneticHeuristicAgent, u32)> = + Vec::with_capacity(new_pop_size); for _ in 0..3 { let mut random_selection = new_population - .choose_multiple(&mut rng, 100) + .choose_multiple(&mut rng, new_pop_size / 3) .collect::>(); random_selection.sort_unstable_by(|e1, e2| e1.1.cmp(&e2.1)); let best_two = random_selection.iter().rev().take(2).collect::>(); - let parent1 = best_two[0]; - let parent2 = best_two[1]; - for _ in 0..100 { + let parent1 = dbg!(best_two[0]); + let parent2 = dbg!(best_two[1]); + for _ in 0..new_pop_size / 3 { let breeded = parent1.0.breed(parent1.1, &parent2.0, parent2.1); let mut cloned = breeded.clone(); if rng.gen::() < 0.05 { @@ -208,11 +261,11 @@ pub fn train_actor(opts: &Train) -> Box { } new_population.sort_unstable_by_key(|e| e.1); - new_population.splice(..breeded_population.len(), breeded_population); - population = new_population; } + total_pb.lock().unwrap().finish(); + population.sort_unstable_by_key(|e| e.1); Box::new(population.iter().rev().next().unwrap().0) } diff --git a/src/actors/qlearning.rs b/src/actors/qlearning.rs index b5b2087..540a903 100644 --- a/src/actors/qlearning.rs +++ b/src/actors/qlearning.rs @@ -1,11 +1,12 @@ +use super::Predictable; use crate::actors::{Actor, State}; use crate::{ cli::Train, game::{Action, Controllable, Game, Tickable}, - playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, + playfield::{Matrix, PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, }; use indicatif::ProgressIterator; -use log::{debug, info}; +use log::{debug, info, trace}; use rand::rngs::SmallRng; use rand::seq::SliceRandom; use rand::Rng; @@ -15,9 +16,9 @@ use std::collections::HashMap; pub trait QLearningActor: Actor { fn update( &mut self, - state: State, + game_state: Game, action: Action, - next_state: State, + next_game_state: Game, next_legal_actions: &[Action], reward: f64, ); @@ -54,11 +55,32 @@ impl QLearningAgent { } fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action { - *legal_actions + let legal_actions = legal_actions + .iter() + .map(|action| (action, self.get_q_value(state, *action))) + .collect::>(); + + let max_val = legal_actions .iter() - .map(|action| (action, self.get_q_value(&state, *action))) .max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize)) .expect("Failed to select an action") + .1; + + let actions_to_choose = legal_actions + .iter() + .filter(|(_, v)| max_val == *v) + .collect::>(); + + if actions_to_choose.len() != 1 { + trace!( + "more than one best option, choosing randomly: {:?}", + actions_to_choose + ); + } + + *actions_to_choose + .choose(&mut SmallRng::from_entropy()) + .unwrap() .0 } @@ -95,12 +117,14 @@ impl Actor for QLearningAgent { impl QLearningActor for QLearningAgent { fn update( &mut self, - state: State, + game_state: Game, action: Action, - next_state: State, + next_game_state: Game, _next_legal_actions: &[Action], reward: f64, ) { + let state = (&game_state).into(); + let next_state = (&next_game_state).into(); let cur_q_val = self.get_q_value(&state, action); let new_q_val = cur_q_val + self.learning_rate @@ -130,106 +154,163 @@ pub struct ApproximateQLearning { pub learning_rate: f64, pub exploration_prob: f64, pub discount_rate: f64, - weights: HashMap, + weights: HashMap, } impl Default for ApproximateQLearning { fn default() -> Self { + let mut weights = HashMap::default(); + weights.insert(Feature::TotalHeight, 1.0); + weights.insert(Feature::Bumpiness, 1.0); + weights.insert(Feature::LinesCleared, 1.0); + weights.insert(Feature::Holes, 1.0); + Self { learning_rate: 0.0, exploration_prob: 0.0, discount_rate: 0.0, - weights: HashMap::default(), + weights, } } } +#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)] +enum Feature { + TotalHeight, + Bumpiness, + LinesCleared, + Holes, +} impl ApproximateQLearning { - fn get_features( - &self, - state: &State, - _action: &Action, - new_state: &State, - ) -> HashMap { + fn get_features(&self, game: &Game, action: &Action) -> HashMap { + // let game = game.get_next_state(*action); + let mut features = HashMap::default(); - - let mut heights = [None; PLAYFIELD_WIDTH]; - for r in 0..PLAYFIELD_HEIGHT { - for c in 0..PLAYFIELD_WIDTH { - if heights[c].is_none() && state.matrix[r][c].is_some() { - heights[c] = Some(PLAYFIELD_HEIGHT - r); - } - } - } - + let field = game.playfield().field(); + let heights = self.get_heights(field); features.insert( - "Total Height".into(), - heights - .iter() - .map(|o| o.unwrap_or_else(|| 0)) - .sum::() as f64 - / (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64, + Feature::TotalHeight, + heights.iter().sum::() as f64 / (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64, ); features.insert( - "Bumpiness".into(), + Feature::Bumpiness, heights .iter() - .map(|o| o.unwrap_or_else(|| 0) as isize) - .fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur)) + .fold((0, 0), |(acc, prev), cur| { + (acc + (prev as isize - *cur as isize).abs(), *cur) + }) .0 as f64 / (PLAYFIELD_WIDTH * 40) as f64, ); features.insert( - "Lines cleared".into(), - (new_state.line_clears - state.line_clears) as f64 / 4.0, + Feature::LinesCleared, + game.playfield() + .field() + .iter() + .map(|r| r.iter().all(Option::is_some)) + .map(|r| if r { 1 } else { 0 }) + .sum::() as f64 + / 4.0, ); let mut holes = 0; for r in 1..PLAYFIELD_HEIGHT { for c in 0..PLAYFIELD_WIDTH { - if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() { + if field[r][c].is_none() && field[r - 1][c].is_some() { holes += 1; } } } - features.insert("Holes".into(), holes as f64); + features.insert(Feature::Holes, holes as f64); features } - fn get_q_value(&self, state: &State, action: &Action, next_state: &State) -> f64 { - self.get_features(state, action, next_state) + fn get_heights(&self, matrix: &Matrix) -> Vec { + let mut heights = vec![0; matrix[0].len()]; + for r in 0..matrix.len() { + for c in 0..matrix[0].len() { + if heights[c] == 0 && matrix[r][c].is_some() { + heights[c] = matrix.len() - r; + } + } + } + + heights + } + + fn get_q_value(&self, game: &Game, action: &Action) -> f64 { + self.get_features(game, action) .iter() - .map(|(key, val)| val * *self.weights.get(key).unwrap_or_else(|| &0.0)) + .map(|(key, val)| val * *self.weights.get(key).unwrap()) .sum() } - fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action { - *legal_actions + fn get_action_from_q_values(&self, game: &Game) -> Action { + let legal_actions = game.get_legal_actions(); + let legal_actions = legal_actions + .iter() + .map(|action| (action, self.get_q_value(game, action))) + .collect::>(); + + let max_val = legal_actions .iter() - .map(|action| (action, self.get_q_value(&state, action, state))) .max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize)) .expect("Failed to select an action") + .1; + + let actions_to_choose = legal_actions + .iter() + .filter(|(_, v)| max_val == *v) + .collect::>(); + + if actions_to_choose.len() != 1 { + trace!( + "more than one best option, choosing randomly: {:?}", + actions_to_choose + ); + } + + *actions_to_choose + .choose(&mut SmallRng::from_entropy()) + .unwrap() .0 } - fn get_value(&self, state: &State, legal_actions: &[Action]) -> f64 { - legal_actions + fn get_value(&self, game: &Game) -> f64 { + game.get_legal_actions() .iter() - .map(|action| self.get_q_value(state, action, state)) + .map(|action| self.get_q_value(game, action)) .max_by_key(|v| (v * 1_000_000.0) as isize) .unwrap_or_else(|| 0.0) } } +#[cfg(test)] +mod aaa { + use super::*; + use crate::tetromino::TetrominoType; + + #[test] + fn test_height() { + let agent = ApproximateQLearning::default(); + let matrix = vec![ + vec![None, None, None], + vec![None, Some(TetrominoType::T), None], + vec![None, None, None], + ]; + assert_eq!(agent.get_heights(&matrix), vec![0, 2, 0]); + } +} + impl Actor for ApproximateQLearning { fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action { if rng.gen::() < self.exploration_prob { *legal_actions.choose(rng).unwrap() } else { - self.get_action_from_q_values(&game.into(), legal_actions) + self.get_action_from_q_values(game) } } @@ -241,21 +322,19 @@ impl Actor for ApproximateQLearning { impl QLearningActor for ApproximateQLearning { fn update( &mut self, - state: State, + game_state: Game, action: Action, - next_state: State, + next_game_state: Game, next_legal_actions: &[Action], reward: f64, ) { - let difference = reward - + self.discount_rate * self.get_value(&next_state, next_legal_actions) - - self.get_q_value(&state, &action, &next_state); + let difference = reward + self.discount_rate * self.get_value(&next_game_state) + - self.get_q_value(&game_state, &action); - for (feat_key, feat_val) in self.get_features(&state, &action, &next_state) { + for (feat_key, feat_val) in self.get_features(&game_state, &action) { self.weights.insert( feat_key.clone(), - *self.weights.get(&feat_key).unwrap_or_else(|| &0.0) - + self.learning_rate * difference * feat_val, + *self.weights.get(&feat_key).unwrap() + self.learning_rate * difference * feat_val, ); } } @@ -303,11 +382,11 @@ pub fn train_actor( super::apply_action_to_game(action, &mut game); - let new_state = (&game).into(); + let new_state = game.clone(); let mut reward = game.score() as f64 - cur_score as f64; - if action != Action::Nothing { - reward -= 0.0; - } + // if action != Action::Nothing { + // reward -= 0.0; + // } if game.is_game_over().is_some() { reward = -1.0; diff --git a/src/game.rs b/src/game.rs index 2232e54..ed4ff83 100644 --- a/src/game.rs +++ b/src/game.rs @@ -17,6 +17,7 @@ pub enum LossReason { TopOut, LockOut, PieceLimitReached, + TickLimitReached, BlockOut(Position), } @@ -39,6 +40,9 @@ pub struct Game { // used if we set a limit on how long a game can last. pieces_placed: usize, piece_limit: usize, + + // used if we set a limit on how long the game can be played. + tick_limit: u64, } impl fmt::Debug for Game { @@ -65,6 +69,7 @@ impl Default for Game { line_clears: 0, pieces_placed: 0, piece_limit: 0, + tick_limit: 0, } } } @@ -79,6 +84,7 @@ impl Tickable for Game { return; } self.tick += 1; + match self.tick { t if t == self.next_spawn_tick => { trace!("Spawn tick was met, spawning new Tetromino!"); @@ -105,6 +111,10 @@ impl Tickable for Game { } _ => (), } + + if self.tick == self.tick_limit { + self.is_game_over = Some(LossReason::TickLimitReached); + } } } @@ -217,6 +227,10 @@ impl Game { pub fn playfield(&self) -> &PlayField { &self.playfield } + + pub fn set_time_limit(&mut self, duration: std::time::Duration) { + self.tick_limit = duration.as_secs() * TICKS_PER_SECOND as u64; + } } pub trait Controllable { @@ -230,7 +244,7 @@ pub trait Controllable { fn get_legal_actions(&self) -> Vec; } -#[derive(Hash, Eq, PartialEq, Debug, Copy, Clone)] +#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)] pub enum Action { Nothing, // Default value MoveLeft, @@ -248,8 +262,7 @@ impl Controllable for Game { return; } - self.playfield.move_offset(-1, 0); - if !self.playfield.can_active_piece_move_down() { + if self.playfield.move_offset(-1, 0) && !self.playfield.can_active_piece_move_down() { self.update_lock_tick(); } } @@ -259,8 +272,7 @@ impl Controllable for Game { return; } - self.playfield.move_offset(1, 0); - if !self.playfield.can_active_piece_move_down() { + if self.playfield.move_offset(1, 0) && !self.playfield.can_active_piece_move_down() { self.update_lock_tick(); } } @@ -288,7 +300,7 @@ impl Controllable for Game { active_piece.position = active_piece.position.offset(x, y); active_piece.rotate_left(); self.playfield.active_piece = Some(active_piece); - // self.update_lock_tick(); + self.update_lock_tick(); } Err(_) => (), } @@ -305,7 +317,7 @@ impl Controllable for Game { active_piece.position = active_piece.position.offset(x, y); active_piece.rotate_right(); self.playfield.active_piece = Some(active_piece); - // self.update_lock_tick(); + self.update_lock_tick(); } Err(_) => (), } @@ -351,13 +363,13 @@ impl Controllable for Game { fn get_legal_actions(&self) -> Vec { let mut legal_actions = vec![ + Action::RotateLeft, + Action::RotateRight, + Action::SoftDrop, + Action::HardDrop, Action::Nothing, Action::MoveLeft, Action::MoveRight, - Action::SoftDrop, - Action::HardDrop, - Action::RotateLeft, - Action::RotateRight, ]; if self.playfield.can_swap() { diff --git a/src/main.rs b/src/main.rs index 8f4e82b..3e1198e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,7 +4,6 @@ use cli::*; use game::{Action, Controllable, Game, Tickable}; use graphics::standard_renderer; use graphics::COLOR_BACKGROUND; -use indicatif::ProgressIterator; use log::{debug, info, trace}; use rand::SeedableRng; use sdl2::event::Event; @@ -23,7 +22,7 @@ mod tetromino; const TICKS_PER_SECOND: usize = 60; -#[tokio::main] +#[tokio::main(core_threads = 16)] async fn main() -> Result<(), Box> { let opts = crate::cli::Opts::parse(); @@ -36,9 +35,16 @@ async fn main() -> Result<(), Box> { qlearning::train_actor(qlearning::QLearningAgent::default(), &sub_opts) } Agent::ApproximateQLearning => { - qlearning::train_actor(qlearning::ApproximateQLearning::default(), &sub_opts) + let agent = + qlearning::train_actor(qlearning::ApproximateQLearning::default(), &sub_opts); + agent.dbg(); + agent + } + Agent::HeuristicGenetic => { + let agent = genetic::train_actor(&sub_opts).await; + agent.dbg(); + agent } - Agent::HeuristicGenetic => genetic::train_actor(&sub_opts), }), }; @@ -55,125 +61,127 @@ async fn play_game(mut actor: Option>) -> Result<(), Box { - println!("Lost due to: {:?}", e); - break; + 'escape: loop { + let mut game = Game::default(); + + loop { + match game.is_game_over() { + Some(e) => { + println!("Lost due to: {:?}", e); + break; + } + None => (), } - None => (), + + let cur_state = game.clone(); + + // If there's an actor, the player action will get overridden. If not, + // then then the player action falls through, if there is one. This is + // to allow for restarting and quitting the game from the GUI. + let mut action = None; + for event in event_pump.poll_iter() { + match event { + Event::Quit { .. } + | Event::KeyDown { + keycode: Some(Keycode::Escape), + .. + } => { + debug!("Escape registered"); + break 'escape Ok(()); + } + Event::KeyDown { + keycode: Some(Keycode::Left), + .. + } => { + debug!("Move left registered"); + action = Some(Action::MoveLeft); + } + Event::KeyDown { + keycode: Some(Keycode::Right), + .. + } => { + debug!("Move right registered"); + action = Some(Action::MoveRight); + } + Event::KeyDown { + keycode: Some(Keycode::Down), + .. + } => { + debug!("Soft drop registered"); + action = Some(Action::SoftDrop); + } + Event::KeyDown { + keycode: Some(Keycode::Z), + .. + } => { + debug!("Rotate left registered"); + action = Some(Action::RotateLeft); + } + Event::KeyDown { + keycode: Some(Keycode::X), + .. + } => { + debug!("Rotate right registered"); + action = Some(Action::RotateRight); + } + Event::KeyDown { + keycode: Some(Keycode::Space), + .. + } + | Event::KeyDown { + keycode: Some(Keycode::Up), + .. + } => { + debug!("Hard drop registered"); + action = Some(Action::HardDrop); + } + Event::KeyDown { + keycode: Some(Keycode::LShift), + .. + } => { + debug!("Hold registered"); + action = Some(Action::Hold); + } + Event::KeyDown { + keycode: Some(Keycode::R), + .. + } => { + info!("Restarting game"); + game = Game::default(); + } + Event::KeyDown { + keycode: Some(e), .. + } => trace!("Ignoring keycode {}", e), + _ => (), + } + } + + actor.as_mut().map(|actor| { + action = + Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()))); + }); + + action.map(|action| match action { + Action::Nothing => (), + Action::MoveLeft => game.move_left(), + Action::MoveRight => game.move_right(), + Action::SoftDrop => game.move_down(), + Action::HardDrop => game.hard_drop(), + Action::Hold => game.hold(), + Action::RotateLeft => game.rotate_left(), + Action::RotateRight => game.rotate_right(), + }); + + game.tick(); + canvas.set_draw_color(COLOR_BACKGROUND); + canvas.clear(); + standard_renderer::render(&mut canvas, &game); + canvas.present(); + interval.tick().await; } - let cur_state = game.clone(); - - // If there's an actor, the player action will get overridden. If not, - // then then the player action falls through, if there is one. This is - // to allow for restarting and quitting the game from the GUI. - let mut action = None; - for event in event_pump.poll_iter() { - match event { - Event::Quit { .. } - | Event::KeyDown { - keycode: Some(Keycode::Escape), - .. - } => { - debug!("Escape registered"); - break 'running; - } - Event::KeyDown { - keycode: Some(Keycode::Left), - .. - } => { - debug!("Move left registered"); - action = Some(Action::MoveLeft); - } - Event::KeyDown { - keycode: Some(Keycode::Right), - .. - } => { - debug!("Move right registered"); - action = Some(Action::MoveRight); - } - Event::KeyDown { - keycode: Some(Keycode::Down), - .. - } => { - debug!("Soft drop registered"); - action = Some(Action::SoftDrop); - } - Event::KeyDown { - keycode: Some(Keycode::Z), - .. - } => { - debug!("Rotate left registered"); - action = Some(Action::RotateLeft); - } - Event::KeyDown { - keycode: Some(Keycode::X), - .. - } => { - debug!("Rotate right registered"); - action = Some(Action::RotateRight); - } - Event::KeyDown { - keycode: Some(Keycode::Space), - .. - } - | Event::KeyDown { - keycode: Some(Keycode::Up), - .. - } => { - debug!("Hard drop registered"); - action = Some(Action::HardDrop); - } - Event::KeyDown { - keycode: Some(Keycode::LShift), - .. - } => { - debug!("Hold registered"); - action = Some(Action::Hold); - } - Event::KeyDown { - keycode: Some(Keycode::R), - .. - } => { - info!("Restarting game"); - game = Game::default(); - } - Event::KeyDown { - keycode: Some(e), .. - } => trace!("Ignoring keycode {}", e), - _ => (), - } - } - - actor.as_mut().map(|actor| { - action = Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()))); - }); - - action.map(|action| match action { - Action::Nothing => (), - Action::MoveLeft => game.move_left(), - Action::MoveRight => game.move_right(), - Action::SoftDrop => game.move_down(), - Action::HardDrop => game.hard_drop(), - Action::Hold => game.hold(), - Action::RotateLeft => game.rotate_left(), - Action::RotateRight => game.rotate_right(), - }); - - game.tick(); - canvas.set_draw_color(COLOR_BACKGROUND); - canvas.clear(); - standard_renderer::render(&mut canvas, &game); - canvas.present(); - interval.tick().await; + info!("Final score: {}", game.score()); } - - info!("Final score: {}", game.score()); - actor.map(|a| a.dbg()); - Ok(()) }