much refactor
This commit is contained in:
parent
710a7dbebb
commit
deb74da552
7 changed files with 183 additions and 81 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,2 +1,4 @@
|
|||
/target
|
||||
**/*.rs.bk
|
||||
flamegraph.svg
|
||||
perf.*
|
||||
|
|
|
@ -2,12 +2,14 @@
|
|||
|
||||
use super::{Actor, State};
|
||||
use crate::{
|
||||
game::Action,
|
||||
cli::Train,
|
||||
game::{Action, Controllable, Game},
|
||||
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
||||
};
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::Rng;
|
||||
use rand::{seq::SliceRandom, Rng, SeedableRng};
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Parameters {
|
||||
total_height: f64,
|
||||
bumpiness: f64,
|
||||
|
@ -57,6 +59,7 @@ impl Parameters {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct GeneticHeuristicAgent {
|
||||
params: Parameters,
|
||||
}
|
||||
|
@ -115,40 +118,101 @@ impl GeneticHeuristicAgent {
|
|||
}
|
||||
}
|
||||
|
||||
fn get_heuristic(&self, state: &State, action: &Action) -> f64 {
|
||||
fn get_heuristic(&self, game: &Game, action: &Action) -> f64 {
|
||||
todo!();
|
||||
}
|
||||
|
||||
pub fn breed(&self, self_fitness: u32, other: &Self, other_fitness: u32) -> Self {
|
||||
let weight = (self_fitness + other_fitness) as f64;
|
||||
let self_weight = self_fitness as f64 / weight;
|
||||
let other_weight = other_fitness as f64 / weight;
|
||||
Self {
|
||||
params: Parameters {
|
||||
total_height: self.params.total_height * self_weight
|
||||
+ other.params.total_height * other_weight,
|
||||
bumpiness: self.params.total_height * self_weight
|
||||
+ other.params.total_height * other_weight,
|
||||
holes: self.params.total_height * self_weight
|
||||
+ other.params.total_height * other_weight,
|
||||
complete_lines: self.params.total_height * self_weight
|
||||
+ other.params.total_height * other_weight,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn mutate(&mut self, rng: &mut SmallRng) {
|
||||
self.params.mutate(rng);
|
||||
}
|
||||
}
|
||||
|
||||
impl Actor for GeneticHeuristicAgent {
|
||||
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
|
||||
fn get_action(&self, _: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
|
||||
*legal_actions
|
||||
.iter()
|
||||
.map(|action| (action, self.get_heuristic(state, action)))
|
||||
.max_by_key(|(action, heuristic)| (heuristic * 1_000_00.0) as usize)
|
||||
.map(|action| (action, self.get_heuristic(game, action)))
|
||||
.max_by_key(|(_, heuristic)| (heuristic * 1_000_00.0) as usize)
|
||||
.unwrap()
|
||||
.0
|
||||
}
|
||||
fn update(
|
||||
&mut self,
|
||||
state: State,
|
||||
action: Action,
|
||||
next_state: State,
|
||||
next_legal_actions: &[Action],
|
||||
reward: f64,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
fn set_learning_rate(&mut self, learning_rate: f64) {
|
||||
unimplemented!()
|
||||
}
|
||||
fn set_exploration_prob(&mut self, exploration_prob: f64) {
|
||||
unimplemented!()
|
||||
}
|
||||
fn set_discount_rate(&mut self, discount_rate: f64) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn dbg(&self) {
|
||||
unimplemented!()
|
||||
// unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn train_actor(opts: &Train) -> Box<dyn Actor> {
|
||||
let mut rng = SmallRng::from_entropy();
|
||||
let mut population = vec![(GeneticHeuristicAgent::default(), 0); 1000];
|
||||
|
||||
for _ in 0..opts.episodes {
|
||||
let mut new_population: Vec<(GeneticHeuristicAgent, u32)> = population
|
||||
.iter()
|
||||
.map(|(agent, _)| {
|
||||
let mut fitness = 0;
|
||||
for _ in 0..100 {
|
||||
let mut game = Game::default();
|
||||
game.set_piece_limit(500);
|
||||
while (&game).is_game_over().is_none() {
|
||||
super::apply_action_to_game(
|
||||
agent.get_action(
|
||||
&mut rng,
|
||||
&game.clone().into(),
|
||||
&game.get_legal_actions(),
|
||||
),
|
||||
&mut game,
|
||||
)
|
||||
}
|
||||
fitness += game.line_clears;
|
||||
}
|
||||
(*agent, fitness)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut breeded_population: Vec<(GeneticHeuristicAgent, u32)> = Vec::with_capacity(300);
|
||||
for _ in 0..3 {
|
||||
let mut random_selection = new_population
|
||||
.choose_multiple(&mut rng, 100)
|
||||
.collect::<Vec<_>>();
|
||||
random_selection.sort_unstable_by(|e1, e2| e1.1.cmp(&e2.1));
|
||||
let best_two = random_selection.iter().rev().take(2).collect::<Vec<_>>();
|
||||
let parent1 = best_two[0];
|
||||
let parent2 = best_two[1];
|
||||
for _ in 0..100 {
|
||||
let breeded = parent1.0.breed(parent1.1, &parent2.0, parent2.1);
|
||||
let mut cloned = breeded.clone();
|
||||
if rng.gen::<f64>() < 0.05 {
|
||||
cloned.mutate(&mut rng);
|
||||
}
|
||||
breeded_population.push((cloned, 0));
|
||||
}
|
||||
}
|
||||
|
||||
new_population.sort_unstable_by_key(|e| e.1);
|
||||
|
||||
new_population.splice(..breeded_population.len(), breeded_population);
|
||||
|
||||
population = new_population;
|
||||
}
|
||||
population.sort_unstable_by_key(|e| e.1);
|
||||
Box::new(population.iter().rev().next().unwrap().0)
|
||||
}
|
||||
|
|
|
@ -40,20 +40,7 @@ impl From<PlayField> for State {
|
|||
}
|
||||
|
||||
pub trait Actor {
|
||||
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action;
|
||||
|
||||
fn update(
|
||||
&mut self,
|
||||
state: State,
|
||||
action: Action,
|
||||
next_state: State,
|
||||
next_legal_actions: &[Action],
|
||||
reward: f64,
|
||||
);
|
||||
|
||||
fn set_learning_rate(&mut self, learning_rate: f64);
|
||||
fn set_exploration_prob(&mut self, exploration_prob: f64);
|
||||
fn set_discount_rate(&mut self, discount_rate: f64);
|
||||
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action;
|
||||
|
||||
fn dbg(&self);
|
||||
}
|
||||
|
@ -80,3 +67,16 @@ impl Predictable for Game {
|
|||
game
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply_action_to_game(action: Action, game: &mut Game) {
|
||||
match action {
|
||||
Action::Nothing => (),
|
||||
Action::MoveLeft => game.move_left(),
|
||||
Action::MoveRight => game.move_right(),
|
||||
Action::SoftDrop => game.move_down(),
|
||||
Action::HardDrop => game.hard_drop(),
|
||||
Action::Hold => game.hold(),
|
||||
Action::RotateLeft => game.rotate_left(),
|
||||
Action::RotateRight => game.rotate_right(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,21 @@ use rand::Rng;
|
|||
use rand::SeedableRng;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub trait QLearningActor: Actor {
|
||||
fn update(
|
||||
&mut self,
|
||||
state: State,
|
||||
action: Action,
|
||||
next_state: State,
|
||||
next_legal_actions: &[Action],
|
||||
reward: f64,
|
||||
);
|
||||
|
||||
fn set_learning_rate(&mut self, learning_rate: f64);
|
||||
fn set_exploration_prob(&mut self, exploration_prob: f64);
|
||||
fn set_discount_rate(&mut self, discount_rate: f64);
|
||||
}
|
||||
|
||||
pub struct QLearningAgent {
|
||||
pub learning_rate: f64,
|
||||
pub exploration_prob: f64,
|
||||
|
@ -64,14 +79,20 @@ impl QLearningAgent {
|
|||
impl Actor for QLearningAgent {
|
||||
// Because doing (Nothing) is in the set of legal actions, this will never
|
||||
// be empty
|
||||
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
|
||||
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
|
||||
if rng.gen::<f64>() < self.exploration_prob {
|
||||
*legal_actions.choose(rng).unwrap()
|
||||
} else {
|
||||
self.get_action_from_q_values(state, legal_actions)
|
||||
self.get_action_from_q_values(&game.into(), legal_actions)
|
||||
}
|
||||
}
|
||||
|
||||
fn dbg(&self) {
|
||||
debug!("Total states: {}", self.q_values.len());
|
||||
}
|
||||
}
|
||||
|
||||
impl QLearningActor for QLearningAgent {
|
||||
fn update(
|
||||
&mut self,
|
||||
state: State,
|
||||
|
@ -94,10 +115,6 @@ impl Actor for QLearningAgent {
|
|||
.insert(action, new_q_val);
|
||||
}
|
||||
|
||||
fn dbg(&self) {
|
||||
debug!("Total states: {}", self.q_values.len());
|
||||
}
|
||||
|
||||
fn set_learning_rate(&mut self, learning_rate: f64) {
|
||||
self.learning_rate = learning_rate;
|
||||
}
|
||||
|
@ -208,14 +225,20 @@ impl ApproximateQLearning {
|
|||
}
|
||||
|
||||
impl Actor for ApproximateQLearning {
|
||||
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
|
||||
fn get_action(&self, rng: &mut SmallRng, game: &Game, legal_actions: &[Action]) -> Action {
|
||||
if rng.gen::<f64>() < self.exploration_prob {
|
||||
*legal_actions.choose(rng).unwrap()
|
||||
} else {
|
||||
self.get_action_from_q_values(state, legal_actions)
|
||||
self.get_action_from_q_values(&game.into(), legal_actions)
|
||||
}
|
||||
}
|
||||
|
||||
fn dbg(&self) {
|
||||
dbg!(&self.weights);
|
||||
}
|
||||
}
|
||||
|
||||
impl QLearningActor for ApproximateQLearning {
|
||||
fn update(
|
||||
&mut self,
|
||||
state: State,
|
||||
|
@ -248,15 +271,15 @@ impl Actor for ApproximateQLearning {
|
|||
fn set_discount_rate(&mut self, discount_rate: f64) {
|
||||
self.discount_rate = discount_rate;
|
||||
}
|
||||
|
||||
fn dbg(&self) {
|
||||
dbg!(&self.weights);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>, opts: &Train) -> Box<dyn Actor> {
|
||||
pub fn train_actor<T: 'static + QLearningActor + Actor>(
|
||||
mut actor: T,
|
||||
opts: &Train,
|
||||
) -> Box<dyn Actor> {
|
||||
let mut rng = SmallRng::from_entropy();
|
||||
let mut avg = 0.0;
|
||||
let episodes = opts.episodes;
|
||||
|
||||
actor.set_learning_rate(opts.learning_rate);
|
||||
actor.set_discount_rate(opts.discount_rate);
|
||||
|
@ -274,20 +297,11 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>, opts: &Train) ->
|
|||
}
|
||||
let mut game = Game::default();
|
||||
while (&game).is_game_over().is_none() {
|
||||
let cur_state = (&game).into();
|
||||
let cur_state = game.clone();
|
||||
let cur_score = game.score();
|
||||
let action = actor.get_action(&mut rng, &cur_state, &((&game).get_legal_actions()));
|
||||
|
||||
match action {
|
||||
Action::Nothing => (),
|
||||
Action::MoveLeft => game.move_left(),
|
||||
Action::MoveRight => game.move_right(),
|
||||
Action::SoftDrop => game.move_down(),
|
||||
Action::HardDrop => game.hard_drop(),
|
||||
Action::Hold => game.hold(),
|
||||
Action::RotateLeft => game.rotate_left(),
|
||||
Action::RotateRight => game.rotate_right(),
|
||||
}
|
||||
super::apply_action_to_game(action, &mut game);
|
||||
|
||||
let new_state = (&game).into();
|
||||
let mut reward = game.score() as f64 - cur_score as f64;
|
||||
|
@ -301,7 +315,13 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>, opts: &Train) ->
|
|||
|
||||
let new_legal_actions = game.get_legal_actions();
|
||||
|
||||
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
|
||||
actor.update(
|
||||
cur_state.into(),
|
||||
action,
|
||||
new_state,
|
||||
&new_legal_actions,
|
||||
reward,
|
||||
);
|
||||
|
||||
game.tick();
|
||||
}
|
||||
|
@ -316,6 +336,5 @@ pub fn train_actor(episodes: usize, mut actor: Box<dyn Actor>, opts: &Train) ->
|
|||
if opts.no_learn_during_evaluation {
|
||||
actor.set_learning_rate(0.0);
|
||||
}
|
||||
|
||||
actor
|
||||
Box::new(actor)
|
||||
}
|
||||
|
|
|
@ -74,6 +74,7 @@ arg_enum! {
|
|||
pub enum Agent {
|
||||
QLearning,
|
||||
ApproximateQLearning,
|
||||
HeuristicGenetic
|
||||
}
|
||||
}
|
||||
|
||||
|
|
23
src/game.rs
23
src/game.rs
|
@ -16,6 +16,7 @@ const LINE_CLEAR_DELAY: u64 = TICKS_PER_SECOND as u64 * 41 / 60;
|
|||
pub enum LossReason {
|
||||
TopOut,
|
||||
LockOut,
|
||||
PieceLimitReached,
|
||||
BlockOut(Position),
|
||||
}
|
||||
|
||||
|
@ -35,6 +36,9 @@ pub struct Game {
|
|||
/// bonus is needed.
|
||||
last_clear_action: ClearAction,
|
||||
pub line_clears: u32,
|
||||
// used if we set a limit on how long a game can last.
|
||||
pieces_placed: usize,
|
||||
piece_limit: usize,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Game {
|
||||
|
@ -59,6 +63,8 @@ impl Default for Game {
|
|||
is_game_over: None,
|
||||
last_clear_action: ClearAction::Single, // Doesn't matter what it's initialized to
|
||||
line_clears: 0,
|
||||
pieces_placed: 0,
|
||||
piece_limit: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -168,10 +174,18 @@ impl Game {
|
|||
// It's possible that the player moved the piece in the meantime.
|
||||
if !self.playfield.can_active_piece_move_down() {
|
||||
let positions = self.playfield.lock_active_piece();
|
||||
if self.pieces_placed < self.piece_limit {
|
||||
self.pieces_placed += 1;
|
||||
if self.pieces_placed >= self.piece_limit {
|
||||
trace!("Loss due to piece limit!");
|
||||
self.is_game_over = Some(LossReason::PieceLimitReached);
|
||||
}
|
||||
}
|
||||
|
||||
self.is_game_over = self.is_game_over.or_else(|| {
|
||||
if positions.iter().map(|p| p.y).all(|y| y < 20) {
|
||||
trace!("Loss due to topout! {:?}", positions);
|
||||
Some(LossReason::TopOut)
|
||||
trace!("Loss due to lockout! {:?}", positions);
|
||||
Some(LossReason::LockOut)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -183,6 +197,7 @@ impl Game {
|
|||
self.line_clears += cleared_lines as u32;
|
||||
self.score += (cleared_lines * 100 * self.level as usize) as u32;
|
||||
self.level = (self.line_clears / 10) as u8;
|
||||
|
||||
self.playfield.active_piece = None;
|
||||
self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY;
|
||||
} else {
|
||||
|
@ -195,6 +210,10 @@ impl Game {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn set_piece_limit(&mut self, size: usize) {
|
||||
self.piece_limit = size;
|
||||
}
|
||||
|
||||
pub fn playfield(&self) -> &PlayField {
|
||||
&self.playfield
|
||||
}
|
||||
|
|
21
src/main.rs
21
src/main.rs
|
@ -30,18 +30,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
init_verbosity(&opts)?;
|
||||
|
||||
let agent = match opts.subcmd {
|
||||
SubCommand::Play(sub_opts) => None,
|
||||
SubCommand::Play(_) => None,
|
||||
SubCommand::Train(sub_opts) => Some(match sub_opts.agent {
|
||||
Agent::QLearning => qlearning::train_actor(
|
||||
sub_opts.episodes,
|
||||
Box::new(qlearning::QLearningAgent::default()),
|
||||
&sub_opts,
|
||||
),
|
||||
Agent::ApproximateQLearning => qlearning::train_actor(
|
||||
sub_opts.episodes,
|
||||
Box::new(qlearning::ApproximateQLearning::default()),
|
||||
&sub_opts,
|
||||
),
|
||||
Agent::QLearning => {
|
||||
qlearning::train_actor(qlearning::QLearningAgent::default(), &sub_opts)
|
||||
}
|
||||
Agent::ApproximateQLearning => {
|
||||
qlearning::train_actor(qlearning::ApproximateQLearning::default(), &sub_opts)
|
||||
}
|
||||
Agent::HeuristicGenetic => genetic::train_actor(&sub_opts),
|
||||
}),
|
||||
};
|
||||
|
||||
|
@ -70,7 +67,7 @@ async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std:
|
|||
None => (),
|
||||
}
|
||||
|
||||
let cur_state = (&game).into();
|
||||
let cur_state = game.clone();
|
||||
|
||||
// If there's an actor, the player action will get overridden. If not,
|
||||
// then then the player action falls through, if there is one. This is
|
||||
|
|
Loading…
Reference in a new issue