use dynamic dispatch for actor selection
This commit is contained in:
parent
ea2c926c50
commit
ce65afa277
2 changed files with 10 additions and 6 deletions
10
src/cli.rs
10
src/cli.rs
|
@ -70,7 +70,8 @@ pub struct Train {
|
||||||
arg_enum! {
|
arg_enum! {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Agent {
|
pub enum Agent {
|
||||||
QLearning
|
QLearning,
|
||||||
|
ApproximateQLearning,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,6 +88,9 @@ pub fn init_verbosity(opts: &Opts) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_actor() -> impl Actor {
|
pub fn get_actor(agent: Agent) -> Box<dyn Actor> {
|
||||||
qlearning::QLearningAgent::default()
|
match agent {
|
||||||
|
Agent::QLearning => Box::new(qlearning::QLearningAgent::default()),
|
||||||
|
Agent::ApproximateQLearning => todo!(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
match opts.subcmd {
|
match opts.subcmd {
|
||||||
SubCommand::Play(sub_opts) => {}
|
SubCommand::Play(sub_opts) => {}
|
||||||
SubCommand::Train(sub_opts) => {
|
SubCommand::Train(sub_opts) => {
|
||||||
let mut to_train = get_actor();
|
let mut to_train = get_actor(sub_opts.agent);
|
||||||
to_train.set_learning_rate(sub_opts.learning_rate);
|
to_train.set_learning_rate(sub_opts.learning_rate);
|
||||||
to_train.set_discount_rate(sub_opts.discount_rate);
|
to_train.set_discount_rate(sub_opts.discount_rate);
|
||||||
to_train.set_exploration_prob(sub_opts.exploration_prob);
|
to_train.set_exploration_prob(sub_opts.exploration_prob);
|
||||||
|
@ -61,7 +61,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn train_actor(episodes: usize, mut actor: impl Actor) -> impl Actor {
|
fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
|
||||||
let mut rng = rand::rngs::SmallRng::from_entropy();
|
let mut rng = rand::rngs::SmallRng::from_entropy();
|
||||||
let mut avg = 0.0;
|
let mut avg = 0.0;
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ fn train_actor(episodes: usize, mut actor: impl Actor) -> impl Actor {
|
||||||
actor
|
actor
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn play_game(mut actor: Option<impl Actor>) -> Result<(), Box<dyn std::error::Error>> {
|
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let mut rng = rand::rngs::SmallRng::from_entropy();
|
let mut rng = rand::rngs::SmallRng::from_entropy();
|
||||||
let sdl_context = sdl2::init()?;
|
let sdl_context = sdl2::init()?;
|
||||||
let video_subsystem = sdl_context.video()?;
|
let video_subsystem = sdl_context.video()?;
|
||||||
|
|
Loading…
Reference in a new issue