use super::Predictable; use crate::actors::{Actor, State}; use crate::{ cli::Train, game::{Action, Controllable, Game, Tickable}, playfield::{Matrix, PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, }; use log::{debug, error, info, trace}; use rand::rngs::SmallRng; use rand::seq::SliceRandom; use rand::Rng; use rand::SeedableRng; use std::collections::HashMap; pub trait QLearningActor: Actor { fn update( &mut self, game_state: Game, action: Action, next_game_state: Game, next_legal_actions: &[Action], reward: f64, ); fn set_learning_rate(&mut self, learning_rate: f64); fn set_exploration_prob(&mut self, exploration_prob: f64); fn set_discount_rate(&mut self, discount_rate: f64); } #[derive(Debug)] pub struct QLearningAgent { pub learning_rate: f64, pub exploration_prob: f64, pub discount_rate: f64, q_values: HashMap>, } impl Default for QLearningAgent { fn default() -> Self { Self { learning_rate: 0.0, exploration_prob: 0.0, discount_rate: 0.0, q_values: HashMap::default(), } } } impl QLearningAgent { fn get_q_value(&self, state: &State, action: Action) -> f64 { match self.q_values.get(&state) { Some(action_qval) => *action_qval.get(&action).unwrap_or_else(|| &0.0), None => 0.0, } } fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action { let legal_actions = legal_actions .iter() .map(|action| (action, self.get_q_value(state, *action))) .collect::>(); let max_val = legal_actions .iter() .max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize)) .expect("Failed to select an action") .1; let actions_to_choose = legal_actions .iter() .filter(|(_, v)| max_val == *v) .collect::>(); if actions_to_choose.len() != 1 { trace!( "more than one best option, choosing randomly: {:?}", actions_to_choose ); } *actions_to_choose .choose(&mut SmallRng::from_entropy()) .unwrap() .0 } fn get_value_from_q_values(&self, state: &State) -> f64 { *self .q_values .get(state) .and_then(|hashmap| { hashmap .values() .max_by_key(|q_val| (**q_val * 1_000_000.0) as isize) .or_else(|| Some(&0.0)) }) .unwrap_or_else(|| &0.0) } } impl Actor for QLearningAgent { // Because doing (Nothing) is in the set of legal actions, this will never // be empty fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action { if rng.gen::() < self.exploration_prob { *legal_actions.choose(rng).unwrap() } else { self.get_action_from_q_values(&game.into(), legal_actions) } } fn dbg(&self) { debug!("Total states: {}", self.q_values.len()); } } impl QLearningActor for QLearningAgent { fn update( &mut self, game_state: Game, action: Action, next_game_state: Game, _next_legal_actions: &[Action], reward: f64, ) { let state = (&game_state).into(); let next_state = (&next_game_state).into(); let cur_q_val = self.get_q_value(&state, action); let new_q_val = cur_q_val + self.learning_rate * (reward + self.discount_rate * self.get_value_from_q_values(&next_state) - cur_q_val); if !self.q_values.contains_key(&state) { self.q_values.insert(state.clone(), HashMap::default()); } self.q_values .get_mut(&state) .unwrap() .insert(action, new_q_val); } fn set_learning_rate(&mut self, learning_rate: f64) { self.learning_rate = learning_rate; } fn set_exploration_prob(&mut self, exploration_prob: f64) { self.exploration_prob = exploration_prob; } fn set_discount_rate(&mut self, discount_rate: f64) { self.discount_rate = discount_rate; } } #[derive(Debug)] pub struct ApproximateQLearning { pub learning_rate: f64, pub exploration_prob: f64, pub discount_rate: f64, weights: HashMap, } impl Default for ApproximateQLearning { fn default() -> Self { let mut weights = HashMap::default(); weights.insert(Feature::TotalHeight, 1.0); weights.insert(Feature::Bumpiness, 1.0); weights.insert(Feature::LinesCleared, 1.0); weights.insert(Feature::Holes, 1.0); Self { learning_rate: 0.0, exploration_prob: 0.0, discount_rate: 0.0, weights, } } } #[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)] enum Feature { TotalHeight, Bumpiness, LinesCleared, Holes, } impl ApproximateQLearning { fn get_features(&self, game: &Game, action: &Action) -> HashMap { let game = game.get_next_state(*action); let mut features = HashMap::default(); let field = game.playfield().field(); let heights = self.get_heights(field); features.insert( Feature::TotalHeight, heights.iter().sum::() as f64 / (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64, ); features.insert( Feature::Bumpiness, heights .iter() .fold((0, 0), |(acc, prev), cur| { (acc + (prev as isize - *cur as isize).abs(), *cur) }) .0 as f64 / (PLAYFIELD_WIDTH * 40) as f64, ); features.insert( Feature::LinesCleared, game.playfield() .field() .iter() .map(|r| r.iter().all(Option::is_some)) .map(|r| if r { 1 } else { 0 }) .sum::() as f64 / 4.0, ); let mut holes = 0; for r in 1..PLAYFIELD_HEIGHT { for c in 0..PLAYFIELD_WIDTH { if field[r][c].is_none() && field[r - 1][c].is_some() { holes += 1; } } } features.insert(Feature::Holes, holes as f64); features } fn get_heights(&self, matrix: &Matrix) -> Vec { let mut heights = vec![0; matrix[0].len()]; for r in 0..matrix.len() { for c in 0..matrix[0].len() { if heights[c] == 0 && matrix[r][c].is_some() { heights[c] = matrix.len() - r; } } } heights } fn get_q_value(&self, game: &Game, action: &Action) -> f64 { self.get_features(game, action) .iter() .map(|(key, val)| val * *self.weights.get(key).unwrap()) .sum() } fn get_action_from_q_values(&self, game: &Game) -> Action { let legal_actions = game.get_legal_actions(); let legal_actions = legal_actions .iter() .map(|action| (action, self.get_q_value(game, action))) .collect::>(); // dbg!(&legal_actions); let max_val = legal_actions .iter() .max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize)) .expect("Failed to select an action") .1; let actions_to_choose = legal_actions .iter() .filter(|(_, v)| max_val == *v) .collect::>(); if actions_to_choose.len() != 1 { trace!( "more than one best option, choosing randomly: {:?}", actions_to_choose ); } let action = actions_to_choose.choose(&mut SmallRng::from_entropy()); match action { Some(a) => *a.0, None => { dbg!(&legal_actions); dbg!(&actions_to_choose); panic!("wtf???"); } } } fn get_value(&self, game: &Game) -> f64 { game.get_legal_actions() .iter() .map(|action| self.get_q_value(game, action)) .max_by_key(|v| (v * 1_000_000.0) as isize) .unwrap_or_else(|| 0.0) } } #[cfg(test)] mod aaa { use super::*; use crate::tetromino::TetrominoType; #[test] fn test_height() { let agent = ApproximateQLearning::default(); let matrix = vec![ vec![None, None, None], vec![None, Some(TetrominoType::T), None], vec![None, None, None], ]; assert_eq!(agent.get_heights(&matrix), vec![0, 2, 0]); } } impl Actor for ApproximateQLearning { fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action { if rng.gen::() < self.exploration_prob { *legal_actions.choose(rng).unwrap() } else { self.get_action_from_q_values(game) } } fn dbg(&self) { dbg!(&self.weights); } } impl QLearningActor for ApproximateQLearning { fn update( &mut self, game_state: Game, action: Action, next_game_state: Game, _: &[Action], reward: f64, ) { let difference = reward + self.discount_rate * self.get_value(&next_game_state) - self.get_q_value(&game_state, &action); for (feat_key, feat_val) in self.get_features(&game_state, &action) { self.weights.insert( feat_key.clone(), *self.weights.get(&feat_key).unwrap() + self.learning_rate * difference * feat_val, ); } } fn set_learning_rate(&mut self, learning_rate: f64) { self.learning_rate = learning_rate; } fn set_exploration_prob(&mut self, exploration_prob: f64) { self.exploration_prob = exploration_prob; } fn set_discount_rate(&mut self, discount_rate: f64) { self.discount_rate = discount_rate; } } pub fn train_actor( mut actor: T, opts: &Train, ) -> Box { let mut rng = SmallRng::from_entropy(); let mut avg = 0.0; let episodes = opts.episodes; actor.set_learning_rate(opts.learning_rate); actor.set_discount_rate(opts.discount_rate); actor.set_exploration_prob(opts.exploration_prob); info!( "Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}", opts.learning_rate, opts.discount_rate, opts.exploration_prob ); for i in 0..episodes { if i != 0 && i % (episodes / 10) == 0 { println!("{}", avg); eprintln!("iteration {}", i); // println!("{:?}", &actor); avg = 0.0; } let mut game = Game::default(); while (&game).is_game_over().is_none() { let cur_state = game.clone(); let cur_score = game.score(); let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())); super::apply_action_to_game(action, &mut game); let new_state = game.clone(); let mut reward = game.score() as f64 - cur_score as f64; // if action != Action::Nothing { // reward -= 0.0; // } if game.is_game_over().is_some() { reward = -1.0; } let new_legal_actions = game.get_legal_actions(); actor.update( cur_state.into(), action, new_state, &new_legal_actions, reward, ); game.tick(); } avg += game.score() as f64 / (episodes / 10) as f64; } if opts.no_explore_during_evaluation { actor.set_exploration_prob(0.0); } if opts.no_learn_during_evaluation { actor.set_learning_rate(0.0); } Box::new(actor) }