diff --git a/Cargo.lock b/Cargo.lock index 90df455..dde6bd9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -400,6 +400,7 @@ dependencies = [ "rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", "rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_pcg 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -427,6 +428,14 @@ dependencies = [ "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "rand_pcg" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "redox_syscall" version = "0.1.56" @@ -718,6 +727,7 @@ dependencies = [ "checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853" "checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" "checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +"checksum rand_pcg 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "16abd0c1b639e9eb4d7c50c0b8100b0d0f849be2349829c740fe8e6eb4816429" "checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84" "checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3" "checksum regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)" = "7fe5bd57d1d7414c6b5ed48563a2c855d995ff777729dcd91c369ec7fea395ae" diff --git a/Cargo.toml b/Cargo.toml index b91cf54..456d62e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -rand = "0.7" +rand = { version = "0.7", features = ["default", "small_rng"] } tokio = { version = "0.2", features = ["full"] } log = "0.4" simple_logger = "1.6" diff --git a/src/actors/mod.rs b/src/actors/mod.rs index 6c14aaf..e4d17da 100644 --- a/src/actors/mod.rs +++ b/src/actors/mod.rs @@ -1,7 +1,7 @@ use crate::game::{Action, Game}; use crate::playfield::{Matrix, PlayField}; use crate::tetromino::{Tetromino, TetrominoType}; -use rand::RngCore; +use rand::rngs::SmallRng; pub mod qlearning; @@ -35,12 +35,7 @@ impl From for State { } pub trait Actor { - fn get_action( - &self, - rng: &mut T, - state: &State, - legal_actions: &[Action], - ) -> Action; + fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action; fn update(&mut self, state: State, action: Action, next_state: State, reward: f64); diff --git a/src/actors/qlearning.rs b/src/actors/qlearning.rs index ead215b..8bc3605 100644 --- a/src/actors/qlearning.rs +++ b/src/actors/qlearning.rs @@ -1,9 +1,9 @@ use crate::actors::{Actor, State}; use crate::game::Action; use log::debug; +use rand::rngs::SmallRng; use rand::seq::SliceRandom; use rand::Rng; -use rand::RngCore; use std::collections::HashMap; pub struct QLearningAgent { @@ -58,12 +58,7 @@ impl QLearningAgent { impl Actor for QLearningAgent { // Because doing (Nothing) is in the set of legal actions, this will never // be empty - fn get_action( - &self, - rng: &mut T, - state: &State, - legal_actions: &[Action], - ) -> Action { + fn get_action(&self, rng: &mut SmallRng, state: &State, legal_actions: &[Action]) -> Action { if rng.gen::() < self.exploration_prob { *legal_actions.choose(rng).unwrap() } else { diff --git a/src/main.rs b/src/main.rs index 4d3fd27..00339c3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -62,7 +62,7 @@ async fn main() -> Result<(), Box> { } fn train_actor(episodes: usize, mut actor: impl Actor) -> impl Actor { - let mut rng = rand::rngs::StdRng::from_entropy(); + let mut rng = rand::rngs::SmallRng::from_entropy(); let mut avg = 0.0; for i in (0..episodes).progress() { @@ -110,7 +110,7 @@ fn train_actor(episodes: usize, mut actor: impl Actor) -> impl Actor { } async fn play_game(mut actor: Option) -> Result<(), Box> { - let mut rng = rand::rngs::StdRng::from_entropy(); + let mut rng = rand::rngs::SmallRng::from_entropy(); let sdl_context = sdl2::init()?; let video_subsystem = sdl_context.video()?; let window = video_subsystem