341 lines
9.8 KiB
Rust
341 lines
9.8 KiB
Rust
use crate::actors::{Actor, State};
|
|
use crate::{
|
|
cli::Train,
|
|
game::{Action, Controllable, Game, Tickable},
|
|
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
|
};
|
|
use indicatif::ProgressIterator;
|
|
use log::{debug, info};
|
|
use rand::rngs::SmallRng;
|
|
use rand::seq::SliceRandom;
|
|
use rand::Rng;
|
|
use rand::SeedableRng;
|
|
use std::collections::HashMap;
|
|
|
|
pub trait QLearningActor: Actor {
|
|
fn update(
|
|
&mut self,
|
|
state: State,
|
|
action: Action,
|
|
next_state: State,
|
|
next_legal_actions: &[Action],
|
|
reward: f64,
|
|
);
|
|
|
|
fn set_learning_rate(&mut self, learning_rate: f64);
|
|
fn set_exploration_prob(&mut self, exploration_prob: f64);
|
|
fn set_discount_rate(&mut self, discount_rate: f64);
|
|
}
|
|
|
|
pub struct QLearningAgent {
|
|
pub learning_rate: f64,
|
|
pub exploration_prob: f64,
|
|
pub discount_rate: f64,
|
|
q_values: HashMap<State, HashMap<Action, f64>>,
|
|
}
|
|
|
|
impl Default for QLearningAgent {
|
|
fn default() -> Self {
|
|
Self {
|
|
learning_rate: 0.0,
|
|
exploration_prob: 0.0,
|
|
discount_rate: 0.0,
|
|
q_values: HashMap::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl QLearningAgent {
|
|
fn get_q_value(&self, state: &State, action: Action) -> f64 {
|
|
match self.q_values.get(&state) {
|
|
Some(action_qval) => *action_qval.get(&action).unwrap_or_else(|| &0.0),
|
|
None => 0.0,
|
|
}
|
|
}
|
|
|
|
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
|
|
*legal_actions
|
|
.iter()
|
|
.map(|action| (action, self.get_q_value(&state, *action)))
|
|
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
|
|
.expect("Failed to select an action")
|
|
.0
|
|
}
|
|
|
|
fn get_value_from_q_values(&self, state: &State) -> f64 {
|
|
*self
|
|
.q_values
|
|
.get(state)
|
|
.and_then(|hashmap| {
|
|
hashmap
|
|
.values()
|
|
.max_by_key(|q_val| (**q_val * 1_000_000.0) as isize)
|
|
.or_else(|| Some(&0.0))
|
|
})
|
|
.unwrap_or_else(|| &0.0)
|
|
}
|
|
}
|
|
|
|
impl Actor for QLearningAgent {
|
|
// Because doing (Nothing) is in the set of legal actions, this will never
|
|
// be empty
|
|
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
|
|
if rng.gen::<f64>() < self.exploration_prob {
|
|
*legal_actions.choose(rng).unwrap()
|
|
} else {
|
|
self.get_action_from_q_values(&game.into(), legal_actions)
|
|
}
|
|
}
|
|
|
|
fn dbg(&self) {
|
|
debug!("Total states: {}", self.q_values.len());
|
|
}
|
|
}
|
|
|
|
impl QLearningActor for QLearningAgent {
|
|
fn update(
|
|
&mut self,
|
|
state: State,
|
|
action: Action,
|
|
next_state: State,
|
|
_next_legal_actions: &[Action],
|
|
reward: f64,
|
|
) {
|
|
let cur_q_val = self.get_q_value(&state, action);
|
|
let new_q_val = cur_q_val
|
|
+ self.learning_rate
|
|
* (reward + self.discount_rate * self.get_value_from_q_values(&next_state)
|
|
- cur_q_val);
|
|
if !self.q_values.contains_key(&state) {
|
|
self.q_values.insert(state.clone(), HashMap::default());
|
|
}
|
|
self.q_values
|
|
.get_mut(&state)
|
|
.unwrap()
|
|
.insert(action, new_q_val);
|
|
}
|
|
|
|
fn set_learning_rate(&mut self, learning_rate: f64) {
|
|
self.learning_rate = learning_rate;
|
|
}
|
|
fn set_exploration_prob(&mut self, exploration_prob: f64) {
|
|
self.exploration_prob = exploration_prob;
|
|
}
|
|
fn set_discount_rate(&mut self, discount_rate: f64) {
|
|
self.discount_rate = discount_rate;
|
|
}
|
|
}
|
|
|
|
pub struct ApproximateQLearning {
|
|
pub learning_rate: f64,
|
|
pub exploration_prob: f64,
|
|
pub discount_rate: f64,
|
|
weights: HashMap<String, f64>,
|
|
}
|
|
|
|
impl Default for ApproximateQLearning {
|
|
fn default() -> Self {
|
|
Self {
|
|
learning_rate: 0.0,
|
|
exploration_prob: 0.0,
|
|
discount_rate: 0.0,
|
|
weights: HashMap::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ApproximateQLearning {
|
|
fn get_features(
|
|
&self,
|
|
state: &State,
|
|
_action: &Action,
|
|
new_state: &State,
|
|
) -> HashMap<String, f64> {
|
|
let mut features = HashMap::default();
|
|
|
|
let mut heights = [None; PLAYFIELD_WIDTH];
|
|
for r in 0..PLAYFIELD_HEIGHT {
|
|
for c in 0..PLAYFIELD_WIDTH {
|
|
if heights[c].is_none() && state.matrix[r][c].is_some() {
|
|
heights[c] = Some(PLAYFIELD_HEIGHT - r);
|
|
}
|
|
}
|
|
}
|
|
|
|
features.insert(
|
|
"Total Height".into(),
|
|
heights
|
|
.iter()
|
|
.map(|o| o.unwrap_or_else(|| 0))
|
|
.sum::<usize>() as f64
|
|
/ (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64,
|
|
);
|
|
|
|
features.insert(
|
|
"Bumpiness".into(),
|
|
heights
|
|
.iter()
|
|
.map(|o| o.unwrap_or_else(|| 0) as isize)
|
|
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
|
|
.0 as f64
|
|
/ (PLAYFIELD_WIDTH * 40) as f64,
|
|
);
|
|
|
|
features.insert(
|
|
"Lines cleared".into(),
|
|
(new_state.line_clears - state.line_clears) as f64 / 4.0,
|
|
);
|
|
|
|
let mut holes = 0;
|
|
for r in 1..PLAYFIELD_HEIGHT {
|
|
for c in 0..PLAYFIELD_WIDTH {
|
|
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
|
|
holes += 1;
|
|
}
|
|
}
|
|
}
|
|
features.insert("Holes".into(), holes as f64);
|
|
|
|
features
|
|
}
|
|
|
|
fn get_q_value(&self, state: &State, action: &Action, next_state: &State) -> f64 {
|
|
self.get_features(state, action, next_state)
|
|
.iter()
|
|
.map(|(key, val)| val * *self.weights.get(key).unwrap_or_else(|| &0.0))
|
|
.sum()
|
|
}
|
|
|
|
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
|
|
*legal_actions
|
|
.iter()
|
|
.map(|action| (action, self.get_q_value(&state, action, state)))
|
|
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
|
|
.expect("Failed to select an action")
|
|
.0
|
|
}
|
|
|
|
fn get_value(&self, state: &State, legal_actions: &[Action]) -> f64 {
|
|
legal_actions
|
|
.iter()
|
|
.map(|action| self.get_q_value(state, action, state))
|
|
.max_by_key(|v| (v * 1_000_000.0) as isize)
|
|
.unwrap_or_else(|| 0.0)
|
|
}
|
|
}
|
|
|
|
impl Actor for ApproximateQLearning {
|
|
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
|
|
if rng.gen::<f64>() < self.exploration_prob {
|
|
*legal_actions.choose(rng).unwrap()
|
|
} else {
|
|
self.get_action_from_q_values(&game.into(), legal_actions)
|
|
}
|
|
}
|
|
|
|
fn dbg(&self) {
|
|
dbg!(&self.weights);
|
|
}
|
|
}
|
|
|
|
impl QLearningActor for ApproximateQLearning {
|
|
fn update(
|
|
&mut self,
|
|
state: State,
|
|
action: Action,
|
|
next_state: State,
|
|
next_legal_actions: &[Action],
|
|
reward: f64,
|
|
) {
|
|
let difference = reward
|
|
+ self.discount_rate * self.get_value(&next_state, next_legal_actions)
|
|
- self.get_q_value(&state, &action, &next_state);
|
|
|
|
for (feat_key, feat_val) in self.get_features(&state, &action, &next_state) {
|
|
self.weights.insert(
|
|
feat_key.clone(),
|
|
*self.weights.get(&feat_key).unwrap_or_else(|| &0.0)
|
|
+ self.learning_rate * difference * feat_val,
|
|
);
|
|
}
|
|
}
|
|
|
|
fn set_learning_rate(&mut self, learning_rate: f64) {
|
|
self.learning_rate = learning_rate;
|
|
}
|
|
|
|
fn set_exploration_prob(&mut self, exploration_prob: f64) {
|
|
self.exploration_prob = exploration_prob;
|
|
}
|
|
|
|
fn set_discount_rate(&mut self, discount_rate: f64) {
|
|
self.discount_rate = discount_rate;
|
|
}
|
|
}
|
|
|
|
pub fn train_actor<T: 'static + QLearningActor + Actor>(
|
|
mut actor: T,
|
|
opts: &Train,
|
|
) -> Box<dyn Actor> {
|
|
let mut rng = SmallRng::from_entropy();
|
|
let mut avg = 0.0;
|
|
let episodes = opts.episodes;
|
|
|
|
actor.set_learning_rate(opts.learning_rate);
|
|
actor.set_discount_rate(opts.discount_rate);
|
|
actor.set_exploration_prob(opts.exploration_prob);
|
|
info!(
|
|
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
|
|
opts.learning_rate, opts.discount_rate, opts.exploration_prob
|
|
);
|
|
|
|
for i in (0..episodes).progress() {
|
|
if i != 0 && i % (episodes / 10) == 0 {
|
|
info!("Last {} scores avg: {}", (episodes / 10), avg);
|
|
println!();
|
|
avg = 0.0;
|
|
}
|
|
let mut game = Game::default();
|
|
while (&game).is_game_over().is_none() {
|
|
let cur_state = game.clone();
|
|
let cur_score = game.score();
|
|
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
|
|
|
|
super::apply_action_to_game(action, &mut game);
|
|
|
|
let new_state = (&game).into();
|
|
let mut reward = game.score() as f64 - cur_score as f64;
|
|
if action != Action::Nothing {
|
|
reward -= 0.0;
|
|
}
|
|
|
|
if game.is_game_over().is_some() {
|
|
reward = -1.0;
|
|
}
|
|
|
|
let new_legal_actions = game.get_legal_actions();
|
|
|
|
actor.update(
|
|
cur_state.into(),
|
|
action,
|
|
new_state,
|
|
&new_legal_actions,
|
|
reward,
|
|
);
|
|
|
|
game.tick();
|
|
}
|
|
|
|
avg += game.score() as f64 / (episodes / 10) as f64;
|
|
}
|
|
|
|
if opts.no_explore_during_evaluation {
|
|
actor.set_exploration_prob(0.0);
|
|
}
|
|
|
|
if opts.no_learn_during_evaluation {
|
|
actor.set_learning_rate(0.0);
|
|
}
|
|
Box::new(actor)
|
|
}
|