genetic algo base

master
Edward Shen 2020-04-05 19:36:00 -04:00
parent f3b48fbc85
commit a65f48f585
Signed by: edward
GPG Key ID: 19182661E818369F
6 changed files with 157 additions and 55 deletions

71
src/actors/genetic.rs Normal file
View File

@ -0,0 +1,71 @@
// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
use super::Actor;
use rand::rngs::SmallRng;
use rand::Rng;
pub struct Parameters {
total_height: f64,
bumpiness: f64,
holes: f64,
complete_lines: f64,
}
impl Parameters {
fn mutate(mut self, rng: &mut SmallRng) {
let mutation_amt = rng.gen_range(-0.2, 0.2);
match rng.gen_range(0, 4) {
0 => self.total_height += mutation_amt,
1 => self.bumpiness += mutation_amt,
2 => self.holes += mutation_amt,
3 => self.complete_lines += mutation_amt,
_ => unreachable!(),
}
let normalization_factor = (self.total_height.powi(2)
+ self.bumpiness.powi(2)
+ self.holes.powi(2)
+ self.complete_lines.powi(2))
.sqrt();
self.total_height /= normalization_factor;
self.bumpiness /= normalization_factor;
self.holes /= normalization_factor;
self.complete_lines /= normalization_factor;
}
}
pub struct GeneticHeuristicAgent {}
impl Actor for GeneticHeuristicAgent {
fn get_action(
&self,
rng: &mut SmallRng,
state: &super::State,
legal_actions: &[crate::game::Action],
) -> crate::game::Action {
unimplemented!()
}
fn update(
&mut self,
state: super::State,
action: crate::game::Action,
next_state: super::State,
next_legal_actions: &[crate::game::Action],
reward: f64,
) {
unimplemented!()
}
fn set_learning_rate(&mut self, learning_rate: f64) {
unimplemented!()
}
fn set_exploration_prob(&mut self, exploration_prob: f64) {
unimplemented!()
}
fn set_discount_rate(&mut self, discount_rate: f64) {
unimplemented!()
}
fn dbg(&self) {
unimplemented!()
}
}

View File

@ -1,8 +1,9 @@
use crate::game::{Action, Game}; use crate::game::{Action, Controllable, Game};
use crate::playfield::{Matrix, PlayField}; use crate::playfield::{Matrix, PlayField};
use crate::tetromino::{Tetromino, TetrominoType}; use crate::tetromino::{Tetromino, TetrominoType};
use rand::rngs::SmallRng; use rand::rngs::SmallRng;
pub mod genetic;
pub mod qlearning; pub mod qlearning;
#[derive(Hash, PartialEq, Eq, Clone, Debug)] #[derive(Hash, PartialEq, Eq, Clone, Debug)]
@ -56,3 +57,26 @@ pub trait Actor {
fn dbg(&self); fn dbg(&self);
} }
pub trait Predictable {
fn get_next_state(&self, action: Action) -> Self;
}
impl Predictable for Game {
/// Expensive, performs a full clone.
fn get_next_state(&self, action: Action) -> Self {
let mut game = self.clone();
match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
};
game
}
}

View File

@ -1,12 +1,14 @@
use crate::actors::{Actor, State}; use crate::actors::{Actor, State};
use crate::{ use crate::{
game::Action, game::{Action, Controllable, Game, Tickable},
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
}; };
use log::debug; use indicatif::ProgressIterator;
use log::{debug, info};
use rand::rngs::SmallRng; use rand::rngs::SmallRng;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::Rng; use rand::Rng;
use rand::SeedableRng;
use std::collections::HashMap; use std::collections::HashMap;
pub struct QLearningAgent { pub struct QLearningAgent {
@ -250,3 +252,53 @@ impl Actor for ApproximateQLearning {
dbg!(&self.weights); dbg!(&self.weights);
} }
} }
pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
let mut rng = SmallRng::from_entropy();
let mut avg = 0.0;
for i in (0..episodes).progress() {
if i != 0 && i % (episodes / 10) == 0 {
info!("Last {} scores avg: {}", (episodes / 10), avg);
println!();
avg = 0.0;
}
let mut game = Game::default();
while (&game).is_game_over().is_none() {
let cur_state = (&game).into();
let cur_score = game.score();
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
}
let new_state = (&game).into();
let mut reward = game.score() as f64 - cur_score as f64;
if action != Action::Nothing {
reward -= 0.0;
}
if game.is_game_over().is_some() {
reward = -1.0;
}
let new_legal_actions = game.get_legal_actions();
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
game.tick();
}
avg += game.score() as f64 / (episodes / 10) as f64;
}
actor
}

View File

@ -18,6 +18,8 @@ pub enum LossReason {
LockOut, LockOut,
BlockOut(Position), BlockOut(Position),
} }
#[derive(Clone)]
// Logic is based on 60 ticks / second // Logic is based on 60 ticks / second
pub struct Game { pub struct Game {
playfield: PlayField, playfield: PlayField,
@ -100,6 +102,7 @@ impl Tickable for Game {
} }
} }
#[derive(Clone, Copy)]
enum ClearAction { enum ClearAction {
Single, Single,
Double, Double,
@ -266,7 +269,7 @@ impl Controllable for Game {
active_piece.position = active_piece.position.offset(x, y); active_piece.position = active_piece.position.offset(x, y);
active_piece.rotate_left(); active_piece.rotate_left();
self.playfield.active_piece = Some(active_piece); self.playfield.active_piece = Some(active_piece);
self.update_lock_tick(); // self.update_lock_tick();
} }
Err(_) => (), Err(_) => (),
} }
@ -283,7 +286,7 @@ impl Controllable for Game {
active_piece.position = active_piece.position.offset(x, y); active_piece.position = active_piece.position.offset(x, y);
active_piece.rotate_right(); active_piece.rotate_right();
self.playfield.active_piece = Some(active_piece); self.playfield.active_piece = Some(active_piece);
self.update_lock_tick(); // self.update_lock_tick();
} }
Err(_) => (), Err(_) => (),
} }

View File

@ -6,6 +6,7 @@ use graphics::standard_renderer;
use graphics::COLOR_BACKGROUND; use graphics::COLOR_BACKGROUND;
use indicatif::ProgressIterator; use indicatif::ProgressIterator;
use log::{debug, info, trace}; use log::{debug, info, trace};
use qlearning::train_actor;
use rand::SeedableRng; use rand::SeedableRng;
use sdl2::event::Event; use sdl2::event::Event;
use sdl2::keyboard::Keycode; use sdl2::keyboard::Keycode;
@ -61,56 +62,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
let mut rng = rand::rngs::SmallRng::from_entropy();
let mut avg = 0.0;
for i in (0..episodes).progress() {
if i != 0 && i % (episodes / 10) == 0 {
info!("Last {} scores avg: {}", (episodes / 10), avg);
println!();
avg = 0.0;
}
let mut game = Game::default();
while (&game).is_game_over().is_none() {
let cur_state = (&game).into();
let cur_score = game.score();
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
}
let new_state = (&game).into();
let mut reward = game.score() as f64 - cur_score as f64;
if action != Action::Nothing {
reward -= 0.0;
}
if game.is_game_over().is_some() {
reward = -1.0;
}
let new_legal_actions = game.get_legal_actions();
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
game.tick();
}
avg += game.score() as f64 / (episodes / 10) as f64;
}
actor
}
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> { async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
let mut rng = rand::rngs::SmallRng::from_entropy(); let mut rng = rand::rngs::SmallRng::from_entropy();
let sdl_context = sdl2::init()?; let sdl_context = sdl2::init()?;

View File

@ -10,6 +10,7 @@ pub trait RotationSystem {
fn rotate_right(&self, playfield: &PlayField) -> Result<Position, ()>; fn rotate_right(&self, playfield: &PlayField) -> Result<Position, ()>;
} }
#[derive(Clone)]
pub struct SRS { pub struct SRS {
jlstz_offset_data: HashMap<RotationState, Vec<Position>>, jlstz_offset_data: HashMap<RotationState, Vec<Position>>,
i_offset_data: HashMap<RotationState, Vec<Position>>, i_offset_data: HashMap<RotationState, Vec<Position>>,