diff --git a/src/actors/mod.rs b/src/actors/mod.rs index e4d17da..d25e7b9 100644 --- a/src/actors/mod.rs +++ b/src/actors/mod.rs @@ -10,6 +10,7 @@ pub struct State { matrix: Matrix, active_piece: Option, held_piece: Option, + line_clears: u32, } impl From for State { @@ -20,7 +21,9 @@ impl From for State { impl From<&Game> for State { fn from(game: &Game) -> Self { - game.playfield().clone().into() + let mut state: State = game.playfield().clone().into(); + state.line_clears = game.line_clears; + state } } @@ -30,6 +33,7 @@ impl From for State { matrix: playfield.field().clone(), active_piece: playfield.active_piece, held_piece: playfield.hold_piece().map(|t| t.clone()), + line_clears: 0, } } } @@ -37,7 +41,14 @@ impl From for State { pub trait Actor { fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action; - fn update(&mut self, state: State, action: Action, next_state: State, reward: f64); + fn update( + &mut self, + state: State, + action: Action, + next_state: State, + next_legal_actions: &[Action], + reward: f64, + ); fn set_learning_rate(&mut self, learning_rate: f64); fn set_exploration_prob(&mut self, exploration_prob: f64); diff --git a/src/actors/qlearning.rs b/src/actors/qlearning.rs index 8bc3605..425da80 100644 --- a/src/actors/qlearning.rs +++ b/src/actors/qlearning.rs @@ -1,5 +1,8 @@ use crate::actors::{Actor, State}; -use crate::game::Action; +use crate::{ + game::Action, + playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, +}; use log::debug; use rand::rngs::SmallRng; use rand::seq::SliceRandom; @@ -15,10 +18,10 @@ pub struct QLearningAgent { impl Default for QLearningAgent { fn default() -> Self { - QLearningAgent { - learning_rate: 0.1, - exploration_prob: 0.5, - discount_rate: 1.0, + Self { + learning_rate: 0.0, + exploration_prob: 0.0, + discount_rate: 0.0, q_values: HashMap::default(), } } @@ -66,7 +69,14 @@ impl Actor for QLearningAgent { } } - fn update(&mut self, state: State, action: Action, next_state: State, reward: f64) { + fn update( + &mut self, + state: State, + action: Action, + next_state: State, + _next_legal_actions: &[Action], + reward: f64, + ) { let cur_q_val = self.get_q_value(&state, action); let new_q_val = cur_q_val + self.learning_rate @@ -95,3 +105,148 @@ impl Actor for QLearningAgent { self.discount_rate = discount_rate; } } + +pub struct ApproximateQLearning { + pub learning_rate: f64, + pub exploration_prob: f64, + pub discount_rate: f64, + weights: HashMap, +} + +impl Default for ApproximateQLearning { + fn default() -> Self { + Self { + learning_rate: 0.0, + exploration_prob: 0.0, + discount_rate: 0.0, + weights: HashMap::default(), + } + } +} + +impl ApproximateQLearning { + fn get_features( + &self, + state: &State, + _action: &Action, + new_state: &State, + ) -> HashMap { + let mut features = HashMap::default(); + + let mut heights = [None; PLAYFIELD_WIDTH]; + for r in 0..PLAYFIELD_HEIGHT { + for c in 0..PLAYFIELD_WIDTH { + if heights[c].is_none() && state.matrix[r][c].is_some() { + heights[c] = Some(PLAYFIELD_HEIGHT - r); + } + } + } + + features.insert( + "Total Height".into(), + heights + .iter() + .map(|o| o.unwrap_or_else(|| 0)) + .sum::() as f64 + / (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64, + ); + + features.insert( + "Bumpiness".into(), + heights + .iter() + .map(|o| o.unwrap_or_else(|| 0) as isize) + .fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur)) + .0 as f64 + / (PLAYFIELD_WIDTH * 40) as f64, + ); + + features.insert( + "Lines cleared".into(), + (new_state.line_clears - state.line_clears) as f64 / 4.0, + ); + + let mut holes = 0; + for r in 1..PLAYFIELD_HEIGHT { + for c in 0..PLAYFIELD_WIDTH { + if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() { + holes += 1; + } + } + } + features.insert("Holes".into(), holes as f64); + + features + } + + fn get_q_value(&self, state: &State, action: &Action, next_state: &State) -> f64 { + self.get_features(state, action, next_state) + .iter() + .map(|(key, val)| val * *self.weights.get(key).unwrap_or_else(|| &0.0)) + .sum() + } + + fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action { + *legal_actions + .iter() + .map(|action| (action, self.get_q_value(&state, action, state))) + .max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize)) + .expect("Failed to select an action") + .0 + } + + fn get_value(&self, state: &State, legal_actions: &[Action]) -> f64 { + legal_actions + .iter() + .map(|action| self.get_q_value(state, action, state)) + .max_by_key(|v| (v * 1_000_000.0) as isize) + .unwrap_or_else(|| 0.0) + } +} + +impl Actor for ApproximateQLearning { + fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action { + if rng.gen::() < self.exploration_prob { + *legal_actions.choose(rng).unwrap() + } else { + self.get_action_from_q_values(state, legal_actions) + } + } + + fn update( + &mut self, + state: State, + action: Action, + next_state: State, + next_legal_actions: &[Action], + reward: f64, + ) { + let difference = reward + + self.discount_rate * self.get_value(&next_state, next_legal_actions) + - self.get_q_value(&state, &action, &next_state); + + for (feat_key, feat_val) in self.get_features(&state, &action, &next_state) { + self.weights.insert( + feat_key.clone(), + *self.weights.get(&feat_key).unwrap_or_else(|| &0.0) + + self.learning_rate * difference * feat_val, + ); + } + } + + fn set_learning_rate(&mut self, learning_rate: f64) { + self.learning_rate = learning_rate; + } + + fn set_exploration_prob(&mut self, exploration_prob: f64) { + self.exploration_prob = exploration_prob; + } + + fn set_discount_rate(&mut self, discount_rate: f64) { + self.discount_rate = discount_rate; + } + + fn dbg(&self) { + dbg!(&self.weights); + } +} diff --git a/src/cli.rs b/src/cli.rs index 59c6abd..913b09c 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -65,6 +65,8 @@ pub struct Train { /// Number of episodes to train the agent #[clap(short = "n", long = "num", default_value = "10")] pub episodes: usize, + // #[clap(long = "use-epsilon-decreasing")] + // pub epsilon_decreasing: bool, } arg_enum! { @@ -91,6 +93,6 @@ pub fn init_verbosity(opts: &Opts) -> Result<(), Box> { pub fn get_actor(agent: Agent) -> Box { match agent { Agent::QLearning => Box::new(qlearning::QLearningAgent::default()), - Agent::ApproximateQLearning => todo!(), + Agent::ApproximateQLearning => Box::new(qlearning::ApproximateQLearning::default()), } } diff --git a/src/game.rs b/src/game.rs index e8b4681..1610df3 100644 --- a/src/game.rs +++ b/src/game.rs @@ -32,7 +32,7 @@ pub struct Game { /// The last clear action performed, used for determining if a back-to-back /// bonus is needed. last_clear_action: ClearAction, - line_clears: u32, + pub line_clears: u32, } impl fmt::Debug for Game { @@ -178,7 +178,7 @@ impl Game { if cleared_lines > 0 { trace!("Lines were cleared."); self.line_clears += cleared_lines as u32; - self.score += (cleared_lines * self.level as usize) as u32; + self.score += (cleared_lines * 100 * self.level as usize) as u32; self.level = (self.line_clears / 10) as u8; self.playfield.active_piece = None; self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY; diff --git a/src/main.rs b/src/main.rs index 5bcf9bd..ece01dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -91,14 +91,16 @@ fn train_actor(episodes: usize, mut actor: Box) -> Box { let new_state = (&game).into(); let mut reward = game.score() as f64 - cur_score as f64; if action != Action::Nothing { - reward -= 10.0; + reward -= 0.0; } if game.is_game_over().is_some() { - reward = -100.0; + reward = -1.0; } - actor.update(cur_state, action, new_state, reward); + let new_legal_actions = game.get_legal_actions(); + + actor.update(cur_state, action, new_state, &new_legal_actions, reward); game.tick(); }