use crate::{ util::{ debug_say, operators::{get_china_ops, get_global_ops, Operator}, }, DbConnPool, }; use log::warn; use serenity::framework::standard::{macros::command, Args, CommandResult}; use serenity::model::{channel::Message, id::UserId}; use serenity::{async_trait, prelude::Context}; use sqlx::Error; use std::{collections::HashSet, str::FromStr}; #[command] #[sub_commands(list, add, remove, missing, roll, pity)] async fn op(_: &Context, _: &Message) -> CommandResult { Ok(()) } #[command] async fn list(ctx: &Context, msg: &Message, args: Args) -> CommandResult { let db_pool = ctx.data.clone(); let db_pool = db_pool.read().await; let db_pool = db_pool .get::() .expect("No db pool in context?!"); let user_id = match args.current().map(|id| UserId::from_str(id.trim())) { Some(Ok(user_id)) => user_id, _ => msg.author.id, } .as_u64() .to_string(); debug_say( msg, ctx, format!( "{}'s 6\u{2605} Operators:\n{}", msg.author, db_pool .get_operators(user_id) .await? .iter() .map(|(op, pot)| format!("{:?} (Pot {})", op, pot)) .collect::>() .join("\n") ), ) .await?; Ok(()) } #[command] async fn add(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let db_pool = ctx.data.clone(); let db_pool = db_pool.read().await; let db_pool = db_pool .get::() .expect("No db pool in context?!"); let mut failed_parses = false; let mut num_added: usize = 0; for arg in args.iter::() { match arg { Ok(op) => { db_pool .add_operator(msg.author.id.as_u64().to_string(), op) .await?; num_added += 1; } Err(_) => { failed_parses = true; } }; } if failed_parses { debug_say( msg, ctx, "Unable to add some operators. Check your spelling and try again.", ) .await?; } match num_added { 0 => debug_say(msg, ctx, "Didn't add any operators...").await?, 1 => debug_say(msg, ctx, "Added an operator!").await?, n => debug_say(msg, ctx, format!("Added {} operators!", n)).await?, }; Ok(()) } #[command] async fn remove(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let db_pool = ctx.data.clone(); let db_pool = db_pool.read().await; let db_pool = db_pool .get::() .expect("No db pool in context?!"); let mut failed_parses = false; let mut num_added: usize = 0; for arg in args.iter::() { match arg { Ok(op) => { db_pool .remove_operator(msg.author.id.as_u64().to_string(), op) .await?; num_added += 1; } Err(_) => { failed_parses = true; } }; } if failed_parses { debug_say( msg, ctx, "Unable to remove some operators. Check your spelling and try again.", ) .await?; } match num_added { 0 => debug_say(msg, ctx, "Didn't remove any operators...").await?, 1 => debug_say(msg, ctx, "Removed an operator!").await?, n => debug_say(msg, ctx, format!("Removed {} operators!", n)).await?, }; Ok(()) } #[command] async fn missing(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let db_pool = ctx.data.clone(); let db_pool = db_pool.read().await; let db_pool = db_pool .get::() .expect("No db pool in context?!"); let user_id = match args.current().map(|id| UserId::from_str(id.trim())) { Some(Ok(user_id)) => user_id, _ => msg.author.id, } .as_u64() .to_string(); let operators = db_pool .get_operators(user_id) .await? .iter() .map(|(op, _)| *op) .collect::>(); let compare_ops = match args.single_quoted::() { Ok(arg) if arg == "china" || arg == "cn" => get_china_ops(), _ => get_global_ops(), }; let operators = operators .symmetric_difference(&compare_ops) .collect::>(); let resp = if operators.len() > 5 { format!( "Missing {} and {} more...", operators .iter() .take(5) .map(|op| format!("{:?}", op)) .collect::>() .join(", "), operators.len() - 5, ) } else { format!( "Missing {}", operators .iter() .map(|op| format!("{:?}", op)) .collect::>() .join(", ") ) }; debug_say(msg, ctx, resp).await?; Ok(()) } #[command] async fn roll(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let db_pool = ctx.data.clone(); let db_pool = db_pool.read().await; let db_pool = db_pool .get::() .expect("No db pool in context?!"); if args.current().is_none() { args = Args::new("1", &[]); } while args.current().is_some() { match args.quoted().current() { Some(amt) if amt.parse::().is_ok() => { db_pool .add_roll_count(msg.author.id.as_u64().to_string(), amt.parse().unwrap()) .await?; } Some(operator) if Operator::from_str(operator).is_ok() => { let new_op = Operator::from_str(operator).unwrap(); db_pool .add_operator(msg.author.id.as_u64().to_string(), new_op) .await?; debug_say(msg, ctx, format!("Congratulations on {:?}!", new_op)).await?; db_pool .reset_roll_count(msg.author.id.as_u64().to_string()) .await?; } _ => (), } args.advance(); } pity(ctx, msg, args).await?; Ok(()) } #[command] #[sub_commands(reset)] async fn pity(ctx: &Context, msg: &Message, args: Args) -> CommandResult { let db_pool = ctx.data.clone(); let db_pool = db_pool.read().await; let db_pool = db_pool .get::() .expect("No db pool in context?!"); let mut other_user = false; let user_id = match args.current().map(|id| UserId::from_str(id.trim())) { Some(Ok(user_id)) => { other_user = true; user_id } _ => msg.author.id, } .as_u64() .to_string(); match db_pool.get_roll_count(user_id).await { Ok(count) => { debug_say( msg, ctx, format!( "{}'s Current roll: {}.\n6\u{2605} chance: {}%", msg.author, count, 2 + (count as u16).saturating_sub(50) * 2 ), ) .await?; } Err(sqlx::Error::RowNotFound) => { if !other_user { reset(ctx, msg, Args::new("", &[])).await?; } else { debug_say( msg, ctx, "We don't know where you're currently at. Use ~pity reset first!", ) .await?; } } Err(_) => { warn!("Unable to communicate with database"); } } Ok(()) } #[command] async fn reset(ctx: &Context, msg: &Message) -> CommandResult { let db_pool = ctx.data.clone(); let db_pool = db_pool.read().await; let db_pool = db_pool .get::() .expect("No db pool in context?!"); db_pool .reset_roll_count(msg.author.id.as_u64().to_string()) .await?; pity(ctx, msg, Args::new("", &[])).await?; Ok(()) } #[async_trait] trait OpCommandQueries { async fn get_roll_count(&self, id: String) -> Result; async fn reset_roll_count(&self, id: String) -> Result<(), Error>; async fn add_roll_count(&self, id: String, amount: u32) -> Result<(), Error>; async fn get_operators(&self, id: String) -> Result, Error>; async fn add_operator(&self, id: String, op: Operator) -> Result<(), Error>; async fn remove_operator(&self, id: String, op: Operator) -> Result<(), Error>; } #[async_trait] impl OpCommandQueries for DbConnPool { async fn get_roll_count(&self, id: String) -> Result { let resp = sqlx::query!("SELECT count FROM RollCount WHERE user_id = ?", id) .fetch_one(&self.pool) .await?; Ok(resp.count) } async fn reset_roll_count(&self, id: String) -> Result<(), Error> { sqlx::query!( "INSERT INTO RollCount (user_id, count) VALUES (?, 0) ON CONFLICT(user_id) DO UPDATE SET count = 0", id ) .execute(&self.pool) .await?; Ok(()) } async fn add_roll_count(&self, id: String, amount: u32) -> Result<(), Error> { sqlx::query!( "INSERT INTO RollCount (user_id, count) VALUES (?, 0) ON CONFLICT(user_id) DO UPDATE SET count = MIN(99, count + ?)", id, amount as i32 ) .execute(&self.pool) .await?; Ok(()) } async fn get_operators(&self, id: String) -> Result, Error> { Ok(sqlx::query!( "SELECT operator, count FROM OperatorCount WHERE user_id = ?", id ) .fetch_all(&self.pool) .await? .iter() .map(|record| { ( Operator::from_str(record.operator.as_ref()).unwrap(), record.count as u32, ) }) .collect()) } async fn add_operator(&self, id: String, op: Operator) -> Result<(), Error> { sqlx::query!( "INSERT INTO OperatorCount (user_id, operator, count) VALUES (?, ?, 1) ON CONFLICT(user_id, operator) DO UPDATE SET count = count + 1", id, op ) .execute(&self.pool) .await?; Ok(()) } async fn remove_operator(&self, id: String, op: Operator) -> Result<(), Error> { sqlx::query!( "UPDATE OperatorCount SET count = MAX(count - 1, 0) WHERE user_id = ? AND operator = ?", id, op ) .execute(&self.pool) .await?; sqlx::query!("DELETE FROM OperatorCount WHERE count = 0") .execute(&self.pool) .await?; Ok(()) } }