make qlearning train_agent specific

This commit is contained in:
Edward Shen 2020-04-05 20:18:48 -04:00
parent a65f48f585
commit 710a7dbebb
Signed by: edward
GPG key ID: 19182661E818369F
4 changed files with 130 additions and 51 deletions

View file

@ -1,6 +1,10 @@
// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/ // https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
use super::Actor; use super::{Actor, State};
use crate::{
game::Action,
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
};
use rand::rngs::SmallRng; use rand::rngs::SmallRng;
use rand::Rng; use rand::Rng;
@ -11,6 +15,17 @@ pub struct Parameters {
complete_lines: f64, complete_lines: f64,
} }
impl Default for Parameters {
fn default() -> Self {
Self {
total_height: 1.0,
bumpiness: 1.0,
holes: 1.0,
complete_lines: 1.0,
}
}
}
impl Parameters { impl Parameters {
fn mutate(mut self, rng: &mut SmallRng) { fn mutate(mut self, rng: &mut SmallRng) {
let mutation_amt = rng.gen_range(-0.2, 0.2); let mutation_amt = rng.gen_range(-0.2, 0.2);
@ -33,25 +48,93 @@ impl Parameters {
self.holes /= normalization_factor; self.holes /= normalization_factor;
self.complete_lines /= normalization_factor; self.complete_lines /= normalization_factor;
} }
fn dot_multiply(&self, other: &Self) -> f64 {
self.total_height * other.total_height
+ self.bumpiness * other.bumpiness
+ self.holes * other.holes
+ self.complete_lines * other.complete_lines
}
} }
pub struct GeneticHeuristicAgent {} pub struct GeneticHeuristicAgent {
params: Parameters,
}
impl Default for GeneticHeuristicAgent {
fn default() -> Self {
Self {
params: Parameters::default(),
}
}
}
impl GeneticHeuristicAgent {
fn extract_features_from_state(state: &State) -> Parameters {
let mut heights = [None; PLAYFIELD_WIDTH];
for r in 0..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if heights[c].is_none() && state.matrix[r][c].is_some() {
heights[c] = Some(PLAYFIELD_HEIGHT - r);
}
}
}
let total_height = heights
.iter()
.map(|o| o.unwrap_or_else(|| 0))
.sum::<usize>() as f64;
let bumpiness = heights
.iter()
.map(|o| o.unwrap_or_else(|| 0) as isize)
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
.0 as f64;
let complete_lines = state
.matrix
.iter()
.map(|row| row.iter().all(Option::is_some))
.map(|c| if c { 1.0 } else { 0.0 })
.sum::<f64>();
let mut holes = 0;
for r in 1..PLAYFIELD_HEIGHT {
for c in 0..PLAYFIELD_WIDTH {
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
holes += 1;
}
}
}
Parameters {
total_height,
bumpiness,
complete_lines,
holes: holes as f64,
}
}
fn get_heuristic(&self, state: &State, action: &Action) -> f64 {
todo!();
}
}
impl Actor for GeneticHeuristicAgent { impl Actor for GeneticHeuristicAgent {
fn get_action( fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
&self, *legal_actions
rng: &mut SmallRng, .iter()
state: &super::State, .map(|action| (action, self.get_heuristic(state, action)))
legal_actions: &[crate::game::Action], .max_by_key(|(action, heuristic)| (heuristic * 1_000_00.0) as usize)
) -> crate::game::Action { .unwrap()
unimplemented!() .0
} }
fn update( fn update(
&mut self, &mut self,
state: super::State, state: State,
action: crate::game::Action, action: Action,
next_state: super::State, next_state: State,
next_legal_actions: &[crate::game::Action], next_legal_actions: &[Action],
reward: f64, reward: f64,
) { ) {
unimplemented!() unimplemented!()

View file

@ -1,5 +1,6 @@
use crate::actors::{Actor, State}; use crate::actors::{Actor, State};
use crate::{ use crate::{
cli::Train,
game::{Action, Controllable, Game, Tickable}, game::{Action, Controllable, Game, Tickable},
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH}, playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
}; };
@ -253,10 +254,18 @@ impl Actor for ApproximateQLearning {
} }
} }
pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> { pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>, opts: &Train) -> Box<dyn Actor> {
let mut rng = SmallRng::from_entropy(); let mut rng = SmallRng::from_entropy();
let mut avg = 0.0; let mut avg = 0.0;
actor.set_learning_rate(opts.learning_rate);
actor.set_discount_rate(opts.discount_rate);
actor.set_exploration_prob(opts.exploration_prob);
info!(
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
opts.learning_rate, opts.discount_rate, opts.exploration_prob
);
for i in (0..episodes).progress() { for i in (0..episodes).progress() {
if i != 0 && i % (episodes / 10) == 0 { if i != 0 && i % (episodes / 10) == 0 {
info!("Last {} scores avg: {}", (episodes / 10), avg); info!("Last {} scores avg: {}", (episodes / 10), avg);
@ -300,5 +309,13 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor>
avg += game.score() as f64 / (episodes / 10) as f64; avg += game.score() as f64 / (episodes / 10) as f64;
} }
if opts.no_explore_during_evaluation {
actor.set_exploration_prob(0.0);
}
if opts.no_learn_during_evaluation {
actor.set_learning_rate(0.0);
}
actor actor
} }

View file

@ -89,10 +89,3 @@ pub fn init_verbosity(opts: &Opts) -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
pub fn get_actor(agent: Agent) -> Box<dyn Actor> {
match agent {
Agent::QLearning => Box::new(qlearning::QLearningAgent::default()),
Agent::ApproximateQLearning => Box::new(qlearning::ApproximateQLearning::default()),
}
}

View file

@ -6,7 +6,6 @@ use graphics::standard_renderer;
use graphics::COLOR_BACKGROUND; use graphics::COLOR_BACKGROUND;
use indicatif::ProgressIterator; use indicatif::ProgressIterator;
use log::{debug, info, trace}; use log::{debug, info, trace};
use qlearning::train_actor;
use rand::SeedableRng; use rand::SeedableRng;
use sdl2::event::Event; use sdl2::event::Event;
use sdl2::keyboard::Keycode; use sdl2::keyboard::Keycode;
@ -29,37 +28,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let opts = crate::cli::Opts::parse(); let opts = crate::cli::Opts::parse();
init_verbosity(&opts)?; init_verbosity(&opts)?;
let mut actor = None;
match opts.subcmd { let agent = match opts.subcmd {
SubCommand::Play(sub_opts) => {} SubCommand::Play(sub_opts) => None,
SubCommand::Train(sub_opts) => { SubCommand::Train(sub_opts) => Some(match sub_opts.agent {
let mut to_train = get_actor(sub_opts.agent); Agent::QLearning => qlearning::train_actor(
to_train.set_learning_rate(sub_opts.learning_rate); sub_opts.episodes,
to_train.set_discount_rate(sub_opts.discount_rate); Box::new(qlearning::QLearningAgent::default()),
to_train.set_exploration_prob(sub_opts.exploration_prob); &sub_opts,
),
Agent::ApproximateQLearning => qlearning::train_actor(
sub_opts.episodes,
Box::new(qlearning::ApproximateQLearning::default()),
&sub_opts,
),
}),
};
info!( play_game(agent).await
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
sub_opts.learning_rate,
sub_opts.discount_rate,
sub_opts.exploration_prob
);
let mut trained_actor = train_actor(sub_opts.episodes, to_train);
if sub_opts.no_explore_during_evaluation {
trained_actor.set_exploration_prob(0.0);
}
if sub_opts.no_learn_during_evaluation {
trained_actor.set_learning_rate(0.0);
}
actor = Some(trained_actor);
}
}
play_game(actor).await?;
Ok(())
} }
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> { async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {