// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/ use super::{Actor, Predictable, State}; use crate::{ cli::Train, game::{Action, Controllable, Game, Tickable}, playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, }; use indicatif::ProgressBar; use log::debug; use rand::rngs::SmallRng; use rand::{seq::SliceRandom, Rng, SeedableRng}; #[derive(Copy, Clone, Debug)] pub struct Parameters { total_height: f64, bumpiness: f64, holes: f64, complete_lines: f64, max_height: f64, max_well_depth: f64, } impl Default for Parameters { fn default() -> Self { Self { total_height: 1.0, bumpiness: 1.0, holes: 1.0, complete_lines: 1.0, max_height: 1.0, max_well_depth: 1.0, } } } impl Parameters { fn mutate(mut self, rng: &mut SmallRng) { let mutation_amt = rng.gen_range(-0.2, 0.2); match rng.gen_range(0, 6) { 0 => self.total_height += mutation_amt, 1 => self.bumpiness += mutation_amt, 2 => self.holes += mutation_amt, 3 => self.complete_lines += mutation_amt, 4 => self.max_height += mutation_amt, 5 => self.max_well_depth += mutation_amt, _ => unreachable!(), } let normalization_factor = (self.total_height.powi(2) + self.bumpiness.powi(2) + self.holes.powi(2) + self.complete_lines.powi(2)) + self.max_height.powi(2) + self.max_well_depth.powi(2).sqrt(); self.total_height /= normalization_factor; self.bumpiness /= normalization_factor; self.holes /= normalization_factor; self.complete_lines /= normalization_factor; } fn dot_multiply(&self, other: &Self) -> f64 { self.total_height * other.total_height + self.bumpiness * other.bumpiness + self.holes * other.holes + self.complete_lines * other.complete_lines } } #[derive(Clone, Copy, Debug)] pub struct GeneticHeuristicAgent { params: Parameters, } impl Default for GeneticHeuristicAgent { fn default() -> Self { let mut rng = SmallRng::from_entropy(); Self { params: Parameters { total_height: rng.gen::(), bumpiness: rng.gen::(), holes: rng.gen::(), complete_lines: rng.gen::(), max_height: rng.gen::(), max_well_depth: rng.gen::(), }, } } } impl GeneticHeuristicAgent { fn extract_features_from_state(state: &State) -> Parameters { let mut heights = [None; PLAYFIELD_WIDTH]; for r in 0..PLAYFIELD_HEIGHT { for c in 0..PLAYFIELD_WIDTH { if heights[c].is_none() && state.matrix[r][c].is_some() { heights[c] = Some(PLAYFIELD_HEIGHT - r); } } } let total_height = heights .iter() .map(|o| o.unwrap_or_else(|| 0)) .sum::() as f64; let bumpiness = heights .iter() .map(|o| o.unwrap_or_else(|| 0) as isize) .fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur)) .0 as f64; let complete_lines = state .matrix .iter() .map(|row| row.iter().all(Option::is_some)) .map(|c| if c { 1.0 } else { 0.0 }) .sum::(); let mut holes = 0; for r in 1..PLAYFIELD_HEIGHT { for c in 0..PLAYFIELD_WIDTH { if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() { holes += 1; } } } let max_height = heights.iter().max().unwrap().unwrap_or_else(|| 0) as f64; let mut max_well_height = 0; for i in 0..heights.len() { let left = if i == 0 { 20 } else { heights[i - 1].unwrap_or_else(|| 0) }; let right = if i == heights.len() - 1 { 20 } else { heights[i + 1].unwrap_or_else(|| 0) }; let well_height = if left > right { right } else { left }; max_well_height = *[ max_well_height, well_height - heights[i].unwrap_or_else(|| 0), ] .iter() .max() .unwrap(); } Parameters { total_height, bumpiness, complete_lines, holes: holes as f64, max_height, max_well_depth: max_well_height as f64, } } fn get_heuristic(&self, game: &Game, action: &Action) -> f64 { self.params.dot_multiply(&Self::extract_features_from_state( &game.get_next_state(*action).into(), )) } pub fn breed(&self, self_fitness: u32, other: &Self, other_fitness: u32) -> Self { let weight = self_fitness + other_fitness; if weight != 0 { let self_weight = self_fitness as f64 / weight as f64; let other_weight = other_fitness as f64 / weight as f64; Self { params: Parameters { total_height: self.params.total_height * self_weight + other.params.total_height * other_weight, bumpiness: self.params.total_height * self_weight + other.params.total_height * other_weight, holes: self.params.total_height * self_weight + other.params.total_height * other_weight, complete_lines: self.params.total_height * self_weight + other.params.total_height * other_weight, max_height: self.params.max_height * self_weight + other.params.max_height * other_weight, max_well_depth: self.params.max_well_depth * self_weight + other.params.max_well_depth * other_weight, }, } } else { Self::default() } } fn mutate(&mut self, rng: &mut SmallRng) { self.params.mutate(rng); } } impl Actor for GeneticHeuristicAgent { fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action { let actions = legal_actions .iter() .map(|action| { ( action, (self.get_heuristic(game, action) * 1_000_000.0) as usize, ) }) .collect::>(); let max_val = actions .iter() .max_by_key(|(_, heuristic)| heuristic) .unwrap() .1; *actions .iter() .filter(|e| e.1 == max_val) .collect::>() .choose(rng) .unwrap() .0 } fn dbg(&self) { debug!("{:?}", self.params); } } pub async fn train_actor(opts: &Train) -> Box { use std::sync::{Arc, Mutex}; let rng = Arc::new(Mutex::new(SmallRng::from_entropy())); let mut population = vec![(GeneticHeuristicAgent::default(), 0); 1000]; let total_pb = Arc::new(Mutex::new(ProgressBar::new( population.len() as u64 * opts.episodes as u64, ))); for _ in 0..opts.episodes { let mut new_pop_futs = Vec::with_capacity(population.len()); let new_population = Arc::new(Mutex::new(Vec::with_capacity(population.len()))); let num_rounds = 10; for i in 0..population.len() { let rng = rng.clone(); let new_population = new_population.clone(); let total_pb = total_pb.clone(); let agent = population[i].0; new_pop_futs.push(tokio::task::spawn(async move { let mut fitness = 0; for _ in 0..num_rounds { let mut game = Game::default(); game.set_piece_limit(500); game.set_time_limit(std::time::Duration::from_secs(60)); while (&game).is_game_over().is_none() { let mut rng = rng.lock().expect("rng failed"); super::apply_action_to_game( agent.get_action( &mut rng, &game.clone().into(), &game.get_legal_actions(), ), &mut game, ); game.tick(); } fitness += game.line_clears; } new_population.lock().unwrap().push((agent, fitness)); total_pb.lock().expect("progressbar failed").inc(1); })); } futures::future::join_all(new_pop_futs).await; let mut rng = SmallRng::from_entropy(); let mut new_population = new_population.lock().unwrap().clone(); let new_pop_size = population.len() * 3 / 10; let mut breeded_population: Vec<(GeneticHeuristicAgent, u32)> = Vec::with_capacity(new_pop_size); for _ in 0..3 { let mut random_selection = new_population .choose_multiple(&mut rng, new_pop_size / 3) .collect::>(); random_selection.sort_unstable_by(|e1, e2| e1.1.cmp(&e2.1)); let best_two = random_selection.iter().rev().take(2).collect::>(); let parent1 = &best_two[0]; println!("{:?}", &best_two[0]); let parent2 = &best_two[1]; println!("{:?}", &best_two[1]); for _ in 0..new_pop_size / 3 { let breeded = parent1.0.breed(parent1.1, &parent2.0, parent2.1); let mut cloned = breeded.clone(); if rng.gen::() < 0.05 { cloned.mutate(&mut rng); } breeded_population.push((cloned, 0)); } } new_population.sort_unstable_by_key(|e| e.1); new_population.splice(..breeded_population.len(), breeded_population); population = new_population; } total_pb.lock().unwrap().finish(); population.sort_unstable_by_key(|e| e.1); Box::new(population.iter().rev().next().unwrap().0) }