genetic algo base
This commit is contained in:
parent
f3b48fbc85
commit
a65f48f585
6 changed files with 157 additions and 55 deletions
71
src/actors/genetic.rs
Normal file
71
src/actors/genetic.rs
Normal file
|
@ -0,0 +1,71 @@
|
|||
// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
|
||||
|
||||
use super::Actor;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::Rng;
|
||||
|
||||
pub struct Parameters {
|
||||
total_height: f64,
|
||||
bumpiness: f64,
|
||||
holes: f64,
|
||||
complete_lines: f64,
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
fn mutate(mut self, rng: &mut SmallRng) {
|
||||
let mutation_amt = rng.gen_range(-0.2, 0.2);
|
||||
match rng.gen_range(0, 4) {
|
||||
0 => self.total_height += mutation_amt,
|
||||
1 => self.bumpiness += mutation_amt,
|
||||
2 => self.holes += mutation_amt,
|
||||
3 => self.complete_lines += mutation_amt,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let normalization_factor = (self.total_height.powi(2)
|
||||
+ self.bumpiness.powi(2)
|
||||
+ self.holes.powi(2)
|
||||
+ self.complete_lines.powi(2))
|
||||
.sqrt();
|
||||
|
||||
self.total_height /= normalization_factor;
|
||||
self.bumpiness /= normalization_factor;
|
||||
self.holes /= normalization_factor;
|
||||
self.complete_lines /= normalization_factor;
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GeneticHeuristicAgent {}
|
||||
|
||||
impl Actor for GeneticHeuristicAgent {
|
||||
fn get_action(
|
||||
&self,
|
||||
rng: &mut SmallRng,
|
||||
state: &super::State,
|
||||
legal_actions: &[crate::game::Action],
|
||||
) -> crate::game::Action {
|
||||
unimplemented!()
|
||||
}
|
||||
fn update(
|
||||
&mut self,
|
||||
state: super::State,
|
||||
action: crate::game::Action,
|
||||
next_state: super::State,
|
||||
next_legal_actions: &[crate::game::Action],
|
||||
reward: f64,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
fn set_learning_rate(&mut self, learning_rate: f64) {
|
||||
unimplemented!()
|
||||
}
|
||||
fn set_exploration_prob(&mut self, exploration_prob: f64) {
|
||||
unimplemented!()
|
||||
}
|
||||
fn set_discount_rate(&mut self, discount_rate: f64) {
|
||||
unimplemented!()
|
||||
}
|
||||
fn dbg(&self) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
|
@ -1,8 +1,9 @@
|
|||
use crate::game::{Action, Game};
|
||||
use crate::game::{Action, Controllable, Game};
|
||||
use crate::playfield::{Matrix, PlayField};
|
||||
use crate::tetromino::{Tetromino, TetrominoType};
|
||||
use rand::rngs::SmallRng;
|
||||
|
||||
pub mod genetic;
|
||||
pub mod qlearning;
|
||||
|
||||
#[derive(Hash, PartialEq, Eq, Clone, Debug)]
|
||||
|
@ -56,3 +57,26 @@ pub trait Actor {
|
|||
|
||||
fn dbg(&self);
|
||||
}
|
||||
|
||||
pub trait Predictable {
|
||||
fn get_next_state(&self, action: Action) -> Self;
|
||||
}
|
||||
|
||||
impl Predictable for Game {
|
||||
/// Expensive, performs a full clone.
|
||||
fn get_next_state(&self, action: Action) -> Self {
|
||||
let mut game = self.clone();
|
||||
match action {
|
||||
Action::Nothing => (),
|
||||
Action::MoveLeft => game.move_left(),
|
||||
Action::MoveRight => game.move_right(),
|
||||
Action::SoftDrop => game.move_down(),
|
||||
Action::HardDrop => game.hard_drop(),
|
||||
Action::Hold => game.hold(),
|
||||
Action::RotateLeft => game.rotate_left(),
|
||||
Action::RotateRight => game.rotate_right(),
|
||||
};
|
||||
|
||||
game
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
use crate::actors::{Actor, State};
|
||||
use crate::{
|
||||
game::Action,
|
||||
game::{Action, Controllable, Game, Tickable},
|
||||
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
||||
};
|
||||
use log::debug;
|
||||
use indicatif::ProgressIterator;
|
||||
use log::{debug, info};
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
use rand::SeedableRng;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub struct QLearningAgent {
|
||||
|
@ -250,3 +252,53 @@ impl Actor for ApproximateQLearning {
|
|||
dbg!(&self.weights);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
|
||||
let mut rng = SmallRng::from_entropy();
|
||||
let mut avg = 0.0;
|
||||
|
||||
for i in (0..episodes).progress() {
|
||||
if i != 0 && i % (episodes / 10) == 0 {
|
||||
info!("Last {} scores avg: {}", (episodes / 10), avg);
|
||||
println!();
|
||||
avg = 0.0;
|
||||
}
|
||||
let mut game = Game::default();
|
||||
while (&game).is_game_over().is_none() {
|
||||
let cur_state = (&game).into();
|
||||
let cur_score = game.score();
|
||||
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
|
||||
|
||||
match action {
|
||||
Action::Nothing => (),
|
||||
Action::MoveLeft => game.move_left(),
|
||||
Action::MoveRight => game.move_right(),
|
||||
Action::SoftDrop => game.move_down(),
|
||||
Action::HardDrop => game.hard_drop(),
|
||||
Action::Hold => game.hold(),
|
||||
Action::RotateLeft => game.rotate_left(),
|
||||
Action::RotateRight => game.rotate_right(),
|
||||
}
|
||||
|
||||
let new_state = (&game).into();
|
||||
let mut reward = game.score() as f64 - cur_score as f64;
|
||||
if action != Action::Nothing {
|
||||
reward -= 0.0;
|
||||
}
|
||||
|
||||
if game.is_game_over().is_some() {
|
||||
reward = -1.0;
|
||||
}
|
||||
|
||||
let new_legal_actions = game.get_legal_actions();
|
||||
|
||||
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
|
||||
|
||||
game.tick();
|
||||
}
|
||||
|
||||
avg += game.score() as f64 / (episodes / 10) as f64;
|
||||
}
|
||||
|
||||
actor
|
||||
}
|
||||
|
|
|
@ -18,6 +18,8 @@ pub enum LossReason {
|
|||
LockOut,
|
||||
BlockOut(Position),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
// Logic is based on 60 ticks / second
|
||||
pub struct Game {
|
||||
playfield: PlayField,
|
||||
|
@ -100,6 +102,7 @@ impl Tickable for Game {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum ClearAction {
|
||||
Single,
|
||||
Double,
|
||||
|
@ -266,7 +269,7 @@ impl Controllable for Game {
|
|||
active_piece.position = active_piece.position.offset(x, y);
|
||||
active_piece.rotate_left();
|
||||
self.playfield.active_piece = Some(active_piece);
|
||||
self.update_lock_tick();
|
||||
// self.update_lock_tick();
|
||||
}
|
||||
Err(_) => (),
|
||||
}
|
||||
|
@ -283,7 +286,7 @@ impl Controllable for Game {
|
|||
active_piece.position = active_piece.position.offset(x, y);
|
||||
active_piece.rotate_right();
|
||||
self.playfield.active_piece = Some(active_piece);
|
||||
self.update_lock_tick();
|
||||
// self.update_lock_tick();
|
||||
}
|
||||
Err(_) => (),
|
||||
}
|
||||
|
|
51
src/main.rs
51
src/main.rs
|
@ -6,6 +6,7 @@ use graphics::standard_renderer;
|
|||
use graphics::COLOR_BACKGROUND;
|
||||
use indicatif::ProgressIterator;
|
||||
use log::{debug, info, trace};
|
||||
use qlearning::train_actor;
|
||||
use rand::SeedableRng;
|
||||
use sdl2::event::Event;
|
||||
use sdl2::keyboard::Keycode;
|
||||
|
@ -61,56 +62,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
|
||||
let mut rng = rand::rngs::SmallRng::from_entropy();
|
||||
let mut avg = 0.0;
|
||||
|
||||
for i in (0..episodes).progress() {
|
||||
if i != 0 && i % (episodes / 10) == 0 {
|
||||
info!("Last {} scores avg: {}", (episodes / 10), avg);
|
||||
println!();
|
||||
avg = 0.0;
|
||||
}
|
||||
let mut game = Game::default();
|
||||
while (&game).is_game_over().is_none() {
|
||||
let cur_state = (&game).into();
|
||||
let cur_score = game.score();
|
||||
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
|
||||
|
||||
match action {
|
||||
Action::Nothing => (),
|
||||
Action::MoveLeft => game.move_left(),
|
||||
Action::MoveRight => game.move_right(),
|
||||
Action::SoftDrop => game.move_down(),
|
||||
Action::HardDrop => game.hard_drop(),
|
||||
Action::Hold => game.hold(),
|
||||
Action::RotateLeft => game.rotate_left(),
|
||||
Action::RotateRight => game.rotate_right(),
|
||||
}
|
||||
|
||||
let new_state = (&game).into();
|
||||
let mut reward = game.score() as f64 - cur_score as f64;
|
||||
if action != Action::Nothing {
|
||||
reward -= 0.0;
|
||||
}
|
||||
|
||||
if game.is_game_over().is_some() {
|
||||
reward = -1.0;
|
||||
}
|
||||
|
||||
let new_legal_actions = game.get_legal_actions();
|
||||
|
||||
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
|
||||
|
||||
game.tick();
|
||||
}
|
||||
|
||||
avg += game.score() as f64 / (episodes / 10) as f64;
|
||||
}
|
||||
|
||||
actor
|
||||
}
|
||||
|
||||
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut rng = rand::rngs::SmallRng::from_entropy();
|
||||
let sdl_context = sdl2::init()?;
|
||||
|
|
|
@ -10,6 +10,7 @@ pub trait RotationSystem {
|
|||
fn rotate_right(&self, playfield: &PlayField) -> Result<Position, ()>;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SRS {
|
||||
jlstz_offset_data: HashMap<RotationState, Vec<Position>>,
|
||||
i_offset_data: HashMap<RotationState, Vec<Position>>,
|
||||
|
|
Loading…
Reference in a new issue