Compare commits
No commits in common. "f3b48fbc85a42a04807c780ecc8ca797fdf22990" and "ea2c926c508422c554734e3acbd62c551be29446" have entirely different histories.
f3b48fbc85
...
ea2c926c50
5 changed files with 19 additions and 193 deletions
|
@ -10,7 +10,6 @@ pub struct State {
|
||||||
matrix: Matrix,
|
matrix: Matrix,
|
||||||
active_piece: Option<Tetromino>,
|
active_piece: Option<Tetromino>,
|
||||||
held_piece: Option<TetrominoType>,
|
held_piece: Option<TetrominoType>,
|
||||||
line_clears: u32,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Game> for State {
|
impl From<Game> for State {
|
||||||
|
@ -21,9 +20,7 @@ impl From<Game> for State {
|
||||||
|
|
||||||
impl From<&Game> for State {
|
impl From<&Game> for State {
|
||||||
fn from(game: &Game) -> Self {
|
fn from(game: &Game) -> Self {
|
||||||
let mut state: State = game.playfield().clone().into();
|
game.playfield().clone().into()
|
||||||
state.line_clears = game.line_clears;
|
|
||||||
state
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,7 +30,6 @@ impl From<PlayField> for State {
|
||||||
matrix: playfield.field().clone(),
|
matrix: playfield.field().clone(),
|
||||||
active_piece: playfield.active_piece,
|
active_piece: playfield.active_piece,
|
||||||
held_piece: playfield.hold_piece().map(|t| t.clone()),
|
held_piece: playfield.hold_piece().map(|t| t.clone()),
|
||||||
line_clears: 0,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -41,14 +37,7 @@ impl From<PlayField> for State {
|
||||||
pub trait Actor {
|
pub trait Actor {
|
||||||
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action;
|
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action;
|
||||||
|
|
||||||
fn update(
|
fn update(&mut self, state: State, action: Action, next_state: State, reward: f64);
|
||||||
&mut self,
|
|
||||||
state: State,
|
|
||||||
action: Action,
|
|
||||||
next_state: State,
|
|
||||||
next_legal_actions: &[Action],
|
|
||||||
reward: f64,
|
|
||||||
);
|
|
||||||
|
|
||||||
fn set_learning_rate(&mut self, learning_rate: f64);
|
fn set_learning_rate(&mut self, learning_rate: f64);
|
||||||
fn set_exploration_prob(&mut self, exploration_prob: f64);
|
fn set_exploration_prob(&mut self, exploration_prob: f64);
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
use crate::actors::{Actor, State};
|
use crate::actors::{Actor, State};
|
||||||
use crate::{
|
use crate::game::Action;
|
||||||
game::Action,
|
|
||||||
playfield::{PLAYFIELD_HEIGHT, PLAYFIELD_WIDTH},
|
|
||||||
};
|
|
||||||
use log::debug;
|
use log::debug;
|
||||||
use rand::rngs::SmallRng;
|
use rand::rngs::SmallRng;
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
|
@ -18,10 +15,10 @@ pub struct QLearningAgent {
|
||||||
|
|
||||||
impl Default for QLearningAgent {
|
impl Default for QLearningAgent {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
QLearningAgent {
|
||||||
learning_rate: 0.0,
|
learning_rate: 0.1,
|
||||||
exploration_prob: 0.0,
|
exploration_prob: 0.5,
|
||||||
discount_rate: 0.0,
|
discount_rate: 1.0,
|
||||||
q_values: HashMap::default(),
|
q_values: HashMap::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -69,14 +66,7 @@ impl Actor for QLearningAgent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update(
|
fn update(&mut self, state: State, action: Action, next_state: State, reward: f64) {
|
||||||
&mut self,
|
|
||||||
state: State,
|
|
||||||
action: Action,
|
|
||||||
next_state: State,
|
|
||||||
_next_legal_actions: &[Action],
|
|
||||||
reward: f64,
|
|
||||||
) {
|
|
||||||
let cur_q_val = self.get_q_value(&state, action);
|
let cur_q_val = self.get_q_value(&state, action);
|
||||||
let new_q_val = cur_q_val
|
let new_q_val = cur_q_val
|
||||||
+ self.learning_rate
|
+ self.learning_rate
|
||||||
|
@ -105,148 +95,3 @@ impl Actor for QLearningAgent {
|
||||||
self.discount_rate = discount_rate;
|
self.discount_rate = discount_rate;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ApproximateQLearning {
|
|
||||||
pub learning_rate: f64,
|
|
||||||
pub exploration_prob: f64,
|
|
||||||
pub discount_rate: f64,
|
|
||||||
weights: HashMap<String, f64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ApproximateQLearning {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
learning_rate: 0.0,
|
|
||||||
exploration_prob: 0.0,
|
|
||||||
discount_rate: 0.0,
|
|
||||||
weights: HashMap::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ApproximateQLearning {
|
|
||||||
fn get_features(
|
|
||||||
&self,
|
|
||||||
state: &State,
|
|
||||||
_action: &Action,
|
|
||||||
new_state: &State,
|
|
||||||
) -> HashMap<String, f64> {
|
|
||||||
let mut features = HashMap::default();
|
|
||||||
|
|
||||||
let mut heights = [None; PLAYFIELD_WIDTH];
|
|
||||||
for r in 0..PLAYFIELD_HEIGHT {
|
|
||||||
for c in 0..PLAYFIELD_WIDTH {
|
|
||||||
if heights[c].is_none() && state.matrix[r][c].is_some() {
|
|
||||||
heights[c] = Some(PLAYFIELD_HEIGHT - r);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
features.insert(
|
|
||||||
"Total Height".into(),
|
|
||||||
heights
|
|
||||||
.iter()
|
|
||||||
.map(|o| o.unwrap_or_else(|| 0))
|
|
||||||
.sum::<usize>() as f64
|
|
||||||
/ (PLAYFIELD_HEIGHT * PLAYFIELD_WIDTH) as f64,
|
|
||||||
);
|
|
||||||
|
|
||||||
features.insert(
|
|
||||||
"Bumpiness".into(),
|
|
||||||
heights
|
|
||||||
.iter()
|
|
||||||
.map(|o| o.unwrap_or_else(|| 0) as isize)
|
|
||||||
.fold((0, 0), |(acc, prev), cur| (acc + (prev - cur).abs(), cur))
|
|
||||||
.0 as f64
|
|
||||||
/ (PLAYFIELD_WIDTH * 40) as f64,
|
|
||||||
);
|
|
||||||
|
|
||||||
features.insert(
|
|
||||||
"Lines cleared".into(),
|
|
||||||
(new_state.line_clears - state.line_clears) as f64 / 4.0,
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut holes = 0;
|
|
||||||
for r in 1..PLAYFIELD_HEIGHT {
|
|
||||||
for c in 0..PLAYFIELD_WIDTH {
|
|
||||||
if state.matrix[r][c].is_none() && state.matrix[r - 1][c].is_some() {
|
|
||||||
holes += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
features.insert("Holes".into(), holes as f64);
|
|
||||||
|
|
||||||
features
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_q_value(&self, state: &State, action: &Action, next_state: &State) -> f64 {
|
|
||||||
self.get_features(state, action, next_state)
|
|
||||||
.iter()
|
|
||||||
.map(|(key, val)| val * *self.weights.get(key).unwrap_or_else(|| &0.0))
|
|
||||||
.sum()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_action_from_q_values(&self, state: &State, legal_actions: &[Action]) -> Action {
|
|
||||||
*legal_actions
|
|
||||||
.iter()
|
|
||||||
.map(|action| (action, self.get_q_value(&state, action, state)))
|
|
||||||
.max_by_key(|(_, q1)| ((q1 * 1_000_000.0) as isize))
|
|
||||||
.expect("Failed to select an action")
|
|
||||||
.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_value(&self, state: &State, legal_actions: &[Action]) -> f64 {
|
|
||||||
legal_actions
|
|
||||||
.iter()
|
|
||||||
.map(|action| self.get_q_value(state, action, state))
|
|
||||||
.max_by_key(|v| (v * 1_000_000.0) as isize)
|
|
||||||
.unwrap_or_else(|| 0.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Actor for ApproximateQLearning {
|
|
||||||
fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action {
|
|
||||||
if rng.gen::<f64>() < self.exploration_prob {
|
|
||||||
*legal_actions.choose(rng).unwrap()
|
|
||||||
} else {
|
|
||||||
self.get_action_from_q_values(state, legal_actions)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn update(
|
|
||||||
&mut self,
|
|
||||||
state: State,
|
|
||||||
action: Action,
|
|
||||||
next_state: State,
|
|
||||||
next_legal_actions: &[Action],
|
|
||||||
reward: f64,
|
|
||||||
) {
|
|
||||||
let difference = reward
|
|
||||||
+ self.discount_rate * self.get_value(&next_state, next_legal_actions)
|
|
||||||
- self.get_q_value(&state, &action, &next_state);
|
|
||||||
|
|
||||||
for (feat_key, feat_val) in self.get_features(&state, &action, &next_state) {
|
|
||||||
self.weights.insert(
|
|
||||||
feat_key.clone(),
|
|
||||||
*self.weights.get(&feat_key).unwrap_or_else(|| &0.0)
|
|
||||||
+ self.learning_rate * difference * feat_val,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_learning_rate(&mut self, learning_rate: f64) {
|
|
||||||
self.learning_rate = learning_rate;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_exploration_prob(&mut self, exploration_prob: f64) {
|
|
||||||
self.exploration_prob = exploration_prob;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_discount_rate(&mut self, discount_rate: f64) {
|
|
||||||
self.discount_rate = discount_rate;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dbg(&self) {
|
|
||||||
dbg!(&self.weights);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
12
src/cli.rs
12
src/cli.rs
|
@ -65,15 +65,12 @@ pub struct Train {
|
||||||
/// Number of episodes to train the agent
|
/// Number of episodes to train the agent
|
||||||
#[clap(short = "n", long = "num", default_value = "10")]
|
#[clap(short = "n", long = "num", default_value = "10")]
|
||||||
pub episodes: usize,
|
pub episodes: usize,
|
||||||
// #[clap(long = "use-epsilon-decreasing")]
|
|
||||||
// pub epsilon_decreasing: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
arg_enum! {
|
arg_enum! {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Agent {
|
pub enum Agent {
|
||||||
QLearning,
|
QLearning
|
||||||
ApproximateQLearning,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,9 +87,6 @@ pub fn init_verbosity(opts: &Opts) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_actor(agent: Agent) -> Box<dyn Actor> {
|
pub fn get_actor() -> impl Actor {
|
||||||
match agent {
|
qlearning::QLearningAgent::default()
|
||||||
Agent::QLearning => Box::new(qlearning::QLearningAgent::default()),
|
|
||||||
Agent::ApproximateQLearning => Box::new(qlearning::ApproximateQLearning::default()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ pub struct Game {
|
||||||
/// The last clear action performed, used for determining if a back-to-back
|
/// The last clear action performed, used for determining if a back-to-back
|
||||||
/// bonus is needed.
|
/// bonus is needed.
|
||||||
last_clear_action: ClearAction,
|
last_clear_action: ClearAction,
|
||||||
pub line_clears: u32,
|
line_clears: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Debug for Game {
|
impl fmt::Debug for Game {
|
||||||
|
@ -178,7 +178,7 @@ impl Game {
|
||||||
if cleared_lines > 0 {
|
if cleared_lines > 0 {
|
||||||
trace!("Lines were cleared.");
|
trace!("Lines were cleared.");
|
||||||
self.line_clears += cleared_lines as u32;
|
self.line_clears += cleared_lines as u32;
|
||||||
self.score += (cleared_lines * 100 * self.level as usize) as u32;
|
self.score += (cleared_lines * self.level as usize) as u32;
|
||||||
self.level = (self.line_clears / 10) as u8;
|
self.level = (self.line_clears / 10) as u8;
|
||||||
self.playfield.active_piece = None;
|
self.playfield.active_piece = None;
|
||||||
self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY;
|
self.next_spawn_tick = self.tick + LINE_CLEAR_DELAY;
|
||||||
|
|
14
src/main.rs
14
src/main.rs
|
@ -33,7 +33,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
match opts.subcmd {
|
match opts.subcmd {
|
||||||
SubCommand::Play(sub_opts) => {}
|
SubCommand::Play(sub_opts) => {}
|
||||||
SubCommand::Train(sub_opts) => {
|
SubCommand::Train(sub_opts) => {
|
||||||
let mut to_train = get_actor(sub_opts.agent);
|
let mut to_train = get_actor();
|
||||||
to_train.set_learning_rate(sub_opts.learning_rate);
|
to_train.set_learning_rate(sub_opts.learning_rate);
|
||||||
to_train.set_discount_rate(sub_opts.discount_rate);
|
to_train.set_discount_rate(sub_opts.discount_rate);
|
||||||
to_train.set_exploration_prob(sub_opts.exploration_prob);
|
to_train.set_exploration_prob(sub_opts.exploration_prob);
|
||||||
|
@ -61,7 +61,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
|
fn train_actor(episodes: usize, mut actor: impl Actor) -> impl Actor {
|
||||||
let mut rng = rand::rngs::SmallRng::from_entropy();
|
let mut rng = rand::rngs::SmallRng::from_entropy();
|
||||||
let mut avg = 0.0;
|
let mut avg = 0.0;
|
||||||
|
|
||||||
|
@ -91,16 +91,14 @@ fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
|
||||||
let new_state = (&game).into();
|
let new_state = (&game).into();
|
||||||
let mut reward = game.score() as f64 - cur_score as f64;
|
let mut reward = game.score() as f64 - cur_score as f64;
|
||||||
if action != Action::Nothing {
|
if action != Action::Nothing {
|
||||||
reward -= 0.0;
|
reward -= 10.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if game.is_game_over().is_some() {
|
if game.is_game_over().is_some() {
|
||||||
reward = -1.0;
|
reward = -100.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
let new_legal_actions = game.get_legal_actions();
|
actor.update(cur_state, action, new_state, reward);
|
||||||
|
|
||||||
actor.update(cur_state, action, new_state, &new_legal_actions, reward);
|
|
||||||
|
|
||||||
game.tick();
|
game.tick();
|
||||||
}
|
}
|
||||||
|
@ -111,7 +109,7 @@ fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
|
||||||
actor
|
actor
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
|
async fn play_game(mut actor: Option<impl Actor>) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let mut rng = rand::rngs::SmallRng::from_entropy();
|
let mut rng = rand::rngs::SmallRng::from_entropy();
|
||||||
let sdl_context = sdl2::init()?;
|
let sdl_context = sdl2::init()?;
|
||||||
let video_subsystem = sdl_context.video()?;
|
let video_subsystem = sdl_context.video()?;
|
||||||
|
|
Loading…
Reference in a new issue