From a65f48f5856d58803cee67369c0e63cf2b7623f9 Mon Sep 17 00:00:00 2001 From: Edward Shen Date: Sun, 5 Apr 2020 19:36:00 -0400 Subject: [PATCH] genetic algo base --- src/actors/genetic.rs | 71 +++++++++++++++++++++++++++++++++++++++++ src/actors/mod.rs | 26 ++++++++++++++- src/actors/qlearning.rs | 56 ++++++++++++++++++++++++++++++-- src/game.rs | 7 ++-- src/main.rs | 51 +---------------------------- src/srs.rs | 1 + 6 files changed, 157 insertions(+), 55 deletions(-) create mode 100644 src/actors/genetic.rs diff --git a/src/actors/genetic.rs b/src/actors/genetic.rs new file mode 100644 index 0000000..2a5965a --- /dev/null +++ b/src/actors/genetic.rs @@ -0,0 +1,71 @@ +// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/ + +use super::Actor; +use rand::rngs::SmallRng; +use rand::Rng; + +pub struct Parameters { + total_height: f64, + bumpiness: f64, + holes: f64, + complete_lines: f64, +} + +impl Parameters { + fn mutate(mut self, rng: &mut SmallRng) { + let mutation_amt = rng.gen_range(-0.2, 0.2); + match rng.gen_range(0, 4) { + 0 => self.total_height += mutation_amt, + 1 => self.bumpiness += mutation_amt, + 2 => self.holes += mutation_amt, + 3 => self.complete_lines += mutation_amt, + _ => unreachable!(), + } + + let normalization_factor = (self.total_height.powi(2) + + self.bumpiness.powi(2) + + self.holes.powi(2) + + self.complete_lines.powi(2)) + .sqrt(); + + self.total_height /= normalization_factor; + self.bumpiness /= normalization_factor; + self.holes /= normalization_factor; + self.complete_lines /= normalization_factor; + } +} + +pub struct GeneticHeuristicAgent {} + +impl Actor for GeneticHeuristicAgent { + fn get_action( + &self, + rng: &mut SmallRng, + state: &super::State, + legal_actions: &[crate::game::Action], + ) -> crate::game::Action { + unimplemented!() + } + fn update( + &mut self, + state: super::State, + action: crate::game::Action, + next_state: super::State, + next_legal_actions: &[crate::game::Action], + reward: f64, + ) { + unimplemented!() + } + fn set_learning_rate(&mut self, learning_rate: f64) { + unimplemented!() + } + fn set_exploration_prob(&mut self, exploration_prob: f64) { + unimplemented!() + } + fn set_discount_rate(&mut self, discount_rate: f64) { + unimplemented!() + } + fn dbg(&self) { + unimplemented!() + } +} diff --git a/src/actors/mod.rs b/src/actors/mod.rs index d25e7b9..867abda 100644 --- a/src/actors/mod.rs +++ b/src/actors/mod.rs @@ -1,8 +1,9 @@ -use crate::game::{Action, Game}; +use crate::game::{Action, Controllable, Game}; use crate::playfield::{Matrix, PlayField}; use crate::tetromino::{Tetromino, TetrominoType}; use rand::rngs::SmallRng; +pub mod genetic; pub mod qlearning; #[derive(Hash, PartialEq, Eq, Clone, Debug)] @@ -56,3 +57,26 @@ pub trait Actor { fn dbg(&self); } + +pub trait Predictable { + fn get_next_state(&self, action: Action) -> Self; +} + +impl Predictable for Game { + /// Expensive, performs a full clone. + fn get_next_state(&self, action: Action) -> Self { + let mut game = self.clone(); + match action { + Action::Nothing => (), + Action::MoveLeft => game.move_left(), + Action::MoveRight => game.move_right(), + Action::SoftDrop => game.move_down(), + Action::HardDrop => game.hard_drop(), + Action::Hold => game.hold(), + Action::RotateLeft => game.rotate_left(), + Action::RotateRight => game.rotate_right(), + }; + + game + } +} diff --git a/src/actors/qlearning.rs b/src/actors/qlearning.rs index 425da80..bc3ccd5 100644 --- a/src/actors/qlearning.rs +++ b/src/actors/qlearning.rs @@ -1,12 +1,14 @@ use crate::actors::{Actor, State}; use crate::{ - game::Action, + game::{Action, Controllable, Game, Tickable}, playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, }; -use log::debug; +use indicatif::ProgressIterator; +use log::{debug, info}; use rand::rngs::SmallRng; use rand::seq::SliceRandom; use rand::Rng; +use rand::SeedableRng; use std::collections::HashMap; pub struct QLearningAgent { @@ -250,3 +252,53 @@ impl Actor for ApproximateQLearning { dbg!(&self.weights); } } + +pub fn train_actor(episodes: usize, mut actor: Box) -> Box { + let mut rng = SmallRng::from_entropy(); + let mut avg = 0.0; + + for i in (0..episodes).progress() { + if i != 0 && i % (episodes / 10) == 0 { + info!("Last {} scores avg: {}", (episodes / 10), avg); + println!(); + avg = 0.0; + } + let mut game = Game::default(); + while (&game).is_game_over().is_none() { + let cur_state = (&game).into(); + let cur_score = game.score(); + let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())); + + match action { + Action::Nothing => (), + Action::MoveLeft => game.move_left(), + Action::MoveRight => game.move_right(), + Action::SoftDrop => game.move_down(), + Action::HardDrop => game.hard_drop(), + Action::Hold => game.hold(), + Action::RotateLeft => game.rotate_left(), + Action::RotateRight => game.rotate_right(), + } + + let new_state = (&game).into(); + let mut reward = game.score() as f64 - cur_score as f64; + if action != Action::Nothing { + reward -= 0.0; + } + + if game.is_game_over().is_some() { + reward = -1.0; + } + + let new_legal_actions = game.get_legal_actions(); + + actor.update(cur_state, action, new_state, &new_legal_actions, reward); + + game.tick(); + } + + avg += game.score() as f64 / (episodes / 10) as f64; + } + + actor +} diff --git a/src/game.rs b/src/game.rs index 1610df3..35ae74b 100644 --- a/src/game.rs +++ b/src/game.rs @@ -18,6 +18,8 @@ pub enum LossReason { LockOut, BlockOut(Position), } + +#[derive(Clone)] // Logic is based on 60 ticks / second pub struct Game { playfield: PlayField, @@ -100,6 +102,7 @@ impl Tickable for Game { } } +#[derive(Clone, Copy)] enum ClearAction { Single, Double, @@ -266,7 +269,7 @@ impl Controllable for Game { active_piece.position = active_piece.position.offset(x, y); active_piece.rotate_left(); self.playfield.active_piece = Some(active_piece); - self.update_lock_tick(); + // self.update_lock_tick(); } Err(_) => (), } @@ -283,7 +286,7 @@ impl Controllable for Game { active_piece.position = active_piece.position.offset(x, y); active_piece.rotate_right(); self.playfield.active_piece = Some(active_piece); - self.update_lock_tick(); + // self.update_lock_tick(); } Err(_) => (), } diff --git a/src/main.rs b/src/main.rs index ece01dc..2df9b44 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ use graphics::standard_renderer; use graphics::COLOR_BACKGROUND; use indicatif::ProgressIterator; use log::{debug, info, trace}; +use qlearning::train_actor; use rand::SeedableRng; use sdl2::event::Event; use sdl2::keyboard::Keycode; @@ -61,56 +62,6 @@ async fn main() -> Result<(), Box> { Ok(()) } -fn train_actor(episodes: usize, mut actor: Box) -> Box { - let mut rng = rand::rngs::SmallRng::from_entropy(); - let mut avg = 0.0; - - for i in (0..episodes).progress() { - if i != 0 && i % (episodes / 10) == 0 { - info!("Last {} scores avg: {}", (episodes / 10), avg); - println!(); - avg = 0.0; - } - let mut game = Game::default(); - while (&game).is_game_over().is_none() { - let cur_state = (&game).into(); - let cur_score = game.score(); - let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())); - - match action { - Action::Nothing => (), - Action::MoveLeft => game.move_left(), - Action::MoveRight => game.move_right(), - Action::SoftDrop => game.move_down(), - Action::HardDrop => game.hard_drop(), - Action::Hold => game.hold(), - Action::RotateLeft => game.rotate_left(), - Action::RotateRight => game.rotate_right(), - } - - let new_state = (&game).into(); - let mut reward = game.score() as f64 - cur_score as f64; - if action != Action::Nothing { - reward -= 0.0; - } - - if game.is_game_over().is_some() { - reward = -1.0; - } - - let new_legal_actions = game.get_legal_actions(); - - actor.update(cur_state, action, new_state, &new_legal_actions, reward); - - game.tick(); - } - - avg += game.score() as f64 / (episodes / 10) as f64; - } - - actor -} - async fn play_game(mut actor: Option>) -> Result<(), Box> { let mut rng = rand::rngs::SmallRng::from_entropy(); let sdl_context = sdl2::init()?; diff --git a/src/srs.rs b/src/srs.rs index 73a2841..892a99e 100644 --- a/src/srs.rs +++ b/src/srs.rs @@ -10,6 +10,7 @@ pub trait RotationSystem { fn rotate_right(&self, playfield: &PlayField) -> Result; } +#[derive(Clone)] pub struct SRS { jlstz_offset_data: HashMap>, i_offset_data: HashMap>,