98 lines
2.9 KiB
Rust
98 lines
2.9 KiB
Rust
use crate::actors::{Actor, State};
|
|
use crate::game::Action;
|
|
use log::debug;
|
|
use rand::rngs::SmallRng;
|
|
use rand::seq::SliceRandom;
|
|
use rand::Rng;
|
|
use std::collections::HashMap;
|
|
|
|
pub struct QLearningAgent {
|
|
pub learning_rate: f64,
|
|
pub exploration_prob: f64,
|
|
pub discount_rate: f64,
|
|
q_values: HashMap<State, HashMap<Action, f64>>,
|
|
}
|
|
|
|
impl Default for QLearningAgent {
|
|
fn default() -> Self {
|
|
QLearningAgent {
|
|
learning_rate: 0.1,
|
|
exploration_prob: 0.5,
|
|
discount_rate: 1.0,
|
|
q_values: HashMap::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl QLearningAgent {
|
|
fn get_q_value(&self, state: &State, action: Action) -> f64 {
|
|
match self.q_values.get(&state) {
|
|
Some(action_qval) => *action_qval.get(&action).unwrap_or_else(|| &0.0),
|
|
None => 0.0,
|
|
}
|
|
}
|
|
|
|
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
|
|
*legal_actions
|
|
.iter()
|
|
.map(|action| (action, self.get_q_value(&state, *action)))
|
|
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
|
|
.expect("Failed to select an action")
|
|
.0
|
|
}
|
|
|
|
fn get_value_from_q_values(&self, state: &State) -> f64 {
|
|
*self
|
|
.q_values
|
|
.get(state)
|
|
.and_then(|hashmap| {
|
|
hashmap
|
|
.values()
|
|
.max_by_key(|q_val| (**q_val * 1_000_000.0) as isize)
|
|
.or_else(|| Some(&0.0))
|
|
})
|
|
.unwrap_or_else(|| &0.0)
|
|
}
|
|
}
|
|
|
|
impl Actor for QLearningAgent {
|
|
// Because doing (Nothing) is in the set of legal actions, this will never
|
|
// be empty
|
|
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
|
|
if rng.gen::<f64>() < self.exploration_prob {
|
|
*legal_actions.choose(rng).unwrap()
|
|
} else {
|
|
self.get_action_from_q_values(state, legal_actions)
|
|
}
|
|
}
|
|
|
|
fn update(&mut self, state: State, action: Action, next_state: State, reward: f64) {
|
|
let cur_q_val = self.get_q_value(&state, action);
|
|
let new_q_val = cur_q_val
|
|
+ self.learning_rate
|
|
* (reward + self.discount_rate * self.get_value_from_q_values(&next_state)
|
|
- cur_q_val);
|
|
if !self.q_values.contains_key(&state) {
|
|
self.q_values.insert(state.clone(), HashMap::default());
|
|
}
|
|
self.q_values
|
|
.get_mut(&state)
|
|
.unwrap()
|
|
.insert(action, new_q_val);
|
|
}
|
|
|
|
fn dbg(&self) {
|
|
debug!("Total states: {}", self.q_values.len());
|
|
}
|
|
|
|
fn set_learning_rate(&mut self, learning_rate: f64) {
|
|
self.learning_rate = learning_rate;
|
|
}
|
|
fn set_exploration_prob(&mut self, exploration_prob: f64) {
|
|
self.exploration_prob = exploration_prob;
|
|
}
|
|
fn set_discount_rate(&mut self, discount_rate: f64) {
|
|
self.discount_rate = discount_rate;
|
|
}
|
|
}
|