heuristicgenetic

This commit is contained in:
Edward Shen 2020-04-06 16:07:30 -04:00
parent deb74da552
commit 5a9e2538aa
Signed by: edward
GPG key ID: 19182661E818369F
6 changed files with 490 additions and 233 deletions

104
Cargo.lock generated
View file

@ -150,11 +150,88 @@ name = "fuchsia-zircon-sys"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "futures"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-executor 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "futures-channel"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "futures-core"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "futures-executor"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "futures-io"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "futures-macro"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro2 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)",
"quote 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)",
"syn 1.0.16 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "futures-sink"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "futures-task"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "futures-util"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-macro 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
"pin-utils 0.1.0-alpha.4 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro-nested 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
"slab 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "getrandom"
version = "0.1.13"
@ -345,6 +422,11 @@ name = "pin-project-lite"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "pin-utils"
version = "0.1.0-alpha.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "ppv-lite86"
version = "0.2.6"
@ -374,6 +456,16 @@ dependencies = [
"version_check 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "proc-macro-nested"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "proc-macro2"
version = "1.0.9"
@ -549,6 +641,7 @@ name = "tetris"
version = "0.1.0"
dependencies = [
"clap 3.0.0-beta.1 (git+https://github.com/clap-rs/clap/)",
"futures 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
"indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
@ -695,7 +788,15 @@ dependencies = [
"checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3"
"checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82"
"checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
"checksum futures 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "5c329ae8753502fb44ae4fc2b622fa2a94652c41e795143765ba0927f92ab780"
"checksum futures-channel 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f0c77d04ce8edd9cb903932b608268b3fffec4163dc053b3b402bf47eac1f1a8"
"checksum futures-core 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f25592f769825e89b92358db00d26f965761e094951ac44d3663ef25b7ac464a"
"checksum futures-executor 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "f674f3e1bcb15b37284a90cedf55afdba482ab061c407a9c0ebbd0f3109741ba"
"checksum futures-io 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "a638959aa96152c7a4cddf50fcb1e3fede0583b27157c26e67d6f99904090dc6"
"checksum futures-macro 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "9a5081aa3de1f7542a794a397cde100ed903b0630152d0973479018fd85423a7"
"checksum futures-sink 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "3466821b4bc114d95b087b850a724c6f83115e929bc88f1fa98a3304a944c8a6"
"checksum futures-task 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "7b0a34e53cf6cdcd0178aa573aed466b646eb3db769570841fda0c7ede375a27"
"checksum futures-util 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "22766cf25d64306bedf0384da004d05c9974ab104fcc4528f1236181c18004c5"
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
"checksum heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205"
"checksum hermit-abi 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "1010591b26bbfe835e9faeabeb11866061cc7dcebffd56ad7d0942d0e61aefd8"
@ -718,9 +819,12 @@ dependencies = [
"checksum num_cpus 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "46203554f085ff89c235cd12f7075f3233af9b11ed7c9e16dfe2560d03313ce6"
"checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
"checksum pin-project-lite 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "237844750cfbb86f67afe27eee600dfbbcb6188d734139b534cbfbf4f96792ae"
"checksum pin-utils 0.1.0-alpha.4 (registry+https://github.com/rust-lang/crates.io-index)" = "5894c618ce612a3fa23881b152b608bafb8c56cfc22f434a3ba3120b40f7b587"
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
"checksum proc-macro-error 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7"
"checksum proc-macro-error-attr 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "8a5b4b77fdb63c1eca72173d68d24501c54ab1269409f6b672c85deb18af69de"
"checksum proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)" = "0d659fe7c6d27f25e9d80a1a094c223f5246f6a6596453e09d7229bf42750b63"
"checksum proc-macro-nested 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "8e946095f9d3ed29ec38de908c22f95d9ac008e424c7bcae54c75a79c527c694"
"checksum proc-macro2 1.0.9 (registry+https://github.com/rust-lang/crates.io-index)" = "6c09721c6781493a2a492a96b5a5bf19b65917fe6728884e7c44dd0c60ca3435"
"checksum quote 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2bdc6c187c65bca4260c9011c9e3132efe4909da44726bad24cf7572ae338d7f"
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"

View file

@ -13,4 +13,5 @@ log = "0.4"
simple_logger = "1.6"
sdl2 = { version = "0.33.0", features = ["ttf"] }
clap = { git = "https://github.com/clap-rs/clap/", features = ["color"] }
indicatif = "0.14"
indicatif = "0.14"
futures = "0.3"

View file

@ -1,11 +1,13 @@
// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
use super::{Actor, State};
use super::{Actor, Predictable, State};
use crate::{
cli::Train,
game::{Action, Controllable, Game},
game::{Action, Controllable, Game, Tickable},
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
};
use indicatif::ProgressBar;
use log::debug;
use rand::rngs::SmallRng;
use rand::{seq::SliceRandom, Rng, SeedableRng};
@ -59,15 +61,21 @@ impl Parameters {
}
}
#[derive(Clone, Copy)]
#[derive(Clone, Copy, Debug)]
pub struct GeneticHeuristicAgent {
params: Parameters,
}
impl Default for GeneticHeuristicAgent {
fn default() -> Self {
let mut rng = SmallRng::from_entropy();
Self {
params: Parameters::default(),
params: Parameters {
total_height: rng.gen::<f64>(),
bumpiness: rng.gen::<f64>(),
holes: rng.gen::<f64>(),
complete_lines: rng.gen::<f64>(),
},
}
}
}
@ -119,24 +127,31 @@ impl GeneticHeuristicAgent {
}
fn get_heuristic(&self, game: &Game, action: &Action) -> f64 {
todo!();
self.params.dot_multiply(&Self::extract_features_from_state(
&game.get_next_state(*action).into(),
))
}
pub fn breed(&self, self_fitness: u32, other: &Self, other_fitness: u32) -> Self {
let weight = (self_fitness + other_fitness) as f64;
let self_weight = self_fitness as f64 / weight;
let other_weight = other_fitness as f64 / weight;
Self {
params: Parameters {
total_height: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
bumpiness: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
holes: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
complete_lines: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
},
let weight = self_fitness + other_fitness;
if weight != 0 {
let self_weight = self_fitness as f64 / weight as f64;
let other_weight = other_fitness as f64 / weight as f64;
Self {
params: Parameters {
total_height: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
bumpiness: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
holes: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
complete_lines: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
},
}
} else {
Self::default()
}
}
@ -146,33 +161,62 @@ impl GeneticHeuristicAgent {
}
impl Actor for GeneticHeuristicAgent {
fn get_action(&self, _: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
*legal_actions
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
let actions = legal_actions
.iter()
.map(|action| (action, self.get_heuristic(game, action)))
.max_by_key(|(_, heuristic)| (heuristic * 1_000_00.0) as usize)
.map(|action| {
(
action,
(self.get_heuristic(game, action) * 1_000_000.0) as usize,
)
})
.collect::<Vec<_>>();
let max_val = actions
.iter()
.max_by_key(|(_, heuristic)| heuristic)
.unwrap()
.1;
*actions
.iter()
.filter(|e| e.1 == max_val)
.collect::<Vec<_>>()
.choose(rng)
.unwrap()
.0
}
fn dbg(&self) {
// unimplemented!()
debug!("{:?}", self.params);
}
}
pub fn train_actor(opts: &Train) -> Box<dyn Actor> {
let mut rng = SmallRng::from_entropy();
pub async fn train_actor(opts: &Train) -> Box<dyn Actor> {
use std::sync::{Arc, Mutex};
let rng = Arc::new(Mutex::new(SmallRng::from_entropy()));
let mut population = vec![(GeneticHeuristicAgent::default(), 0); 1000];
let total_pb = Arc::new(Mutex::new(ProgressBar::new(
population.len() as u64 * opts.episodes as u64,
)));
for _ in 0..opts.episodes {
let mut new_population: Vec<(GeneticHeuristicAgent, u32)> = population
.iter()
.map(|(agent, _)| {
let mut new_pop_futs = Vec::with_capacity(population.len());
let new_population = Arc::new(Mutex::new(Vec::with_capacity(population.len())));
let num_rounds = 10;
for i in 0..population.len() {
let rng = rng.clone();
let new_population = new_population.clone();
let total_pb = total_pb.clone();
let agent = population[i].0;
new_pop_futs.push(tokio::task::spawn(async move {
let mut fitness = 0;
for _ in 0..100 {
for _ in 0..num_rounds {
let mut game = Game::default();
game.set_piece_limit(500);
game.set_time_limit(std::time::Duration::from_secs(60));
while (&game).is_game_over().is_none() {
let mut rng = rng.lock().expect("rng failed");
super::apply_action_to_game(
agent.get_action(
&mut rng,
@ -180,24 +224,33 @@ pub fn train_actor(opts: &Train) -> Box<dyn Actor> {
&game.get_legal_actions(),
),
&mut game,
)
);
game.tick();
}
fitness += game.line_clears;
}
(*agent, fitness)
})
.collect::<Vec<_>>();
new_population.lock().unwrap().push((agent, fitness));
let mut breeded_population: Vec<(GeneticHeuristicAgent, u32)> = Vec::with_capacity(300);
total_pb.lock().expect("progressbar failed").inc(1);
}));
}
futures::future::join_all(new_pop_futs).await;
let mut rng = SmallRng::from_entropy();
let mut new_population = new_population.lock().unwrap().clone();
let new_pop_size = population.len() * 3 / 10;
let mut breeded_population: Vec<(GeneticHeuristicAgent, u32)> =
Vec::with_capacity(new_pop_size);
for _ in 0..3 {
let mut random_selection = new_population
.choose_multiple(&mut rng, 100)
.choose_multiple(&mut rng, new_pop_size / 3)
.collect::<Vec<_>>();
random_selection.sort_unstable_by(|e1, e2| e1.1.cmp(&e2.1));
let best_two = random_selection.iter().rev().take(2).collect::<Vec<_>>();
let parent1 = best_two[0];
let parent2 = best_two[1];
for _ in 0..100 {
let parent1 = dbg!(best_two[0]);
let parent2 = dbg!(best_two[1]);
for _ in 0..new_pop_size / 3 {
let breeded = parent1.0.breed(parent1.1, &parent2.0, parent2.1);
let mut cloned = breeded.clone();
if rng.gen::<f64>() < 0.05 {
@ -208,11 +261,11 @@ pub fn train_actor(opts: &Train) -> Box<dyn Actor> {
}
new_population.sort_unstable_by_key(|e| e.1);
new_population.splice(..breeded_population.len(), breeded_population);
population = new_population;
}
total_pb.lock().unwrap().finish();
population.sort_unstable_by_key(|e| e.1);
Box::new(population.iter().rev().next().unwrap().0)
}

View file

@ -1,11 +1,12 @@
use super::Predictable;
use crate::actors::{Actor, State};
use crate::{
cli::Train,
game::{Action, Controllable, Game, Tickable},
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
playfield::{Matrix, PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
};
use indicatif::ProgressIterator;
use log::{debug, info};
use log::{debug, info, trace};
use rand::rngs::SmallRng;
use rand::seq::SliceRandom;
use rand::Rng;
@ -15,9 +16,9 @@ use std::collections::HashMap;
pub trait QLearningActor: Actor {
fn update(
&mut self,
state: State,
game_state: Game,
action: Action,
next_state: State,
next_game_state: Game,
next_legal_actions: &[Action],
reward: f64,
);
@ -54,11 +55,32 @@ impl QLearningAgent {
}
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
*legal_actions
let legal_actions = legal_actions
.iter()
.map(|action| (action, self.get_q_value(state, *action)))
.collect::<Vec<_>>();
let max_val = legal_actions
.iter()
.map(|action| (action, self.get_q_value(&state, *action)))
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
.expect("Failed to select an action")
.1;
let actions_to_choose = legal_actions
.iter()
.filter(|(_, v)| max_val == *v)
.collect::<Vec<_>>();
if actions_to_choose.len() != 1 {
trace!(
"more than one best option, choosing randomly: {:?}",
actions_to_choose
);
}
*actions_to_choose
.choose(&mut SmallRng::from_entropy())
.unwrap()
.0
}
@ -95,12 +117,14 @@ impl Actor for QLearningAgent {
impl QLearningActor for QLearningAgent {
fn update(
&mut self,
state: State,
game_state: Game,
action: Action,
next_state: State,
next_game_state: Game,
_next_legal_actions: &[Action],
reward: f64,
) {
let state = (&game_state).into();
let next_state = (&next_game_state).into();
let cur_q_val = self.get_q_value(&state, action);
let new_q_val = cur_q_val
+ self.learning_rate
@ -130,106 +154,163 @@ pub struct ApproximateQLearning {
pub learning_rate: f64,
pub exploration_prob: f64,
pub discount_rate: f64,
weights: HashMap<String, f64>,
weights: HashMap<Feature, f64>,
}
impl Default for ApproximateQLearning {
fn default() -> Self {
let mut weights = HashMap::default();
weights.insert(Feature::TotalHeight, 1.0);
weights.insert(Feature::Bumpiness, 1.0);
weights.insert(Feature::LinesCleared, 1.0);
weights.insert(Feature::Holes, 1.0);
Self {
learning_rate: 0.0,
exploration_prob: 0.0,
discount_rate: 0.0,
weights: HashMap::default(),
weights,
}
}
}
#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
enum Feature {
TotalHeight,
Bumpiness,
LinesCleared,
Holes,
}
impl ApproximateQLearning {
fn get_features(
&self,
state: &State,
_action: &Action,
new_state: &State,
) -> HashMap<String, f64> {
fn get_features(&self, game: &Game, action: &Action) -> HashMap<Feature, f64> {
// let game = game.get_next_state(*action);
let mut features = HashMap::default();
let mut heights = [None; PLAYFIELD_WIDTH];
for r in 0..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if heights[c].is_none() && state.matrix[r][c].is_some() {
heights[c] = Some(PLAYFIELD_HEIGHT - r);
}
}
}
let field = game.playfield().field();
let heights = self.get_heights(field);
features.insert(
"Total Height".into(),
heights
.iter()
.map(|o| o.unwrap_or_else(|| 0))
.sum::<usize>() as f64
/ (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64,
Feature::TotalHeight,
heights.iter().sum::<usize>() as f64 / (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64,
);
features.insert(
"Bumpiness".into(),
Feature::Bumpiness,
heights
.iter()
.map(|o| o.unwrap_or_else(|| 0) as isize)
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
.fold((0, 0), |(acc, prev), cur| {
(acc + (prev as isize - *cur as isize).abs(), *cur)
})
.0 as f64
/ (PLAYFIELD_WIDTH * 40) as f64,
);
features.insert(
"Lines cleared".into(),
(new_state.line_clears - state.line_clears) as f64 / 4.0,
Feature::LinesCleared,
game.playfield()
.field()
.iter()
.map(|r| r.iter().all(Option::is_some))
.map(|r| if r { 1 } else { 0 })
.sum::<usize>() as f64
/ 4.0,
);
let mut holes = 0;
for r in 1..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
if field[r][c].is_none() && field[r - 1][c].is_some() {
holes += 1;
}
}
}
features.insert("Holes".into(), holes as f64);
features.insert(Feature::Holes, holes as f64);
features
}
fn get_q_value(&self, state: &State, action: &Action, next_state: &State) -> f64 {
self.get_features(state, action, next_state)
fn get_heights(&self, matrix: &Matrix) -> Vec<usize> {
let mut heights = vec![0; matrix[0].len()];
for r in 0..matrix.len() {
for c in 0..matrix[0].len() {
if heights[c] == 0 && matrix[r][c].is_some() {
heights[c] = matrix.len() - r;
}
}
}
heights
}
fn get_q_value(&self, game: &Game, action: &Action) -> f64 {
self.get_features(game, action)
.iter()
.map(|(key, val)| val * *self.weights.get(key).unwrap_or_else(|| &0.0))
.map(|(key, val)| val * *self.weights.get(key).unwrap())
.sum()
}
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
*legal_actions
fn get_action_from_q_values(&self, game: &Game) -> Action {
let legal_actions = game.get_legal_actions();
let legal_actions = legal_actions
.iter()
.map(|action| (action, self.get_q_value(game, action)))
.collect::<Vec<_>>();
let max_val = legal_actions
.iter()
.map(|action| (action, self.get_q_value(&state, action, state)))
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
.expect("Failed to select an action")
.1;
let actions_to_choose = legal_actions
.iter()
.filter(|(_, v)| max_val == *v)
.collect::<Vec<_>>();
if actions_to_choose.len() != 1 {
trace!(
"more than one best option, choosing randomly: {:?}",
actions_to_choose
);
}
*actions_to_choose
.choose(&mut SmallRng::from_entropy())
.unwrap()
.0
}
fn get_value(&self, state: &State, legal_actions: &[Action]) -> f64 {
legal_actions
fn get_value(&self, game: &Game) -> f64 {
game.get_legal_actions()
.iter()
.map(|action| self.get_q_value(state, action, state))
.map(|action| self.get_q_value(game, action))
.max_by_key(|v| (v * 1_000_000.0) as isize)
.unwrap_or_else(|| 0.0)
}
}
#[cfg(test)]
mod aaa {
use super::*;
use crate::tetromino::TetrominoType;
#[test]
fn test_height() {
let agent = ApproximateQLearning::default();
let matrix = vec![
vec![None, None, None],
vec![None, Some(TetrominoType::T), None],
vec![None, None, None],
];
assert_eq!(agent.get_heights(&matrix), vec![0, 2, 0]);
}
}
impl Actor for ApproximateQLearning {
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
if rng.gen::<f64>() < self.exploration_prob {
*legal_actions.choose(rng).unwrap()
} else {
self.get_action_from_q_values(&game.into(), legal_actions)
self.get_action_from_q_values(game)
}
}
@ -241,21 +322,19 @@ impl Actor for ApproximateQLearning {
impl QLearningActor for ApproximateQLearning {
fn update(
&mut self,
state: State,
game_state: Game,
action: Action,
next_state: State,
next_game_state: Game,
next_legal_actions: &[Action],
reward: f64,
) {
let difference = reward
+ self.discount_rate * self.get_value(&next_state, next_legal_actions)
- self.get_q_value(&state, &action, &next_state);
let difference = reward + self.discount_rate * self.get_value(&next_game_state)
- self.get_q_value(&game_state, &action);
for (feat_key, feat_val) in self.get_features(&state, &action, &next_state) {
for (feat_key, feat_val) in self.get_features(&game_state, &action) {
self.weights.insert(
feat_key.clone(),
*self.weights.get(&feat_key).unwrap_or_else(|| &0.0)
+ self.learning_rate * difference * feat_val,
*self.weights.get(&feat_key).unwrap() + self.learning_rate * difference * feat_val,
);
}
}
@ -303,11 +382,11 @@ pub fn train_actor<T: 'static + QLearningActor + Actor>(
super::apply_action_to_game(action, &mut game);
let new_state = (&game).into();
let new_state = game.clone();
let mut reward = game.score() as f64 - cur_score as f64;
if action != Action::Nothing {
reward -= 0.0;
}
// if action != Action::Nothing {
// reward -= 0.0;
// }
if game.is_game_over().is_some() {
reward = -1.0;

View file

@ -17,6 +17,7 @@ pub enum LossReason {
TopOut,
LockOut,
PieceLimitReached,
TickLimitReached,
BlockOut(Position),
}
@ -39,6 +40,9 @@ pub struct Game {
// used if we set a limit on how long a game can last.
pieces_placed: usize,
piece_limit: usize,
// used if we set a limit on how long the game can be played.
tick_limit: u64,
}
impl fmt::Debug for Game {
@ -65,6 +69,7 @@ impl Default for Game {
line_clears: 0,
pieces_placed: 0,
piece_limit: 0,
tick_limit: 0,
}
}
}
@ -79,6 +84,7 @@ impl Tickable for Game {
return;
}
self.tick += 1;
match self.tick {
t if t == self.next_spawn_tick => {
trace!("Spawn tick was met, spawning new Tetromino!");
@ -105,6 +111,10 @@ impl Tickable for Game {
}
_ => (),
}
if self.tick == self.tick_limit {
self.is_game_over = Some(LossReason::TickLimitReached);
}
}
}
@ -217,6 +227,10 @@ impl Game {
pub fn playfield(&self) -> &PlayField {
&self.playfield
}
pub fn set_time_limit(&mut self, duration: std::time::Duration) {
self.tick_limit = duration.as_secs() * TICKS_PER_SECOND as u64;
}
}
pub trait Controllable {
@ -230,7 +244,7 @@ pub trait Controllable {
fn get_legal_actions(&self) -> Vec<Action>;
}
#[derive(Hash, Eq, PartialEq, Debug, Copy, Clone)]
#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
pub enum Action {
Nothing, // Default value
MoveLeft,
@ -248,8 +262,7 @@ impl Controllable for Game {
return;
}
self.playfield.move_offset(-1, 0);
if !self.playfield.can_active_piece_move_down() {
if self.playfield.move_offset(-1, 0) && !self.playfield.can_active_piece_move_down() {
self.update_lock_tick();
}
}
@ -259,8 +272,7 @@ impl Controllable for Game {
return;
}
self.playfield.move_offset(1, 0);
if !self.playfield.can_active_piece_move_down() {
if self.playfield.move_offset(1, 0) && !self.playfield.can_active_piece_move_down() {
self.update_lock_tick();
}
}
@ -288,7 +300,7 @@ impl Controllable for Game {
active_piece.position = active_piece.position.offset(x, y);
active_piece.rotate_left();
self.playfield.active_piece = Some(active_piece);
// self.update_lock_tick();
self.update_lock_tick();
}
Err(_) => (),
}
@ -305,7 +317,7 @@ impl Controllable for Game {
active_piece.position = active_piece.position.offset(x, y);
active_piece.rotate_right();
self.playfield.active_piece = Some(active_piece);
// self.update_lock_tick();
self.update_lock_tick();
}
Err(_) => (),
}
@ -351,13 +363,13 @@ impl Controllable for Game {
fn get_legal_actions(&self) -> Vec<Action> {
let mut legal_actions = vec![
Action::RotateLeft,
Action::RotateRight,
Action::SoftDrop,
Action::HardDrop,
Action::Nothing,
Action::MoveLeft,
Action::MoveRight,
Action::SoftDrop,
Action::HardDrop,
Action::RotateLeft,
Action::RotateRight,
];
if self.playfield.can_swap() {

View file

@ -4,7 +4,6 @@ use cli::*;
use game::{Action, Controllable, Game, Tickable};
use graphics::standard_renderer;
use graphics::COLOR_BACKGROUND;
use indicatif::ProgressIterator;
use log::{debug, info, trace};
use rand::SeedableRng;
use sdl2::event::Event;
@ -23,7 +22,7 @@ mod tetromino;
const TICKS_PER_SECOND: usize = 60;
#[tokio::main]
#[tokio::main(core_threads = 16)]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let opts = crate::cli::Opts::parse();
@ -36,9 +35,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
qlearning::train_actor(qlearning::QLearningAgent::default(), &sub_opts)
}
Agent::ApproximateQLearning => {
qlearning::train_actor(qlearning::ApproximateQLearning::default(), &sub_opts)
let agent =
qlearning::train_actor(qlearning::ApproximateQLearning::default(), &sub_opts);
agent.dbg();
agent
}
Agent::HeuristicGenetic => {
let agent = genetic::train_actor(&sub_opts).await;
agent.dbg();
agent
}
Agent::HeuristicGenetic => genetic::train_actor(&sub_opts),
}),
};
@ -55,125 +61,127 @@ async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std:
.build()?;
let mut canvas = window.into_canvas().build()?;
let mut event_pump = sdl_context.event_pump()?;
let mut game = Game::default();
let mut interval = interval(Duration::from_millis(1000 / TICKS_PER_SECOND as u64));
'running: loop {
match game.is_game_over() {
Some(e) => {
println!("Lost due to: {:?}", e);
break;
'escape: loop {
let mut game = Game::default();
loop {
match game.is_game_over() {
Some(e) => {
println!("Lost due to: {:?}", e);
break;
}
None => (),
}
None => (),
let cur_state = game.clone();
// If there's an actor, the player action will get overridden. If not,
// then then the player action falls through, if there is one. This is
// to allow for restarting and quitting the game from the GUI.
let mut action = None;
for event in event_pump.poll_iter() {
match event {
Event::Quit { .. }
| Event::KeyDown {
keycode: Some(Keycode::Escape),
..
} => {
debug!("Escape registered");
break 'escape Ok(());
}
Event::KeyDown {
keycode: Some(Keycode::Left),
..
} => {
debug!("Move left registered");
action = Some(Action::MoveLeft);
}
Event::KeyDown {
keycode: Some(Keycode::Right),
..
} => {
debug!("Move right registered");
action = Some(Action::MoveRight);
}
Event::KeyDown {
keycode: Some(Keycode::Down),
..
} => {
debug!("Soft drop registered");
action = Some(Action::SoftDrop);
}
Event::KeyDown {
keycode: Some(Keycode::Z),
..
} => {
debug!("Rotate left registered");
action = Some(Action::RotateLeft);
}
Event::KeyDown {
keycode: Some(Keycode::X),
..
} => {
debug!("Rotate right registered");
action = Some(Action::RotateRight);
}
Event::KeyDown {
keycode: Some(Keycode::Space),
..
}
| Event::KeyDown {
keycode: Some(Keycode::Up),
..
} => {
debug!("Hard drop registered");
action = Some(Action::HardDrop);
}
Event::KeyDown {
keycode: Some(Keycode::LShift),
..
} => {
debug!("Hold registered");
action = Some(Action::Hold);
}
Event::KeyDown {
keycode: Some(Keycode::R),
..
} => {
info!("Restarting game");
game = Game::default();
}
Event::KeyDown {
keycode: Some(e), ..
} => trace!("Ignoring keycode {}", e),
_ => (),
}
}
actor.as_mut().map(|actor| {
action =
Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())));
});
action.map(|action| match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
});
game.tick();
canvas.set_draw_color(COLOR_BACKGROUND);
canvas.clear();
standard_renderer::render(&mut canvas, &game);
canvas.present();
interval.tick().await;
}
let cur_state = game.clone();
// If there's an actor, the player action will get overridden. If not,
// then then the player action falls through, if there is one. This is
// to allow for restarting and quitting the game from the GUI.
let mut action = None;
for event in event_pump.poll_iter() {
match event {
Event::Quit { .. }
| Event::KeyDown {
keycode: Some(Keycode::Escape),
..
} => {
debug!("Escape registered");
break 'running;
}
Event::KeyDown {
keycode: Some(Keycode::Left),
..
} => {
debug!("Move left registered");
action = Some(Action::MoveLeft);
}
Event::KeyDown {
keycode: Some(Keycode::Right),
..
} => {
debug!("Move right registered");
action = Some(Action::MoveRight);
}
Event::KeyDown {
keycode: Some(Keycode::Down),
..
} => {
debug!("Soft drop registered");
action = Some(Action::SoftDrop);
}
Event::KeyDown {
keycode: Some(Keycode::Z),
..
} => {
debug!("Rotate left registered");
action = Some(Action::RotateLeft);
}
Event::KeyDown {
keycode: Some(Keycode::X),
..
} => {
debug!("Rotate right registered");
action = Some(Action::RotateRight);
}
Event::KeyDown {
keycode: Some(Keycode::Space),
..
}
| Event::KeyDown {
keycode: Some(Keycode::Up),
..
} => {
debug!("Hard drop registered");
action = Some(Action::HardDrop);
}
Event::KeyDown {
keycode: Some(Keycode::LShift),
..
} => {
debug!("Hold registered");
action = Some(Action::Hold);
}
Event::KeyDown {
keycode: Some(Keycode::R),
..
} => {
info!("Restarting game");
game = Game::default();
}
Event::KeyDown {
keycode: Some(e), ..
} => trace!("Ignoring keycode {}", e),
_ => (),
}
}
actor.as_mut().map(|actor| {
action = Some(actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())));
});
action.map(|action| match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
});
game.tick();
canvas.set_draw_color(COLOR_BACKGROUND);
canvas.clear();
standard_renderer::render(&mut canvas, &game);
canvas.present();
interval.tick().await;
info!("Final score: {}", game.score());
}
info!("Final score: {}", game.score());
actor.map(|a| a.dbg());
Ok(())
}