use crate::actors::{Actor, State}; use crate::game::Action; use log::debug; use rand::seq::SliceRandom; use rand::Rng; use rand::RngCore; use std::collections::HashMap; pub struct QLearningAgent { pub learning_rate: f64, pub exploration_prob: f64, pub discount_rate: f64, q_values: HashMap>, } impl Default for QLearningAgent { fn default() -> Self { QLearningAgent { learning_rate: 0.1, exploration_prob: 0.5, discount_rate: 1.0, q_values: HashMap::default(), } } } impl QLearningAgent { fn get_q_value(&self, state: &State, action: Action) -> f64 { match self.q_values.get(&state) { Some(action_qval) => *action_qval.get(&action).unwrap_or_else(|| &0.0), None => 0.0, } } fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action { *legal_actions .iter() .map(|action| (action, self.get_q_value(&state, *action))) .max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize)) .expect("Failed to select an action") .0 } fn get_value_from_q_values(&self, state: &State) -> f64 { *self .q_values .get(state) .and_then(|hashmap| { hashmap .values() .max_by_key(|q_val| (**q_val * 1_000_000.0) as isize) .or_else(|| Some(&0.0)) }) .unwrap_or_else(|| &0.0) } } impl Actor for QLearningAgent { // Because doing (Nothing) is in the set of legal actions, this will never // be empty fn get_action( &self, rng: &mut T, state: &State, legal_actions: &[Action], ) -> Action { if rng.gen::() < self.exploration_prob { *legal_actions.choose(rng).unwrap() } else { self.get_action_from_q_values(state, legal_actions) } } fn update(&mut self, state: State, action: Action, next_state: State, reward: f64) { let cur_q_val = self.get_q_value(&state, action); let new_q_val = cur_q_val + self.learning_rate * (reward + self.discount_rate * self.get_value_from_q_values(&next_state) - cur_q_val); if !self.q_values.contains_key(&state) { self.q_values.insert(state.clone(), HashMap::default()); } self.q_values .get_mut(&state) .unwrap() .insert(action, new_q_val); } fn dbg(&self) { debug!("Total states: {}", self.q_values.len()); } fn set_learning_rate(&mut self, learning_rate: f64) { self.learning_rate = learning_rate; } fn set_exploration_prob(&mut self, exploration_prob: f64) { self.exploration_prob = exploration_prob; } fn set_discount_rate(&mut self, discount_rate: f64) { self.discount_rate = discount_rate; } }