tetris/src/actors/genetic.rs

316 lines
10 KiB
Rust

// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
use super::{Actor, Predictable, State};
use crate::{
cli::Train,
game::{Action, Controllable, Game, Tickable},
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
};
use indicatif::ProgressBar;
use log::debug;
use rand::rngs::SmallRng;
use rand::{seq::SliceRandom, Rng, SeedableRng};
#[derive(Copy, Clone, Debug)]
pub struct Parameters {
total_height: f64,
bumpiness: f64,
holes: f64,
complete_lines: f64,
max_height: f64,
max_well_depth: f64,
}
impl Default for Parameters {
fn default() -> Self {
Self {
total_height: 1.0,
bumpiness: 1.0,
holes: 1.0,
complete_lines: 1.0,
max_height: 1.0,
max_well_depth: 1.0,
}
}
}
impl Parameters {
fn mutate(mut self, rng: &mut SmallRng) {
let mutation_amt = rng.gen_range(-0.2, 0.2);
match rng.gen_range(0, 6) {
0 => self.total_height += mutation_amt,
1 => self.bumpiness += mutation_amt,
2 => self.holes += mutation_amt,
3 => self.complete_lines += mutation_amt,
4 => self.max_height += mutation_amt,
5 => self.max_well_depth += mutation_amt,
_ => unreachable!(),
}
let normalization_factor = (self.total_height.powi(2)
+ self.bumpiness.powi(2)
+ self.holes.powi(2)
+ self.complete_lines.powi(2))
+ self.max_height.powi(2)
+ self.max_well_depth.powi(2).sqrt();
self.total_height /= normalization_factor;
self.bumpiness /= normalization_factor;
self.holes /= normalization_factor;
self.complete_lines /= normalization_factor;
}
fn dot_multiply(&self, other: &Self) -> f64 {
self.total_height * other.total_height
+ self.bumpiness * other.bumpiness
+ self.holes * other.holes
+ self.complete_lines * other.complete_lines
}
}
#[derive(Clone, Copy, Debug)]
pub struct GeneticHeuristicAgent {
params: Parameters,
}
impl Default for GeneticHeuristicAgent {
fn default() -> Self {
let mut rng = SmallRng::from_entropy();
Self {
params: Parameters {
total_height: rng.gen::<f64>(),
bumpiness: rng.gen::<f64>(),
holes: rng.gen::<f64>(),
complete_lines: rng.gen::<f64>(),
max_height: rng.gen::<f64>(),
max_well_depth: rng.gen::<f64>(),
},
}
}
}
impl GeneticHeuristicAgent {
fn extract_features_from_state(state: &State) -> Parameters {
let mut heights = [None; PLAYFIELD_WIDTH];
for r in 0..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if heights[c].is_none() && state.matrix[r][c].is_some() {
heights[c] = Some(PLAYFIELD_HEIGHT - r);
}
}
}
let total_height = heights
.iter()
.map(|o| o.unwrap_or_else(|| 0))
.sum::<usize>() as f64;
let bumpiness = heights
.iter()
.map(|o| o.unwrap_or_else(|| 0) as isize)
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
.0 as f64;
let complete_lines = state
.matrix
.iter()
.map(|row| row.iter().all(Option::is_some))
.map(|c| if c { 1.0 } else { 0.0 })
.sum::<f64>();
let mut holes = 0;
for r in 1..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
holes += 1;
}
}
}
let max_height = heights.iter().max().unwrap().unwrap_or_else(|| 0) as f64;
let mut max_well_height = 0;
for i in 0..heights.len() {
let left = if i == 0 {
20
} else {
heights[i - 1].unwrap_or_else(|| 0)
};
let right = if i == heights.len() - 1 {
20
} else {
heights[i + 1].unwrap_or_else(|| 0)
};
let well_height = if left > right { right } else { left };
max_well_height = *[
max_well_height,
well_height - heights[i].unwrap_or_else(|| 0),
]
.iter()
.max()
.unwrap();
}
Parameters {
total_height,
bumpiness,
complete_lines,
holes: holes as f64,
max_height,
max_well_depth: max_well_height as f64,
}
}
fn get_heuristic(&self, game: &Game, action: &Action) -> f64 {
self.params.dot_multiply(&Self::extract_features_from_state(
&game.get_next_state(*action).into(),
))
}
pub fn breed(&self, self_fitness: u32, other: &Self, other_fitness: u32) -> Self {
let weight = self_fitness + other_fitness;
if weight != 0 {
let self_weight = self_fitness as f64 / weight as f64;
let other_weight = other_fitness as f64 / weight as f64;
Self {
params: Parameters {
total_height: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
bumpiness: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
holes: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
complete_lines: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
max_height: self.params.max_height * self_weight
+ other.params.max_height * other_weight,
max_well_depth: self.params.max_well_depth * self_weight
+ other.params.max_well_depth * other_weight,
},
}
} else {
Self::default()
}
}
fn mutate(&mut self, rng: &mut SmallRng) {
self.params.mutate(rng);
}
}
impl Actor for GeneticHeuristicAgent {
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
let actions = legal_actions
.iter()
.map(|action| {
(
action,
(self.get_heuristic(game, action) * 1_000_000.0) as usize,
)
})
.collect::<Vec<_>>();
let max_val = actions
.iter()
.max_by_key(|(_, heuristic)| heuristic)
.unwrap()
.1;
*actions
.iter()
.filter(|e| e.1 == max_val)
.collect::<Vec<_>>()
.choose(rng)
.unwrap()
.0
}
fn dbg(&self) {
debug!("{:?}", self.params);
}
}
pub async fn train_actor(opts: &Train) -> Box<dyn Actor> {
use std::sync::{Arc, Mutex};
let rng = Arc::new(Mutex::new(SmallRng::from_entropy()));
let mut population = vec![(GeneticHeuristicAgent::default(), 0); 1000];
let total_pb = Arc::new(Mutex::new(ProgressBar::new(
population.len() as u64 * opts.episodes as u64,
)));
for _ in 0..opts.episodes {
let mut new_pop_futs = Vec::with_capacity(population.len());
let new_population = Arc::new(Mutex::new(Vec::with_capacity(population.len())));
let num_rounds = 10;
for i in 0..population.len() {
let rng = rng.clone();
let new_population = new_population.clone();
let total_pb = total_pb.clone();
let agent = population[i].0;
new_pop_futs.push(tokio::task::spawn(async move {
let mut fitness = 0;
for _ in 0..num_rounds {
let mut game = Game::default();
game.set_piece_limit(500);
game.set_time_limit(std::time::Duration::from_secs(60));
while (&game).is_game_over().is_none() {
let mut rng = rng.lock().expect("rng failed");
super::apply_action_to_game(
agent.get_action(
&mut rng,
&game.clone().into(),
&game.get_legal_actions(),
),
&mut game,
);
game.tick();
}
fitness += game.line_clears;
}
new_population.lock().unwrap().push((agent, fitness));
total_pb.lock().expect("progressbar failed").inc(1);
}));
}
futures::future::join_all(new_pop_futs).await;
let mut rng = SmallRng::from_entropy();
let mut new_population = new_population.lock().unwrap().clone();
let new_pop_size = population.len() * 3 / 10;
let mut breeded_population: Vec<(GeneticHeuristicAgent, u32)> =
Vec::with_capacity(new_pop_size);
for _ in 0..3 {
let mut random_selection = new_population
.choose_multiple(&mut rng, new_pop_size / 3)
.collect::<Vec<_>>();
random_selection.sort_unstable_by(|e1, e2| e1.1.cmp(&e2.1));
let best_two = random_selection.iter().rev().take(2).collect::<Vec<_>>();
let parent1 = &best_two[0];
println!("{:?}", &best_two[0]);
let parent2 = &best_two[1];
println!("{:?}", &best_two[1]);
for _ in 0..new_pop_size / 3 {
let breeded = parent1.0.breed(parent1.1, &parent2.0, parent2.1);
let mut cloned = breeded.clone();
if rng.gen::<f64>() < 0.05 {
cloned.mutate(&mut rng);
}
breeded_population.push((cloned, 0));
}
}
new_population.sort_unstable_by_key(|e| e.1);
new_population.splice(..breeded_population.len(), breeded_population);
population = new_population;
}
total_pb.lock().unwrap().finish();
population.sort_unstable_by_key(|e| e.1);
Box::new(population.iter().rev().next().unwrap().0)
}