tetris/src/actors/qlearning.rs

430 lines
12 KiB
Rust

use super::Predictable;
use crate::actors::{Actor, State};
use crate::{
cli::Train,
game::{Action, Controllable, Game, Tickable},
playfield::{Matrix, PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
};
use log::{debug, error, info, trace};
use rand::rngs::SmallRng;
use rand::seq::SliceRandom;
use rand::Rng;
use rand::SeedableRng;
use std::collections::HashMap;
pub trait QLearningActor: Actor {
fn update(
&mut self,
game_state: Game,
action: Action,
next_game_state: Game,
next_legal_actions: &[Action],
reward: f64,
);
fn set_learning_rate(&mut self, learning_rate: f64);
fn set_exploration_prob(&mut self, exploration_prob: f64);
fn set_discount_rate(&mut self, discount_rate: f64);
}
#[derive(Debug)]
pub struct QLearningAgent {
pub learning_rate: f64,
pub exploration_prob: f64,
pub discount_rate: f64,
q_values: HashMap<State, HashMap<Action, f64>>,
}
impl Default for QLearningAgent {
fn default() -> Self {
Self {
learning_rate: 0.0,
exploration_prob: 0.0,
discount_rate: 0.0,
q_values: HashMap::default(),
}
}
}
impl QLearningAgent {
fn get_q_value(&self, state: &State, action: Action) -> f64 {
match self.q_values.get(&state) {
Some(action_qval) => *action_qval.get(&action).unwrap_or_else(|| &0.0),
None => 0.0,
}
}
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
let legal_actions = legal_actions
.iter()
.map(|action| (action, self.get_q_value(state, *action)))
.collect::<Vec<_>>();
let max_val = legal_actions
.iter()
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
.expect("Failed to select an action")
.1;
let actions_to_choose = legal_actions
.iter()
.filter(|(_, v)| max_val == *v)
.collect::<Vec<_>>();
if actions_to_choose.len() != 1 {
trace!(
"more than one best option, choosing randomly: {:?}",
actions_to_choose
);
}
*actions_to_choose
.choose(&mut SmallRng::from_entropy())
.unwrap()
.0
}
fn get_value_from_q_values(&self, state: &State) -> f64 {
*self
.q_values
.get(state)
.and_then(|hashmap| {
hashmap
.values()
.max_by_key(|q_val| (**q_val * 1_000_000.0) as isize)
.or_else(|| Some(&0.0))
})
.unwrap_or_else(|| &0.0)
}
}
impl Actor for QLearningAgent {
// Because doing (Nothing) is in the set of legal actions, this will never
// be empty
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
if rng.gen::<f64>() < self.exploration_prob {
*legal_actions.choose(rng).unwrap()
} else {
self.get_action_from_q_values(&game.into(), legal_actions)
}
}
fn dbg(&self) {
debug!("Total states: {}", self.q_values.len());
}
}
impl QLearningActor for QLearningAgent {
fn update(
&mut self,
game_state: Game,
action: Action,
next_game_state: Game,
_next_legal_actions: &[Action],
reward: f64,
) {
let state = (&game_state).into();
let next_state = (&next_game_state).into();
let cur_q_val = self.get_q_value(&state, action);
let new_q_val = cur_q_val
+ self.learning_rate
* (reward + self.discount_rate * self.get_value_from_q_values(&next_state)
- cur_q_val);
if !self.q_values.contains_key(&state) {
self.q_values.insert(state.clone(), HashMap::default());
}
self.q_values
.get_mut(&state)
.unwrap()
.insert(action, new_q_val);
}
fn set_learning_rate(&mut self, learning_rate: f64) {
self.learning_rate = learning_rate;
}
fn set_exploration_prob(&mut self, exploration_prob: f64) {
self.exploration_prob = exploration_prob;
}
fn set_discount_rate(&mut self, discount_rate: f64) {
self.discount_rate = discount_rate;
}
}
#[derive(Debug)]
pub struct ApproximateQLearning {
pub learning_rate: f64,
pub exploration_prob: f64,
pub discount_rate: f64,
weights: HashMap<Feature, f64>,
}
impl Default for ApproximateQLearning {
fn default() -> Self {
let mut weights = HashMap::default();
weights.insert(Feature::TotalHeight, 1.0);
weights.insert(Feature::Bumpiness, 1.0);
weights.insert(Feature::LinesCleared, 1.0);
weights.insert(Feature::Holes, 1.0);
Self {
learning_rate: 0.0,
exploration_prob: 0.0,
discount_rate: 0.0,
weights,
}
}
}
#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
enum Feature {
TotalHeight,
Bumpiness,
LinesCleared,
Holes,
}
impl ApproximateQLearning {
fn get_features(&self, game: &Game, action: &Action) -> HashMap<Feature, f64> {
let game = game.get_next_state(*action);
let mut features = HashMap::default();
let field = game.playfield().field();
let heights = self.get_heights(field);
features.insert(
Feature::TotalHeight,
heights.iter().sum::<usize>() as f64 / (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64,
);
features.insert(
Feature::Bumpiness,
heights
.iter()
.fold((0, 0), |(acc, prev), cur| {
(acc + (prev as isize - *cur as isize).abs(), *cur)
})
.0 as f64
/ (PLAYFIELD_WIDTH * 40) as f64,
);
features.insert(
Feature::LinesCleared,
game.playfield()
.field()
.iter()
.map(|r| r.iter().all(Option::is_some))
.map(|r| if r { 1 } else { 0 })
.sum::<usize>() as f64
/ 4.0,
);
let mut holes = 0;
for r in 1..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if field[r][c].is_none() && field[r - 1][c].is_some() {
holes += 1;
}
}
}
features.insert(Feature::Holes, holes as f64);
features
}
fn get_heights(&self, matrix: &Matrix) -> Vec<usize> {
let mut heights = vec![0; matrix[0].len()];
for r in 0..matrix.len() {
for c in 0..matrix[0].len() {
if heights[c] == 0 && matrix[r][c].is_some() {
heights[c] = matrix.len() - r;
}
}
}
heights
}
fn get_q_value(&self, game: &Game, action: &Action) -> f64 {
self.get_features(game, action)
.iter()
.map(|(key, val)| val * *self.weights.get(key).unwrap())
.sum()
}
fn get_action_from_q_values(&self, game: &Game) -> Action {
let legal_actions = game.get_legal_actions();
let legal_actions = legal_actions
.iter()
.map(|action| (action, self.get_q_value(game, action)))
.collect::<Vec<_>>();
// dbg!(&legal_actions);
let max_val = legal_actions
.iter()
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
.expect("Failed to select an action")
.1;
let actions_to_choose = legal_actions
.iter()
.filter(|(_, v)| max_val == *v)
.collect::<Vec<_>>();
if actions_to_choose.len() != 1 {
trace!(
"more than one best option, choosing randomly: {:?}",
actions_to_choose
);
}
let action = actions_to_choose.choose(&mut SmallRng::from_entropy());
match action {
Some(a) => *a.0,
None => {
dbg!(&legal_actions);
dbg!(&actions_to_choose);
panic!("wtf???");
}
}
}
fn get_value(&self, game: &Game) -> f64 {
game.get_legal_actions()
.iter()
.map(|action| self.get_q_value(game, action))
.max_by_key(|v| (v * 1_000_000.0) as isize)
.unwrap_or_else(|| 0.0)
}
}
#[cfg(test)]
mod aaa {
use super::*;
use crate::tetromino::TetrominoType;
#[test]
fn test_height() {
let agent = ApproximateQLearning::default();
let matrix = vec![
vec![None, None, None],
vec![None, Some(TetrominoType::T), None],
vec![None, None, None],
];
assert_eq!(agent.get_heights(&matrix), vec![0, 2, 0]);
}
}
impl Actor for ApproximateQLearning {
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
if rng.gen::<f64>() < self.exploration_prob {
*legal_actions.choose(rng).unwrap()
} else {
self.get_action_from_q_values(game)
}
}
fn dbg(&self) {
dbg!(&self.weights);
}
}
impl QLearningActor for ApproximateQLearning {
fn update(
&mut self,
game_state: Game,
action: Action,
next_game_state: Game,
_: &[Action],
reward: f64,
) {
let difference = reward + self.discount_rate * self.get_value(&next_game_state)
- self.get_q_value(&game_state, &action);
for (feat_key, feat_val) in self.get_features(&game_state, &action) {
self.weights.insert(
feat_key.clone(),
*self.weights.get(&feat_key).unwrap() + self.learning_rate * difference * feat_val,
);
}
}
fn set_learning_rate(&mut self, learning_rate: f64) {
self.learning_rate = learning_rate;
}
fn set_exploration_prob(&mut self, exploration_prob: f64) {
self.exploration_prob = exploration_prob;
}
fn set_discount_rate(&mut self, discount_rate: f64) {
self.discount_rate = discount_rate;
}
}
pub fn train_actor<T: std::fmt::Debug + 'static + QLearningActor + Actor>(
mut actor: T,
opts: &Train,
) -> Box<dyn Actor> {
let mut rng = SmallRng::from_entropy();
let mut avg = 0.0;
let episodes = opts.episodes;
actor.set_learning_rate(opts.learning_rate);
actor.set_discount_rate(opts.discount_rate);
actor.set_exploration_prob(opts.exploration_prob);
info!(
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
opts.learning_rate, opts.discount_rate, opts.exploration_prob
);
for i in 0..episodes {
if i != 0 && i % (episodes / 10) == 0 {
println!("{}", avg);
eprintln!("iteration {}", i);
// println!("{:?}", &actor);
avg = 0.0;
}
let mut game = Game::default();
while (&game).is_game_over().is_none() {
let cur_state = game.clone();
let cur_score = game.score();
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
super::apply_action_to_game(action, &mut game);
let new_state = game.clone();
let mut reward = game.score() as f64 - cur_score as f64;
// if action != Action::Nothing {
// reward -= 0.0;
// }
if game.is_game_over().is_some() {
reward = -1.0;
}
let new_legal_actions = game.get_legal_actions();
actor.update(
cur_state.into(),
action,
new_state,
&new_legal_actions,
reward,
);
game.tick();
}
avg += game.score() as f64 / (episodes / 10) as f64;
}
if opts.no_explore_during_evaluation {
actor.set_exploration_prob(0.0);
}
if opts.no_learn_during_evaluation {
actor.set_learning_rate(0.0);
}
Box::new(actor)
}