use dynamic dispatch for actor selection
This commit is contained in:
parent
ea2c926c50
commit
ce65afa277
2 changed files with 10 additions and 6 deletions
10
src/cli.rs
10
src/cli.rs
|
@ -70,7 +70,8 @@ pub struct Train {
|
|||
arg_enum! {
|
||||
#[derive(Debug)]
|
||||
pub enum Agent {
|
||||
QLearning
|
||||
QLearning,
|
||||
ApproximateQLearning,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -87,6 +88,9 @@ pub fn init_verbosity(opts: &Opts) -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_actor() -> impl Actor {
|
||||
qlearning::QLearningAgent::default()
|
||||
pub fn get_actor(agent: Agent) -> Box<dyn Actor> {
|
||||
match agent {
|
||||
Agent::QLearning => Box::new(qlearning::QLearningAgent::default()),
|
||||
Agent::ApproximateQLearning => todo!(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
match opts.subcmd {
|
||||
SubCommand::Play(sub_opts) => {}
|
||||
SubCommand::Train(sub_opts) => {
|
||||
let mut to_train = get_actor();
|
||||
let mut to_train = get_actor(sub_opts.agent);
|
||||
to_train.set_learning_rate(sub_opts.learning_rate);
|
||||
to_train.set_discount_rate(sub_opts.discount_rate);
|
||||
to_train.set_exploration_prob(sub_opts.exploration_prob);
|
||||
|
@ -61,7 +61,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn train_actor(episodes: usize, mut actor: impl Actor) -> impl Actor {
|
||||
fn train_actor(episodes: usize, mut actor: Box<dyn Actor>) -> Box<dyn Actor> {
|
||||
let mut rng = rand::rngs::SmallRng::from_entropy();
|
||||
let mut avg = 0.0;
|
||||
|
||||
|
@ -109,7 +109,7 @@ fn train_actor(episodes: usize, mut actor: impl Actor) -> impl Actor {
|
|||
actor
|
||||
}
|
||||
|
||||
async fn play_game(mut actor: Option<impl Actor>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
async fn play_game(mut actor: Option<Box<dyn Actor>>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut rng = rand::rngs::SmallRng::from_entropy();
|
||||
let sdl_context = sdl2::init()?;
|
||||
let video_subsystem = sdl_context.video()?;
|
||||
|
|
Loading…
Reference in a new issue