much refactor

master
Edward Shen 2020-04-05 23:39:19 -04:00
parent 710a7dbebb
commit deb74da552
Signed by: edward
GPG Key ID: 19182661E818369F
7 changed files with 183 additions and 81 deletions

2
.gitignore vendored
View File

@ -1,2 +1,4 @@
/target
**/*.rs.bk
flamegraph.svg
perf.*

View File

@ -2,12 +2,14 @@
use super::{Actor, State};
use crate::{
game::Action,
cli::Train,
game::{Action, Controllable, Game},
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
};
use rand::rngs::SmallRng;
use rand::Rng;
use rand::{seq::SliceRandom, Rng, SeedableRng};
#[derive(Copy, Clone, Debug)]
pub struct Parameters {
total_height: f64,
bumpiness: f64,
@ -57,6 +59,7 @@ impl Parameters {
}
}
#[derive(Clone, Copy)]
pub struct GeneticHeuristicAgent {
params: Parameters,
}
@ -115,40 +118,101 @@ impl GeneticHeuristicAgent {
}
}
fn get_heuristic(&self, state: &State, action: &Action) -> f64 {
fn get_heuristic(&self, game: &Game, action: &Action) -> f64 {
todo!();
}
pub fn breed(&self, self_fitness: u32, other: &Self, other_fitness: u32) -> Self {
let weight = (self_fitness + other_fitness) as f64;
let self_weight = self_fitness as f64 / weight;
let other_weight = other_fitness as f64 / weight;
Self {
params: Parameters {
total_height: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
bumpiness: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
holes: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
complete_lines: self.params.total_height * self_weight
+ other.params.total_height * other_weight,
},
}
}
fn mutate(&mut self, rng: &mut SmallRng) {
self.params.mutate(rng);
}
}
impl Actor for GeneticHeuristicAgent {
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
fn get_action(&self, _: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
*legal_actions
.iter()
.map(|action| (action, self.get_heuristic(state, action)))
.max_by_key(|(action, heuristic)| (heuristic * 1_000_00.0) as usize)
.map(|action| (action, self.get_heuristic(game, action)))
.max_by_key(|(_, heuristic)| (heuristic * 1_000_00.0) as usize)
.unwrap()
.0
}
fn update(
&mut self,
state: State,
action: Action,
next_state: State,
next_legal_actions: &[Action],
reward: f64,
) {
unimplemented!()
}
fn set_learning_rate(&mut self, learning_rate: f64) {
unimplemented!()
}
fn set_exploration_prob(&mut self, exploration_prob: f64) {
unimplemented!()
}
fn set_discount_rate(&mut self, discount_rate: f64) {
unimplemented!()
}
fn dbg(&self) {
unimplemented!()
// unimplemented!()
}
}
pub fn train_actor(opts: &Train) -> Box<dyn Actor> {
let mut rng = SmallRng::from_entropy();
let mut population = vec![(GeneticHeuristicAgent::default(), 0); 1000];
for _ in 0..opts.episodes {
let mut new_population: Vec<(GeneticHeuristicAgent, u32)> = population
.iter()
.map(|(agent, _)| {
let mut fitness = 0;
for _ in 0..100 {
let mut game = Game::default();
game.set_piece_limit(500);
while (&game).is_game_over().is_none() {
super::apply_action_to_game(
agent.get_action(
&mut rng,
&game.clone().into(),
&game.get_legal_actions(),
),
&mut game,
)
}
fitness += game.line_clears;
}
(*agent, fitness)
})
.collect::<Vec<_>>();
let mut breeded_population: Vec<(GeneticHeuristicAgent, u32)> = Vec::with_capacity(300);
for _ in 0..3 {
let mut random_selection = new_population
.choose_multiple(&mut rng, 100)
.collect::<Vec<_>>();
random_selection.sort_unstable_by(|e1, e2| e1.1.cmp(&e2.1));
let best_two = random_selection.iter().rev().take(2).collect::<Vec<_>>();
let parent1 = best_two[0];
let parent2 = best_two[1];
for _ in 0..100 {
let breeded = parent1.0.breed(parent1.1, &parent2.0, parent2.1);
let mut cloned = breeded.clone();
if rng.gen::<f64>() < 0.05 {
cloned.mutate(&mut rng);
}
breeded_population.push((cloned, 0));
}
}
new_population.sort_unstable_by_key(|e| e.1);
new_population.splice(..breeded_population.len(), breeded_population);
population = new_population;
}
population.sort_unstable_by_key(|e| e.1);
Box::new(population.iter().rev().next().unwrap().0)
}

View File

@ -40,20 +40,7 @@ impl From<PlayField> for State {
}
pub trait Actor {
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action;
fn update(
&mut self,
state: State,
action: Action,
next_state: State,
next_legal_actions: &[Action],
reward: f64,
);
fn set_learning_rate(&mut self, learning_rate: f64);
fn set_exploration_prob(&mut self, exploration_prob: f64);
fn set_discount_rate(&mut self, discount_rate: f64);
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action;
fn dbg(&self);
}
@ -80,3 +67,16 @@ impl Predictable for Game {
game
}
}
pub fn apply_action_to_game(action: Action, game: &mut Game) {
match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
}
}

View File

@ -12,6 +12,21 @@ use rand::Rng;
use rand::SeedableRng;
use std::collections::HashMap;
pub trait QLearningActor: Actor {
fn update(
&mut self,
state: State,
action: Action,
next_state: State,
next_legal_actions: &[Action],
reward: f64,
);
fn set_learning_rate(&mut self, learning_rate: f64);
fn set_exploration_prob(&mut self, exploration_prob: f64);
fn set_discount_rate(&mut self, discount_rate: f64);
}
pub struct QLearningAgent {
pub learning_rate: f64,
pub exploration_prob: f64,
@ -64,14 +79,20 @@ impl QLearningAgent {
impl Actor for QLearningAgent {
// Because doing (Nothing) is in the set of legal actions, this will never
// be empty
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
if rng.gen::<f64>() < self.exploration_prob {
*legal_actions.choose(rng).unwrap()
} else {
self.get_action_from_q_values(state, legal_actions)
self.get_action_from_q_values(&game.into(), legal_actions)
}
}
fn dbg(&self) {
debug!("Total states: {}", self.q_values.len());
}
}
impl QLearningActor for QLearningAgent {
fn update(
&mut self,
state: State,
@ -94,10 +115,6 @@ impl Actor for QLearningAgent {
.insert(action, new_q_val);
}
fn dbg(&self) {
debug!("Total states: {}", self.q_values.len());
}
fn set_learning_rate(&mut self, learning_rate: f64) {
self.learning_rate = learning_rate;
}
@ -208,14 +225,20 @@ impl ApproximateQLearning {
}
impl Actor for ApproximateQLearning {
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
if rng.gen::<f64>() < self.exploration_prob {
*legal_actions.choose(rng).unwrap()
} else {
self.get_action_from_q_values(state, legal_actions)
self.get_action_from_q_values(&game.into(), legal_actions)
}
}
fn dbg(&self) {
dbg!(&self.weights);
}
}
impl QLearningActor for ApproximateQLearning {
fn update(
&mut self,
state: State,
@ -248,15 +271,15 @@ impl Actor for ApproximateQLearning {
fn set_discount_rate(&mut self, discount_rate: f64) {
self.discount_rate = discount_rate;
}
fn dbg(&self) {
dbg!(&self.weights);
}
}
pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>, opts: &Train) -> Box<dyn Actor> {
pub fn train_actor<T: 'static + QLearningActor + Actor>(
mut actor: T,
opts: &Train,
) -> Box<dyn Actor> {
let mut rng = SmallRng::from_entropy();
let mut avg = 0.0;
let episodes = opts.episodes;
actor.set_learning_rate(opts.learning_rate);
actor.set_discount_rate(opts.discount_rate);
@ -274,20 +297,11 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>, opts: &Train) ->
}
let mut game = Game::default();
while (&game).is_game_over().is_none() {
let cur_state = (&game).into();
let cur_state = game.clone();
let cur_score = game.score();
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
match action {
Action::Nothing => (),
Action::MoveLeft => game.move_left(),
Action::MoveRight => game.move_right(),
Action::SoftDrop => game.move_down(),
Action::HardDrop => game.hard_drop(),
Action::Hold => game.hold(),
Action::RotateLeft => game.rotate_left(),
Action::RotateRight => game.rotate_right(),
}
super::apply_action_to_game(action, &mut game);
let new_state = (&game).into();
let mut reward = game.score() as f64 - cur_score as f64;
@ -301,7 +315,13 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>, opts: &Train) ->
let new_legal_actions = game.get_legal_actions();
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
actor.update(
cur_state.into(),
action,
new_state,
&new_legal_actions,
reward,
);
game.tick();
}
@ -316,6 +336,5 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>, opts: &Train) ->
if opts.no_learn_during_evaluation {
actor.set_learning_rate(0.0);
}
actor
Box::new(actor)
}

View File

@ -74,6 +74,7 @@ arg_enum! {
pub enum Agent {
QLearning,
ApproximateQLearning,
HeuristicGenetic
}
}

View File

@ -16,6 +16,7 @@ const LINE_CLEAR_DELAY: u64 = TICKS_PER_SECOND as u64 * 41 / 60;
pub enum LossReason {
TopOut,
LockOut,
PieceLimitReached,
BlockOut(Position),
}
@ -35,6 +36,9 @@ pub struct Game {
/// bonus is needed.
last_clear_action: ClearAction,
pub line_clears: u32,
// used if we set a limit on how long a game can last.
pieces_placed: usize,
piece_limit: usize,
}
impl fmt::Debug for Game {
@ -59,6 +63,8 @@ impl Default for Game {
is_game_over: None,
last_clear_action: ClearAction::Single, // Doesn't matter what it's initialized to
line_clears: 0,
pieces_placed: 0,
piece_limit: 0,
}
}
}
@ -168,10 +174,18 @@ impl Game {
// It's possible that the player moved the piece in the meantime.
if !self.playfield.can_active_piece_move_down() {
let positions = self.playfield.lock_active_piece();
if self.pieces_placed < self.piece_limit {
self.pieces_placed += 1;
if self.pieces_placed >= self.piece_limit {
trace!("Loss due to piece limit!");
self.is_game_over = Some(LossReason::PieceLimitReached);
}
}
self.is_game_over = self.is_game_over.or_else(|| {
if positions.iter().map(|p| p.y).all(|y| y < 20) {
trace!("Loss due to topout! {:?}", positions);
Some(LossReason::TopOut)
trace!("Loss due to lockout! {:?}", positions);
Some(LossReason::LockOut)
} else {
None
}
@ -183,6 +197,7 @@ impl Game {
self.line_clears += cleared_lines as u32;
self.score += (cleared_lines * 100 * self.level as usize) as u32;
self.level = (self.line_clears / 10) as u8;
self.playfield.active_piece = None;
self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY;
} else {
@ -195,6 +210,10 @@ impl Game {
}
}
pub fn set_piece_limit(&mut self, size: usize) {
self.piece_limit = size;
}
pub fn playfield(&self) -> &PlayField {
&self.playfield
}

View File

@ -30,18 +30,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
init_verbosity(&opts)?;
let agent = match opts.subcmd {
SubCommand::Play(sub_opts) => None,
SubCommand::Play(_) => None,
SubCommand::Train(sub_opts) => Some(match sub_opts.agent {
Agent::QLearning => qlearning::train_actor(
sub_opts.episodes,
Box::new(qlearning::QLearningAgent::default()),
&sub_opts,
),
Agent::ApproximateQLearning => qlearning::train_actor(
sub_opts.episodes,
Box::new(qlearning::ApproximateQLearning::default()),
&sub_opts,
),
Agent::QLearning => {
qlearning::train_actor(qlearning::QLearningAgent::default(), &sub_opts)
}
Agent::ApproximateQLearning => {
qlearning::train_actor(qlearning::ApproximateQLearning::default(), &sub_opts)
}
Agent::HeuristicGenetic => genetic::train_actor(&sub_opts),
}),
};
@ -70,7 +67,7 @@ async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std:
None => (),
}
let cur_state = (&game).into();
let cur_state = game.clone();
// If there's an actor, the player action will get overridden. If not,
// then then the player action falls through, if there is one. This is