// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/ use super::{Actor, State}; use crate::{ cli::Train, game::{Action, Controllable, Game}, playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, }; use rand::rngs::SmallRng; use rand::{seq::SliceRandom, Rng, SeedableRng}; #[derive(Copy, Clone, Debug)] pub struct Parameters { total_height: f64, bumpiness: f64, holes: f64, complete_lines: f64, } impl Default for Parameters { fn default() -> Self { Self { total_height: 1.0, bumpiness: 1.0, holes: 1.0, complete_lines: 1.0, } } } impl Parameters { fn mutate(mut self, rng: &mut SmallRng) { let mutation_amt = rng.gen_range(-0.2, 0.2); match rng.gen_range(0, 4) { 0 => self.total_height += mutation_amt, 1 => self.bumpiness += mutation_amt, 2 => self.holes += mutation_amt, 3 => self.complete_lines += mutation_amt, _ => unreachable!(), } let normalization_factor = (self.total_height.powi(2) + self.bumpiness.powi(2) + self.holes.powi(2) + self.complete_lines.powi(2)) .sqrt(); self.total_height /= normalization_factor; self.bumpiness /= normalization_factor; self.holes /= normalization_factor; self.complete_lines /= normalization_factor; } fn dot_multiply(&self, other: &Self) -> f64 { self.total_height * other.total_height + self.bumpiness * other.bumpiness + self.holes * other.holes + self.complete_lines * other.complete_lines } } #[derive(Clone, Copy)] pub struct GeneticHeuristicAgent { params: Parameters, } impl Default for GeneticHeuristicAgent { fn default() -> Self { Self { params: Parameters::default(), } } } impl GeneticHeuristicAgent { fn extract_features_from_state(state: &State) -> Parameters { let mut heights = [None; PLAYFIELD_WIDTH]; for r in 0..PLAYFIELD_HEIGHT { for c in 0..PLAYFIELD_WIDTH { if heights[c].is_none() && state.matrix[r][c].is_some() { heights[c] = Some(PLAYFIELD_HEIGHT - r); } } } let total_height = heights .iter() .map(|o| o.unwrap_or_else(|| 0)) .sum::() as f64; let bumpiness = heights .iter() .map(|o| o.unwrap_or_else(|| 0) as isize) .fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur)) .0 as f64; let complete_lines = state .matrix .iter() .map(|row| row.iter().all(Option::is_some)) .map(|c| if c { 1.0 } else { 0.0 }) .sum::(); let mut holes = 0; for r in 1..PLAYFIELD_HEIGHT { for c in 0..PLAYFIELD_WIDTH { if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() { holes += 1; } } } Parameters { total_height, bumpiness, complete_lines, holes: holes as f64, } } fn get_heuristic(&self, game: &Game, action: &Action) -> f64 { todo!(); } pub fn breed(&self, self_fitness: u32, other: &Self, other_fitness: u32) -> Self { let weight = (self_fitness + other_fitness) as f64; let self_weight = self_fitness as f64 / weight; let other_weight = other_fitness as f64 / weight; Self { params: Parameters { total_height: self.params.total_height * self_weight + other.params.total_height * other_weight, bumpiness: self.params.total_height * self_weight + other.params.total_height * other_weight, holes: self.params.total_height * self_weight + other.params.total_height * other_weight, complete_lines: self.params.total_height * self_weight + other.params.total_height * other_weight, }, } } fn mutate(&mut self, rng: &mut SmallRng) { self.params.mutate(rng); } } impl Actor for GeneticHeuristicAgent { fn get_action(&self, _: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action { *legal_actions .iter() .map(|action| (action, self.get_heuristic(game, action))) .max_by_key(|(_, heuristic)| (heuristic * 1_000_00.0) as usize) .unwrap() .0 } fn dbg(&self) { // unimplemented!() } } pub fn train_actor(opts: &Train) -> Box { let mut rng = SmallRng::from_entropy(); let mut population = vec![(GeneticHeuristicAgent::default(), 0); 1000]; for _ in 0..opts.episodes { let mut new_population: Vec<(GeneticHeuristicAgent, u32)> = population .iter() .map(|(agent, _)| { let mut fitness = 0; for _ in 0..100 { let mut game = Game::default(); game.set_piece_limit(500); while (&game).is_game_over().is_none() { super::apply_action_to_game( agent.get_action( &mut rng, &game.clone().into(), &game.get_legal_actions(), ), &mut game, ) } fitness += game.line_clears; } (*agent, fitness) }) .collect::>(); let mut breeded_population: Vec<(GeneticHeuristicAgent, u32)> = Vec::with_capacity(300); for _ in 0..3 { let mut random_selection = new_population .choose_multiple(&mut rng, 100) .collect::>(); random_selection.sort_unstable_by(|e1, e2| e1.1.cmp(&e2.1)); let best_two = random_selection.iter().rev().take(2).collect::>(); let parent1 = best_two[0]; let parent2 = best_two[1]; for _ in 0..100 { let breeded = parent1.0.breed(parent1.1, &parent2.0, parent2.1); let mut cloned = breeded.clone(); if rng.gen::() < 0.05 { cloned.mutate(&mut rng); } breeded_population.push((cloned, 0)); } } new_population.sort_unstable_by_key(|e| e.1); new_population.splice(..breeded_population.len(), breeded_population); population = new_population; } population.sort_unstable_by_key(|e| e.1); Box::new(population.iter().rev().next().unwrap().0) }