finish approx q learning

This commit is contained in:
Edward Shen 2020-04-05 16:34:33 -04:00
parent ce65afa277
commit f3b48fbc85
Signed by: edward
GPG key ID: 19182661E818369F
5 changed files with 184 additions and 14 deletions

View file

@ -10,6 +10,7 @@ pub struct State {
matrix: Matrix, matrix: Matrix,
active_piece: Option<Tetromino>, active_piece: Option<Tetromino>,
held_piece: Option<TetrominoType>, held_piece: Option<TetrominoType>,
line_clears: u32,
} }
impl From<Game> for State { impl From<Game> for State {
@ -20,7 +21,9 @@ impl From<Game> for State {
impl From<&Game> for State { impl From<&Game> for State {
fn from(game: &Game) -> Self { fn from(game: &Game) -> Self {
game.playfield().clone().into() let mut state: State = game.playfield().clone().into();
state.line_clears = game.line_clears;
state
} }
} }
@ -30,6 +33,7 @@ impl From<PlayField> for State {
matrix: playfield.field().clone(), matrix: playfield.field().clone(),
active_piece: playfield.active_piece, active_piece: playfield.active_piece,
held_piece: playfield.hold_piece().map(|t| t.clone()), held_piece: playfield.hold_piece().map(|t| t.clone()),
line_clears: 0,
} }
} }
} }
@ -37,7 +41,14 @@ impl From<PlayField> for State {
pub trait Actor { pub trait Actor {
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action; fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action;
fn update(&mut self, state: State, action: Action, next_state: State, reward: f64); fn update(
&mut self,
state: State,
action: Action,
next_state: State,
next_legal_actions: &[Action],
reward: f64,
);
fn set_learning_rate(&mut self, learning_rate: f64); fn set_learning_rate(&mut self, learning_rate: f64);
fn set_exploration_prob(&mut self, exploration_prob: f64); fn set_exploration_prob(&mut self, exploration_prob: f64);

View file

@ -1,5 +1,8 @@
use crate::actors::{Actor, State}; use crate::actors::{Actor, State};
use crate::game::Action; use crate::{
game::Action,
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
};
use log::debug; use log::debug;
use rand::rngs::SmallRng; use rand::rngs::SmallRng;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
@ -15,10 +18,10 @@ pub struct QLearningAgent {
impl Default for QLearningAgent { impl Default for QLearningAgent {
fn default() -> Self { fn default() -> Self {
QLearningAgent { Self {
learning_rate: 0.1, learning_rate: 0.0,
exploration_prob: 0.5, exploration_prob: 0.0,
discount_rate: 1.0, discount_rate: 0.0,
q_values: HashMap::default(), q_values: HashMap::default(),
} }
} }
@ -66,7 +69,14 @@ impl Actor for QLearningAgent {
} }
} }
fn update(&mut self, state: State, action: Action, next_state: State, reward: f64) { fn update(
&mut self,
state: State,
action: Action,
next_state: State,
_next_legal_actions: &[Action],
reward: f64,
) {
let cur_q_val = self.get_q_value(&state, action); let cur_q_val = self.get_q_value(&state, action);
let new_q_val = cur_q_val let new_q_val = cur_q_val
+ self.learning_rate + self.learning_rate
@ -95,3 +105,148 @@ impl Actor for QLearningAgent {
self.discount_rate = discount_rate; self.discount_rate = discount_rate;
} }
} }
pub struct ApproximateQLearning {
pub learning_rate: f64,
pub exploration_prob: f64,
pub discount_rate: f64,
weights: HashMap<String, f64>,
}
impl Default for ApproximateQLearning {
fn default() -> Self {
Self {
learning_rate: 0.0,
exploration_prob: 0.0,
discount_rate: 0.0,
weights: HashMap::default(),
}
}
}
impl ApproximateQLearning {
fn get_features(
&self,
state: &State,
_action: &Action,
new_state: &State,
) -> HashMap<String, f64> {
let mut features = HashMap::default();
let mut heights = [None; PLAYFIELD_WIDTH];
for r in 0..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if heights[c].is_none() && state.matrix[r][c].is_some() {
heights[c] = Some(PLAYFIELD_HEIGHT - r);
}
}
}
features.insert(
"Total Height".into(),
heights
.iter()
.map(|o| o.unwrap_or_else(|| 0))
.sum::<usize>() as f64
/ (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64,
);
features.insert(
"Bumpiness".into(),
heights
.iter()
.map(|o| o.unwrap_or_else(|| 0) as isize)
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
.0 as f64
/ (PLAYFIELD_WIDTH * 40) as f64,
);
features.insert(
"Lines cleared".into(),
(new_state.line_clears - state.line_clears) as f64 / 4.0,
);
let mut holes = 0;
for r in 1..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
holes += 1;
}
}
}
features.insert("Holes".into(), holes as f64);
features
}
fn get_q_value(&self, state: &State, action: &Action, next_state: &State) -> f64 {
self.get_features(state, action, next_state)
.iter()
.map(|(key, val)| val * *self.weights.get(key).unwrap_or_else(|| &0.0))
.sum()
}
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
*legal_actions
.iter()
.map(|action| (action, self.get_q_value(&state, action, state)))
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
.expect("Failed to select an action")
.0
}
fn get_value(&self, state: &State, legal_actions: &[Action]) -> f64 {
legal_actions
.iter()
.map(|action| self.get_q_value(state, action, state))
.max_by_key(|v| (v * 1_000_000.0) as isize)
.unwrap_or_else(|| 0.0)
}
}
impl Actor for ApproximateQLearning {
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
if rng.gen::<f64>() < self.exploration_prob {
*legal_actions.choose(rng).unwrap()
} else {
self.get_action_from_q_values(state, legal_actions)
}
}
fn update(
&mut self,
state: State,
action: Action,
next_state: State,
next_legal_actions: &[Action],
reward: f64,
) {
let difference = reward
+ self.discount_rate * self.get_value(&next_state, next_legal_actions)
- self.get_q_value(&state, &action, &next_state);
for (feat_key, feat_val) in self.get_features(&state, &action, &next_state) {
self.weights.insert(
feat_key.clone(),
*self.weights.get(&feat_key).unwrap_or_else(|| &0.0)
+ self.learning_rate * difference * feat_val,
);
}
}
fn set_learning_rate(&mut self, learning_rate: f64) {
self.learning_rate = learning_rate;
}
fn set_exploration_prob(&mut self, exploration_prob: f64) {
self.exploration_prob = exploration_prob;
}
fn set_discount_rate(&mut self, discount_rate: f64) {
self.discount_rate = discount_rate;
}
fn dbg(&self) {
dbg!(&self.weights);
}
}

View file

@ -65,6 +65,8 @@ pub struct Train {
/// Number of episodes to train the agent /// Number of episodes to train the agent
#[clap(short = "n", long = "num", default_value = "10")] #[clap(short = "n", long = "num", default_value = "10")]
pub episodes: usize, pub episodes: usize,
// #[clap(long = "use-epsilon-decreasing")]
// pub epsilon_decreasing: bool,
} }
arg_enum! { arg_enum! {
@ -91,6 +93,6 @@ pub fn init_verbosity(opts: &Opts) -> Result<(), Box<dyn std::error::Error>> {
pub fn get_actor(agent: Agent) -> Box<dyn Actor> { pub fn get_actor(agent: Agent) -> Box<dyn Actor> {
match agent { match agent {
Agent::QLearning => Box::new(qlearning::QLearningAgent::default()), Agent::QLearning => Box::new(qlearning::QLearningAgent::default()),
Agent::ApproximateQLearning => todo!(), Agent::ApproximateQLearning => Box::new(qlearning::ApproximateQLearning::default()),
} }
} }

View file

@ -32,7 +32,7 @@ pub struct Game {
/// The last clear action performed, used for determining if a back-to-back /// The last clear action performed, used for determining if a back-to-back
/// bonus is needed. /// bonus is needed.
last_clear_action: ClearAction, last_clear_action: ClearAction,
line_clears: u32, pub line_clears: u32,
} }
impl fmt::Debug for Game { impl fmt::Debug for Game {
@ -178,7 +178,7 @@ impl Game {
if cleared_lines > 0 { if cleared_lines > 0 {
trace!("Lines were cleared."); trace!("Lines were cleared.");
self.line_clears += cleared_lines as u32; self.line_clears += cleared_lines as u32;
self.score += (cleared_lines * self.level as usize) as u32; self.score += (cleared_lines * 100 * self.level as usize) as u32;
self.level = (self.line_clears / 10) as u8; self.level = (self.line_clears / 10) as u8;
self.playfield.active_piece = None; self.playfield.active_piece = None;
self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY; self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY;

View file

@ -91,14 +91,16 @@ fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
let new_state = (&game).into(); let new_state = (&game).into();
let mut reward = game.score() as f64 - cur_score as f64; let mut reward = game.score() as f64 - cur_score as f64;
if action != Action::Nothing { if action != Action::Nothing {
reward -= 10.0; reward -= 0.0;
} }
if game.is_game_over().is_some() { if game.is_game_over().is_some() {
reward = -100.0; reward = -1.0;
} }
actor.update(cur_state, action, new_state, reward); let new_legal_actions = game.get_legal_actions();
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
game.tick(); game.tick();
} }