use crate::actors::{Actor, State}; use crate::{ cli::Train, game::{Action, Controllable, Game, Tickable}, playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, }; use indicatif::ProgressIterator; use log::{debug, info}; use rand::rngs::SmallRng; use rand::seq::SliceRandom; use rand::Rng; use rand::SeedableRng; use std::collections::HashMap; pub trait QLearningActor: Actor { fn update( &mut self, state: State, action: Action, next_state: State, next_legal_actions: &[Action], reward: f64, ); fn set_learning_rate(&mut self, learning_rate: f64); fn set_exploration_prob(&mut self, exploration_prob: f64); fn set_discount_rate(&mut self, discount_rate: f64); } pub struct QLearningAgent { pub learning_rate: f64, pub exploration_prob: f64, pub discount_rate: f64, q_values: HashMap>, } impl Default for QLearningAgent { fn default() -> Self { Self { learning_rate: 0.0, exploration_prob: 0.0, discount_rate: 0.0, q_values: HashMap::default(), } } } impl QLearningAgent { fn get_q_value(&self, state: &State, action: Action) -> f64 { match self.q_values.get(&state) { Some(action_qval) => *action_qval.get(&action).unwrap_or_else(|| &0.0), None => 0.0, } } fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action { *legal_actions .iter() .map(|action| (action, self.get_q_value(&state, *action))) .max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize)) .expect("Failed to select an action") .0 } fn get_value_from_q_values(&self, state: &State) -> f64 { *self .q_values .get(state) .and_then(|hashmap| { hashmap .values() .max_by_key(|q_val| (**q_val * 1_000_000.0) as isize) .or_else(|| Some(&0.0)) }) .unwrap_or_else(|| &0.0) } } impl Actor for QLearningAgent { // Because doing (Nothing) is in the set of legal actions, this will never // be empty fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action { if rng.gen::() < self.exploration_prob { *legal_actions.choose(rng).unwrap() } else { self.get_action_from_q_values(&game.into(), legal_actions) } } fn dbg(&self) { debug!("Total states: {}", self.q_values.len()); } } impl QLearningActor for QLearningAgent { fn update( &mut self, state: State, action: Action, next_state: State, _next_legal_actions: &[Action], reward: f64, ) { let cur_q_val = self.get_q_value(&state, action); let new_q_val = cur_q_val + self.learning_rate * (reward + self.discount_rate * self.get_value_from_q_values(&next_state) - cur_q_val); if !self.q_values.contains_key(&state) { self.q_values.insert(state.clone(), HashMap::default()); } self.q_values .get_mut(&state) .unwrap() .insert(action, new_q_val); } fn set_learning_rate(&mut self, learning_rate: f64) { self.learning_rate = learning_rate; } fn set_exploration_prob(&mut self, exploration_prob: f64) { self.exploration_prob = exploration_prob; } fn set_discount_rate(&mut self, discount_rate: f64) { self.discount_rate = discount_rate; } } pub struct ApproximateQLearning { pub learning_rate: f64, pub exploration_prob: f64, pub discount_rate: f64, weights: HashMap, } impl Default for ApproximateQLearning { fn default() -> Self { Self { learning_rate: 0.0, exploration_prob: 0.0, discount_rate: 0.0, weights: HashMap::default(), } } } impl ApproximateQLearning { fn get_features( &self, state: &State, _action: &Action, new_state: &State, ) -> HashMap { let mut features = HashMap::default(); let mut heights = [None; PLAYFIELD_WIDTH]; for r in 0..PLAYFIELD_HEIGHT { for c in 0..PLAYFIELD_WIDTH { if heights[c].is_none() && state.matrix[r][c].is_some() { heights[c] = Some(PLAYFIELD_HEIGHT - r); } } } features.insert( "Total Height".into(), heights .iter() .map(|o| o.unwrap_or_else(|| 0)) .sum::() as f64 / (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64, ); features.insert( "Bumpiness".into(), heights .iter() .map(|o| o.unwrap_or_else(|| 0) as isize) .fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur)) .0 as f64 / (PLAYFIELD_WIDTH * 40) as f64, ); features.insert( "Lines cleared".into(), (new_state.line_clears - state.line_clears) as f64 / 4.0, ); let mut holes = 0; for r in 1..PLAYFIELD_HEIGHT { for c in 0..PLAYFIELD_WIDTH { if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() { holes += 1; } } } features.insert("Holes".into(), holes as f64); features } fn get_q_value(&self, state: &State, action: &Action, next_state: &State) -> f64 { self.get_features(state, action, next_state) .iter() .map(|(key, val)| val * *self.weights.get(key).unwrap_or_else(|| &0.0)) .sum() } fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action { *legal_actions .iter() .map(|action| (action, self.get_q_value(&state, action, state))) .max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize)) .expect("Failed to select an action") .0 } fn get_value(&self, state: &State, legal_actions: &[Action]) -> f64 { legal_actions .iter() .map(|action| self.get_q_value(state, action, state)) .max_by_key(|v| (v * 1_000_000.0) as isize) .unwrap_or_else(|| 0.0) } } impl Actor for ApproximateQLearning { fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action { if rng.gen::() < self.exploration_prob { *legal_actions.choose(rng).unwrap() } else { self.get_action_from_q_values(&game.into(), legal_actions) } } fn dbg(&self) { dbg!(&self.weights); } } impl QLearningActor for ApproximateQLearning { fn update( &mut self, state: State, action: Action, next_state: State, next_legal_actions: &[Action], reward: f64, ) { let difference = reward + self.discount_rate * self.get_value(&next_state, next_legal_actions) - self.get_q_value(&state, &action, &next_state); for (feat_key, feat_val) in self.get_features(&state, &action, &next_state) { self.weights.insert( feat_key.clone(), *self.weights.get(&feat_key).unwrap_or_else(|| &0.0) + self.learning_rate * difference * feat_val, ); } } fn set_learning_rate(&mut self, learning_rate: f64) { self.learning_rate = learning_rate; } fn set_exploration_prob(&mut self, exploration_prob: f64) { self.exploration_prob = exploration_prob; } fn set_discount_rate(&mut self, discount_rate: f64) { self.discount_rate = discount_rate; } } pub fn train_actor( mut actor: T, opts: &Train, ) -> Box { let mut rng = SmallRng::from_entropy(); let mut avg = 0.0; let episodes = opts.episodes; actor.set_learning_rate(opts.learning_rate); actor.set_discount_rate(opts.discount_rate); actor.set_exploration_prob(opts.exploration_prob); info!( "Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}", opts.learning_rate, opts.discount_rate, opts.exploration_prob ); for i in (0..episodes).progress() { if i != 0 && i % (episodes / 10) == 0 { info!("Last {} scores avg: {}", (episodes / 10), avg); println!(); avg = 0.0; } let mut game = Game::default(); while (&game).is_game_over().is_none() { let cur_state = game.clone(); let cur_score = game.score(); let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions())); super::apply_action_to_game(action, &mut game); let new_state = (&game).into(); let mut reward = game.score() as f64 - cur_score as f64; if action != Action::Nothing { reward -= 0.0; } if game.is_game_over().is_some() { reward = -1.0; } let new_legal_actions = game.get_legal_actions(); actor.update( cur_state.into(), action, new_state, &new_legal_actions, reward, ); game.tick(); } avg += game.score() as f64 / (episodes / 10) as f64; } if opts.no_explore_during_evaluation { actor.set_exploration_prob(0.0); } if opts.no_learn_during_evaluation { actor.set_learning_rate(0.0); } Box::new(actor) }