make qlearning train_agent specific
This commit is contained in:
parent
a65f48f585
commit
710a7dbebb
4 changed files with 130 additions and 51 deletions
|
@ -1,6 +1,10 @@
|
||||||
// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
|
// https://codemyroad.wordpress.com/2013/04/14/tetris-ai-the-near-perfect-player/
|
||||||
|
|
||||||
use super::Actor;
|
use super::{Actor, State};
|
||||||
|
use crate::{
|
||||||
|
game::Action,
|
||||||
|
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
||||||
|
};
|
||||||
use rand::rngs::SmallRng;
|
use rand::rngs::SmallRng;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
|
||||||
|
@ -11,6 +15,17 @@ pub struct Parameters {
|
||||||
complete_lines: f64,
|
complete_lines: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for Parameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
total_height: 1.0,
|
||||||
|
bumpiness: 1.0,
|
||||||
|
holes: 1.0,
|
||||||
|
complete_lines: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Parameters {
|
impl Parameters {
|
||||||
fn mutate(mut self, rng: &mut SmallRng) {
|
fn mutate(mut self, rng: &mut SmallRng) {
|
||||||
let mutation_amt = rng.gen_range(-0.2, 0.2);
|
let mutation_amt = rng.gen_range(-0.2, 0.2);
|
||||||
|
@ -33,25 +48,93 @@ impl Parameters {
|
||||||
self.holes /= normalization_factor;
|
self.holes /= normalization_factor;
|
||||||
self.complete_lines /= normalization_factor;
|
self.complete_lines /= normalization_factor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn dot_multiply(&self, other: &Self) -> f64 {
|
||||||
|
self.total_height * other.total_height
|
||||||
|
+ self.bumpiness * other.bumpiness
|
||||||
|
+ self.holes * other.holes
|
||||||
|
+ self.complete_lines * other.complete_lines
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct GeneticHeuristicAgent {}
|
pub struct GeneticHeuristicAgent {
|
||||||
|
params: Parameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for GeneticHeuristicAgent {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
params: Parameters::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GeneticHeuristicAgent {
|
||||||
|
fn extract_features_from_state(state: &State) -> Parameters {
|
||||||
|
let mut heights = [None; PLAYFIELD_WIDTH];
|
||||||
|
for r in 0..PLAYFIELD_HEIGHT {
|
||||||
|
for c in 0..PLAYFIELD_WIDTH {
|
||||||
|
if heights[c].is_none() && state.matrix[r][c].is_some() {
|
||||||
|
heights[c] = Some(PLAYFIELD_HEIGHT - r);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_height = heights
|
||||||
|
.iter()
|
||||||
|
.map(|o| o.unwrap_or_else(|| 0))
|
||||||
|
.sum::<usize>() as f64;
|
||||||
|
|
||||||
|
let bumpiness = heights
|
||||||
|
.iter()
|
||||||
|
.map(|o| o.unwrap_or_else(|| 0) as isize)
|
||||||
|
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
|
||||||
|
.0 as f64;
|
||||||
|
|
||||||
|
let complete_lines = state
|
||||||
|
.matrix
|
||||||
|
.iter()
|
||||||
|
.map(|row| row.iter().all(Option::is_some))
|
||||||
|
.map(|c| if c { 1.0 } else { 0.0 })
|
||||||
|
.sum::<f64>();
|
||||||
|
|
||||||
|
let mut holes = 0;
|
||||||
|
for r in 1..PLAYFIELD_HEIGHT {
|
||||||
|
for c in 0..PLAYFIELD_WIDTH {
|
||||||
|
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
|
||||||
|
holes += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Parameters {
|
||||||
|
total_height,
|
||||||
|
bumpiness,
|
||||||
|
complete_lines,
|
||||||
|
holes: holes as f64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_heuristic(&self, state: &State, action: &Action) -> f64 {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Actor for GeneticHeuristicAgent {
|
impl Actor for GeneticHeuristicAgent {
|
||||||
fn get_action(
|
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
|
||||||
&self,
|
*legal_actions
|
||||||
rng: &mut SmallRng,
|
.iter()
|
||||||
state: &super::State,
|
.map(|action| (action, self.get_heuristic(state, action)))
|
||||||
legal_actions: &[crate::game::Action],
|
.max_by_key(|(action, heuristic)| (heuristic * 1_000_00.0) as usize)
|
||||||
) -> crate::game::Action {
|
.unwrap()
|
||||||
unimplemented!()
|
.0
|
||||||
}
|
}
|
||||||
fn update(
|
fn update(
|
||||||
&mut self,
|
&mut self,
|
||||||
state: super::State,
|
state: State,
|
||||||
action: crate::game::Action,
|
action: Action,
|
||||||
next_state: super::State,
|
next_state: State,
|
||||||
next_legal_actions: &[crate::game::Action],
|
next_legal_actions: &[Action],
|
||||||
reward: f64,
|
reward: f64,
|
||||||
) {
|
) {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use crate::actors::{Actor, State};
|
use crate::actors::{Actor, State};
|
||||||
use crate::{
|
use crate::{
|
||||||
|
cli::Train,
|
||||||
game::{Action, Controllable, Game, Tickable},
|
game::{Action, Controllable, Game, Tickable},
|
||||||
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
||||||
};
|
};
|
||||||
|
@ -253,10 +254,18 @@ impl Actor for ApproximateQLearning {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
|
pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>, opts: &Train) -> Box<dyn Actor> {
|
||||||
let mut rng = SmallRng::from_entropy();
|
let mut rng = SmallRng::from_entropy();
|
||||||
let mut avg = 0.0;
|
let mut avg = 0.0;
|
||||||
|
|
||||||
|
actor.set_learning_rate(opts.learning_rate);
|
||||||
|
actor.set_discount_rate(opts.discount_rate);
|
||||||
|
actor.set_exploration_prob(opts.exploration_prob);
|
||||||
|
info!(
|
||||||
|
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
|
||||||
|
opts.learning_rate, opts.discount_rate, opts.exploration_prob
|
||||||
|
);
|
||||||
|
|
||||||
for i in (0..episodes).progress() {
|
for i in (0..episodes).progress() {
|
||||||
if i != 0 && i % (episodes / 10) == 0 {
|
if i != 0 && i % (episodes / 10) == 0 {
|
||||||
info!("Last {} scores avg: {}", (episodes / 10), avg);
|
info!("Last {} scores avg: {}", (episodes / 10), avg);
|
||||||
|
@ -300,5 +309,13 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor>
|
||||||
avg += game.score() as f64 / (episodes / 10) as f64;
|
avg += game.score() as f64 / (episodes / 10) as f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.no_explore_during_evaluation {
|
||||||
|
actor.set_exploration_prob(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.no_learn_during_evaluation {
|
||||||
|
actor.set_learning_rate(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
actor
|
actor
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,10 +89,3 @@ pub fn init_verbosity(opts: &Opts) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_actor(agent: Agent) -> Box<dyn Actor> {
|
|
||||||
match agent {
|
|
||||||
Agent::QLearning => Box::new(qlearning::QLearningAgent::default()),
|
|
||||||
Agent::ApproximateQLearning => Box::new(qlearning::ApproximateQLearning::default()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
46
src/main.rs
46
src/main.rs
|
@ -6,7 +6,6 @@ use graphics::standard_renderer;
|
||||||
use graphics::COLOR_BACKGROUND;
|
use graphics::COLOR_BACKGROUND;
|
||||||
use indicatif::ProgressIterator;
|
use indicatif::ProgressIterator;
|
||||||
use log::{debug, info, trace};
|
use log::{debug, info, trace};
|
||||||
use qlearning::train_actor;
|
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
use sdl2::event::Event;
|
use sdl2::event::Event;
|
||||||
use sdl2::keyboard::Keycode;
|
use sdl2::keyboard::Keycode;
|
||||||
|
@ -29,37 +28,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let opts = crate::cli::Opts::parse();
|
let opts = crate::cli::Opts::parse();
|
||||||
|
|
||||||
init_verbosity(&opts)?;
|
init_verbosity(&opts)?;
|
||||||
let mut actor = None;
|
|
||||||
|
|
||||||
match opts.subcmd {
|
let agent = match opts.subcmd {
|
||||||
SubCommand::Play(sub_opts) => {}
|
SubCommand::Play(sub_opts) => None,
|
||||||
SubCommand::Train(sub_opts) => {
|
SubCommand::Train(sub_opts) => Some(match sub_opts.agent {
|
||||||
let mut to_train = get_actor(sub_opts.agent);
|
Agent::QLearning => qlearning::train_actor(
|
||||||
to_train.set_learning_rate(sub_opts.learning_rate);
|
sub_opts.episodes,
|
||||||
to_train.set_discount_rate(sub_opts.discount_rate);
|
Box::new(qlearning::QLearningAgent::default()),
|
||||||
to_train.set_exploration_prob(sub_opts.exploration_prob);
|
&sub_opts,
|
||||||
|
),
|
||||||
|
Agent::ApproximateQLearning => qlearning::train_actor(
|
||||||
|
sub_opts.episodes,
|
||||||
|
Box::new(qlearning::ApproximateQLearning::default()),
|
||||||
|
&sub_opts,
|
||||||
|
),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
info!(
|
play_game(agent).await
|
||||||
"Training an actor with learning_rate = {}, discount_rate = {}, exploration_rate = {}",
|
|
||||||
sub_opts.learning_rate,
|
|
||||||
sub_opts.discount_rate,
|
|
||||||
sub_opts.exploration_prob
|
|
||||||
);
|
|
||||||
let mut trained_actor = train_actor(sub_opts.episodes, to_train);
|
|
||||||
if sub_opts.no_explore_during_evaluation {
|
|
||||||
trained_actor.set_exploration_prob(0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
if sub_opts.no_learn_during_evaluation {
|
|
||||||
trained_actor.set_learning_rate(0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
actor = Some(trained_actor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
play_game(actor).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
|
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
|
Loading…
Reference in a new issue