diff --git a/src/commands/op/mod.rs b/src/commands/op/mod.rs new file mode 100644 index 0000000..440d1e0 --- /dev/null +++ b/src/commands/op/mod.rs @@ -0,0 +1,227 @@ +use crate::{ + util::{ + debug_say, + operators::{get_china_ops, get_global_ops, Operator}, + }, + DbConnPool, +}; +use pity::{pity, PITY_COMMAND}; +use queries::OpCommandQueries; +use serenity::framework::standard::{macros::command, Args, CommandResult}; +use serenity::model::{channel::Message, id::UserId}; +use serenity::prelude::Context; +use std::{collections::HashSet, str::FromStr}; + +mod pity; +mod queries; + +#[command] +#[sub_commands(list, add, remove, missing, roll, pity)] +async fn op(_: &Context, _: &Message) -> CommandResult { + Ok(()) +} + +#[command] +async fn list(ctx: &Context, msg: &Message, args: Args) -> CommandResult { + let db_pool = ctx.data.clone(); + let db_pool = db_pool.read().await; + let db_pool = db_pool + .get::() + .expect("No db pool in context?!"); + let user_id = match args.current().map(|id| UserId::from_str(id.trim())) { + Some(Ok(user_id)) => user_id, + _ => msg.author.id, + } + .as_u64() + .to_string(); + + debug_say( + msg, + ctx, + format!( + "{}'s 6\u{2605} Operators:\n{}", + msg.author, + db_pool + .get_operators(user_id) + .await? + .iter() + .map(|(op, pot)| format!("{:?} (Pot {})", op, pot)) + .collect::>() + .join("\n") + ), + ) + .await?; + Ok(()) +} + +#[command] +async fn add(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { + let db_pool = ctx.data.clone(); + let db_pool = db_pool.read().await; + let db_pool = db_pool + .get::() + .expect("No db pool in context?!"); + let mut failed_parses = false; + let mut num_added: usize = 0; + for arg in args.iter::() { + match arg { + Ok(op) => { + db_pool + .add_operator(msg.author.id.as_u64().to_string(), op) + .await?; + num_added += 1; + } + Err(_) => { + failed_parses = true; + } + }; + } + if failed_parses { + debug_say( + msg, + ctx, + "Unable to add some operators. Check your spelling and try again.", + ) + .await?; + } + + match num_added { + 0 => debug_say(msg, ctx, "Didn't add any operators...").await?, + 1 => debug_say(msg, ctx, "Added an operator!").await?, + n => debug_say(msg, ctx, format!("Added {} operators!", n)).await?, + }; + Ok(()) +} + +#[command] +async fn remove(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { + let db_pool = ctx.data.clone(); + let db_pool = db_pool.read().await; + let db_pool = db_pool + .get::() + .expect("No db pool in context?!"); + let mut failed_parses = false; + let mut num_added: usize = 0; + for arg in args.iter::() { + match arg { + Ok(op) => { + db_pool + .remove_operator(msg.author.id.as_u64().to_string(), op) + .await?; + num_added += 1; + } + Err(_) => { + failed_parses = true; + } + }; + } + if failed_parses { + debug_say( + msg, + ctx, + "Unable to remove some operators. Check your spelling and try again.", + ) + .await?; + } + + match num_added { + 0 => debug_say(msg, ctx, "Didn't remove any operators...").await?, + 1 => debug_say(msg, ctx, "Removed an operator!").await?, + n => debug_say(msg, ctx, format!("Removed {} operators!", n)).await?, + }; + Ok(()) +} + +#[command] +async fn missing(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { + let db_pool = ctx.data.clone(); + let db_pool = db_pool.read().await; + let db_pool = db_pool + .get::() + .expect("No db pool in context?!"); + let user_id = match args.current().map(|id| UserId::from_str(id.trim())) { + Some(Ok(user_id)) => user_id, + _ => msg.author.id, + } + .as_u64() + .to_string(); + let operators = db_pool + .get_operators(user_id) + .await? + .iter() + .map(|(op, _)| *op) + .collect::>(); + + let compare_ops = match args.single_quoted::() { + Ok(arg) if arg == "china" || arg == "cn" => get_china_ops(), + _ => get_global_ops(), + }; + + let operators = operators + .symmetric_difference(&compare_ops) + .collect::>(); + + let resp = if operators.len() > 5 { + format!( + "Missing {} and {} more...", + operators + .iter() + .take(5) + .map(|op| format!("{:?}", op)) + .collect::>() + .join(", "), + operators.len() - 5, + ) + } else { + format!( + "Missing {}", + operators + .iter() + .map(|op| format!("{:?}", op)) + .collect::>() + .join(", ") + ) + }; + + debug_say(msg, ctx, resp).await?; + + Ok(()) +} + +#[command] +async fn roll(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { + let db_pool = ctx.data.clone(); + let db_pool = db_pool.read().await; + let db_pool = db_pool + .get::() + .expect("No db pool in context?!"); + if args.current().is_none() { + args = Args::new("1", &[]); + } + + while args.current().is_some() { + match args.quoted().current() { + Some(amt) if amt.parse::().is_ok() => { + db_pool + .add_roll_count(msg.author.id.as_u64().to_string(), amt.parse().unwrap()) + .await?; + } + Some(operator) if Operator::from_str(operator).is_ok() => { + let new_op = Operator::from_str(operator).unwrap(); + db_pool + .add_operator(msg.author.id.as_u64().to_string(), new_op) + .await?; + debug_say(msg, ctx, format!("Congratulations on {:?}!", new_op)).await?; + db_pool + .reset_roll_count(msg.author.id.as_u64().to_string()) + .await?; + } + _ => (), + } + args.advance(); + } + + pity(ctx, msg, args).await?; + + Ok(()) +} diff --git a/src/commands/op/pity.rs b/src/commands/op/pity.rs new file mode 100644 index 0000000..15cc045 --- /dev/null +++ b/src/commands/op/pity.rs @@ -0,0 +1,75 @@ +use super::queries::OpCommandQueries; +use crate::{util::debug_say, DbConnPool}; +use log::warn; +use serenity::framework::standard::{macros::command, Args, CommandResult}; +use serenity::model::{channel::Message, id::UserId}; +use serenity::prelude::Context; +use std::str::FromStr; + +#[command] +#[sub_commands(reset)] +pub(super) async fn pity(ctx: &Context, msg: &Message, args: Args) -> CommandResult { + let db_pool = ctx.data.clone(); + let db_pool = db_pool.read().await; + let db_pool = db_pool + .get::() + .expect("No db pool in context?!"); + let mut other_user = false; + + let user_id = match args.current().map(|id| UserId::from_str(id.trim())) { + Some(Ok(user_id)) => { + other_user = true; + user_id + } + _ => msg.author.id, + } + .as_u64() + .to_string(); + match db_pool.get_roll_count(user_id).await { + Ok(count) => { + debug_say( + msg, + ctx, + format!( + "{}'s Current roll: {}.\n6\u{2605} chance: {}%", + msg.author, + count, + 2 + (count as u16).saturating_sub(50) * 2 + ), + ) + .await?; + } + Err(sqlx::Error::RowNotFound) => { + if !other_user { + reset(ctx, msg, Args::new("", &[])).await?; + } else { + debug_say( + msg, + ctx, + "We don't know where you're currently at. Use ~pity reset first!", + ) + .await?; + } + } + Err(_) => { + warn!("Unable to communicate with database"); + } + } + Ok(()) +} + +#[command] +async fn reset(ctx: &Context, msg: &Message) -> CommandResult { + let db_pool = ctx.data.clone(); + let db_pool = db_pool.read().await; + let db_pool = db_pool + .get::() + .expect("No db pool in context?!"); + + db_pool + .reset_roll_count(msg.author.id.as_u64().to_string()) + .await?; + + pity(ctx, msg, Args::new("", &[])).await?; + Ok(()) +} diff --git a/src/commands/op/queries.rs b/src/commands/op/queries.rs new file mode 100644 index 0000000..29cfed5 --- /dev/null +++ b/src/commands/op/queries.rs @@ -0,0 +1,93 @@ +use crate::util::{db::DbConnPool, operators::Operator}; +use serenity::async_trait; +use sqlx::Error; +use std::str::FromStr; + +#[async_trait] +pub(super) trait OpCommandQueries { + async fn get_roll_count(&self, id: String) -> Result; + async fn reset_roll_count(&self, id: String) -> Result<(), Error>; + async fn add_roll_count(&self, id: String, amount: u32) -> Result<(), Error>; + async fn get_operators(&self, id: String) -> Result, Error>; + async fn add_operator(&self, id: String, op: Operator) -> Result<(), Error>; + async fn remove_operator(&self, id: String, op: Operator) -> Result<(), Error>; +} + +#[async_trait] +impl OpCommandQueries for DbConnPool { + async fn get_roll_count(&self, id: String) -> Result { + let resp = sqlx::query!("SELECT count FROM RollCount WHERE user_id = ?", id) + .fetch_one(&self.pool) + .await?; + + Ok(resp.count) + } + + async fn reset_roll_count(&self, id: String) -> Result<(), Error> { + sqlx::query!( + "INSERT INTO RollCount (user_id, count) VALUES (?, 0) + ON CONFLICT(user_id) DO UPDATE SET count = 0", + id + ) + .execute(&self.pool) + .await?; + Ok(()) + } + + async fn add_roll_count(&self, id: String, amount: u32) -> Result<(), Error> { + sqlx::query!( + "INSERT INTO RollCount (user_id, count) VALUES (?, 0) + ON CONFLICT(user_id) DO UPDATE SET count = MIN(99, count + ?)", + id, + amount as i32 + ) + .execute(&self.pool) + .await?; + Ok(()) + } + + async fn get_operators(&self, id: String) -> Result, Error> { + Ok(sqlx::query!( + "SELECT operator, count FROM OperatorCount WHERE user_id = ?", + id + ) + .fetch_all(&self.pool) + .await? + .iter() + .map(|record| { + ( + Operator::from_str(record.operator.as_ref()).unwrap(), + record.count as u32, + ) + }) + .collect()) + } + + async fn add_operator(&self, id: String, op: Operator) -> Result<(), Error> { + sqlx::query!( + "INSERT INTO OperatorCount (user_id, operator, count) VALUES (?, ?, 1) + ON CONFLICT(user_id, operator) DO UPDATE SET count = count + 1", + id, + op + ) + .execute(&self.pool) + .await?; + Ok(()) + } + + async fn remove_operator(&self, id: String, op: Operator) -> Result<(), Error> { + sqlx::query!( + "UPDATE OperatorCount SET count = MAX(count - 1, 0) WHERE user_id = ? AND operator = ?", + id, + op + ) + .execute(&self.pool) + .await?; + + sqlx::query!("DELETE FROM OperatorCount WHERE count = 0") + .execute(&self.pool) + .await?; + + Ok(()) + } +}