use dynamic dispatch for actor selection

This commit is contained in:
Edward Shen 2020-04-05 13:31:05 -04:00
parent ea2c926c50
commit ce65afa277
Signed by: edward
GPG key ID: 19182661E818369F
2 changed files with 10 additions and 6 deletions

View file

@ -70,7 +70,8 @@ pub struct Train {
arg_enum! { arg_enum! {
#[derive(Debug)] #[derive(Debug)]
pub enum Agent { pub enum Agent {
QLearning QLearning,
ApproximateQLearning,
} }
} }
@ -87,6 +88,9 @@ pub fn init_verbosity(opts: &Opts) -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
pub fn get_actor() -> impl Actor { pub fn get_actor(agent: Agent) -> Box<dyn Actor> {
qlearning::QLearningAgent::default() match agent {
Agent::QLearning => Box::new(qlearning::QLearningAgent::default()),
Agent::ApproximateQLearning => todo!(),
}
} }

View file

@ -33,7 +33,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
match opts.subcmd { match opts.subcmd {
SubCommand::Play(sub_opts) => {} SubCommand::Play(sub_opts) => {}
SubCommand::Train(sub_opts) => { SubCommand::Train(sub_opts) => {
let mut to_train = get_actor(); let mut to_train = get_actor(sub_opts.agent);
to_train.set_learning_rate(sub_opts.learning_rate); to_train.set_learning_rate(sub_opts.learning_rate);
to_train.set_discount_rate(sub_opts.discount_rate); to_train.set_discount_rate(sub_opts.discount_rate);
to_train.set_exploration_prob(sub_opts.exploration_prob); to_train.set_exploration_prob(sub_opts.exploration_prob);
@ -61,7 +61,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
fn train_actor(episodes: usize, mut actor: impl Actor) -> impl Actor { fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
let mut rng = rand::rngs::SmallRng::from_entropy(); let mut rng = rand::rngs::SmallRng::from_entropy();
let mut avg = 0.0; let mut avg = 0.0;
@ -109,7 +109,7 @@ fn train_actor(episodes: usize, mut actor: impl Actor) -> impl Actor {
actor actor
} }
async fn play_game(mut actor: Option<impl Actor>) -> Result<(), Box<dyn std::error::Error>> { async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
let mut rng = rand::rngs::SmallRng::from_entropy(); let mut rng = rand::rngs::SmallRng::from_entropy();
let sdl_context = sdl2::init()?; let sdl_context = sdl2::init()?;
let video_subsystem = sdl_context.video()?; let video_subsystem = sdl_context.video()?;