genetic algo base

This commit is contained in:
Edward Shen 2020-04-05 19:36:00 -04:00
parent f3b48fbc85
commit a65f48f585
Signed by: edward
GPG key ID: 19182661E818369F
6 changed files with 157 additions and 55 deletions

71
src/actors/genetic.rs Normal file
View file

@ -0,0 +1,71 @@
// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
use super::Actor;
use rand::rngs::SmallRng;
use rand::Rng;
pub struct Parameters {
total_height: f64,
bumpiness: f64,
holes: f64,
complete_lines: f64,
}
impl Parameters {
fn mutate(mut self, rng: &mut SmallRng) {
let mutation_amt = rng.gen_range(-0.2, 0.2);
match rng.gen_range(0, 4) {
0 => self.total_height += mutation_amt,
1 => self.bumpiness += mutation_amt,
2 => self.holes += mutation_amt,
3 => self.complete_lines += mutation_amt,
_ => unreachable!(),
}
let normalization_factor = (self.total_height.powi(2)
+ self.bumpiness.powi(2)
+ self.holes.powi(2)
+ self.complete_lines.powi(2))
.sqrt();
self.total_height /= normalization_factor;
self.bumpiness /= normalization_factor;
self.holes /= normalization_factor;
self.complete_lines /= normalization_factor;
}
}
pub struct GeneticHeuristicAgent {}
impl Actor for GeneticHeuristicAgent {
fn get_action(
&self,
rng: &mut SmallRng,
state: &super::State,
legal_actions: &[crate::game::Action],
) -> crate::game::Action {
unimplemented!()
}
fn update(
&mut self,
state: super::State,
action: crate::game::Action,
next_state: super::State,
next_legal_actions: &[crate::game::Action],
reward: f64,
) {
unimplemented!()
}
fn set_learning_rate(&mut self, learning_rate: f64) {
unimplemented!()
}
fn set_exploration_prob(&mut self, exploration_prob: f64) {
unimplemented!()
}
fn set_discount_rate(&mut self, discount_rate: f64) {
unimplemented!()
}
fn dbg(&self) {
unimplemented!()
}
}

View file

@ -1,8 +1,9 @@
use crate::game::{Action, Game};
use crate::game::{Action, Controllable, Game};
use crate::playfield::{Matrix, PlayField};
use crate::tetromino::{Tetromino, TetrominoType};
use rand::rngs::SmallRng;
pub mod genetic;
pub mod qlearning;
#[derive(Hash, PartialEq, Eq, Clone, Debug)]
@ -56,3 +57,26 @@ pub trait Actor {
fn dbg(&self);
}
pub trait Predictable {
fn get_next_state(&self, action: Action) -> Self;
}
impl Predictable for Game {
/// Expensive, performs a full clone.
fn get_next_state(&self, action: Action) -> Self {
let mut game = self.clone();
match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
};
game
}
}

View file

@ -1,12 +1,14 @@
use crate::actors::{Actor, State};
use crate::{
game::Action,
game::{Action, Controllable, Game, Tickable},
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
};
use log::debug;
use indicatif::ProgressIterator;
use log::{debug, info};
use rand::rngs::SmallRng;
use rand::seq::SliceRandom;
use rand::Rng;
use rand::SeedableRng;
use std::collections::HashMap;
pub struct QLearningAgent {
@ -250,3 +252,53 @@ impl Actor for ApproximateQLearning {
dbg!(&self.weights);
}
}
pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
let mut rng = SmallRng::from_entropy();
let mut avg = 0.0;
for i in (0..episodes).progress() {
if i != 0 && i % (episodes / 10) == 0 {
info!("Last {} scores avg: {}", (episodes / 10), avg);
println!();
avg = 0.0;
}
let mut game = Game::default();
while (&game).is_game_over().is_none() {
let cur_state = (&game).into();
let cur_score = game.score();
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
}
let new_state = (&game).into();
let mut reward = game.score() as f64 - cur_score as f64;
if action != Action::Nothing {
reward -= 0.0;
}
if game.is_game_over().is_some() {
reward = -1.0;
}
let new_legal_actions = game.get_legal_actions();
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
game.tick();
}
avg += game.score() as f64 / (episodes / 10) as f64;
}
actor
}

View file

@ -18,6 +18,8 @@ pub enum LossReason {
LockOut,
BlockOut(Position),
}
#[derive(Clone)]
// Logic is based on 60 ticks / second
pub struct Game {
playfield: PlayField,
@ -100,6 +102,7 @@ impl Tickable for Game {
}
}
#[derive(Clone, Copy)]
enum ClearAction {
Single,
Double,
@ -266,7 +269,7 @@ impl Controllable for Game {
active_piece.position = active_piece.position.offset(x, y);
active_piece.rotate_left();
self.playfield.active_piece = Some(active_piece);
self.update_lock_tick();
// self.update_lock_tick();
}
Err(_) => (),
}
@ -283,7 +286,7 @@ impl Controllable for Game {
active_piece.position = active_piece.position.offset(x, y);
active_piece.rotate_right();
self.playfield.active_piece = Some(active_piece);
self.update_lock_tick();
// self.update_lock_tick();
}
Err(_) => (),
}

View file

@ -6,6 +6,7 @@ use graphics::standard_renderer;
use graphics::COLOR_BACKGROUND;
use indicatif::ProgressIterator;
use log::{debug, info, trace};
use qlearning::train_actor;
use rand::SeedableRng;
use sdl2::event::Event;
use sdl2::keyboard::Keycode;
@ -61,56 +62,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
let mut rng = rand::rngs::SmallRng::from_entropy();
let mut avg = 0.0;
for i in (0..episodes).progress() {
if i != 0 && i % (episodes / 10) == 0 {
info!("Last {} scores avg: {}", (episodes / 10), avg);
println!();
avg = 0.0;
}
let mut game = Game::default();
while (&game).is_game_over().is_none() {
let cur_state = (&game).into();
let cur_score = game.score();
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
}
let new_state = (&game).into();
let mut reward = game.score() as f64 - cur_score as f64;
if action != Action::Nothing {
reward -= 0.0;
}
if game.is_game_over().is_some() {
reward = -1.0;
}
let new_legal_actions = game.get_legal_actions();
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
game.tick();
}
avg += game.score() as f64 / (episodes / 10) as f64;
}
actor
}
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
let mut rng = rand::rngs::SmallRng::from_entropy();
let sdl_context = sdl2::init()?;

View file

@ -10,6 +10,7 @@ pub trait RotationSystem {
fn rotate_right(&self, playfield: &PlayField) -> Result<Position, ()>;
}
#[derive(Clone)]
pub struct SRS {
jlstz_offset_data: HashMap<RotationState, Vec<Position>>,
i_offset_data: HashMap<RotationState, Vec<Position>>,