use crate::{market::Market, user::SaltedPassword}; use log::{debug, info}; use rust_decimal::Decimal; use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword}; use sqlx::{query, PgPool}; use std::{convert::TryFrom, sync::Arc}; use thiserror::Error; use vtse_common::net::{ServerResponse, UserError}; use vtse_common::stock::{Stock, StockSymbol}; use vtse_common::user::{ApiKey, Password, User, UserBalance, UserDebt, Username}; #[derive(Error, Debug)] pub(crate) enum StateError { #[error("Was not in the correct state")] WrongState, #[error("Got SQLx error: {0}")] Database(#[from] sqlx::Error), #[error("Failed to hash password")] PasswordHash, } #[derive(Error, Debug)] pub(crate) enum UserOrDbError { #[error("Got SQLx error: {0}")] Database(#[from] sqlx::Error), #[error("Got user error: {0}")] User(UserError), } impl UserOrDbError { fn into_operation_result(self) -> OperationResult { match self { UserOrDbError::Database(e) => Err(StateError::Database(e)), UserOrDbError::User(e) => Ok(ServerResponse::UserError(e)), } } } type OperationResult = Result; #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)] pub(crate) enum AppState { Unauthorized, Authorized { user_id: i32 }, } /// Helper functions impl AppState { pub(crate) fn new() -> Self { Self::Unauthorized } fn assert_state(&self, expected_state: AppState) -> Result<(), StateError> { if *self != expected_state { Err(StateError::WrongState) } else { Ok(()) } } } /// Query operations impl impl AppState { pub(crate) async fn stock_info(stock_symbol: StockSymbol, pool: &PgPool) -> OperationResult { let stock = query!( "SELECT name, symbol, description, price FROM stocks WHERE symbol = $1", stock_symbol.inner() ) .fetch_one(pool) .await?; Ok(ServerResponse::StockInfo(Stock::new( stock.name, stock.symbol, stock.description, stock.price, ))) } pub(crate) async fn user_info(username: Username, pool: &PgPool) -> OperationResult { let user = query!( "SELECT balance, debt FROM users WHERE username = $1", username.inner() ) .fetch_optional(pool) .await?; match user { Some(user) => Ok(ServerResponse::UserInfo(User::new( username, UserBalance::new(user.balance), UserDebt::new(user.debt), ))), None => Ok(ServerResponse::UserError(UserError::InvalidUsername)), } } } /// User operation implementation impl AppState { pub(crate) async fn login(&mut self, api_key: ApiKey, pool: &PgPool) -> OperationResult { self.assert_state(AppState::Unauthorized)?; let user_id = { let query = query!( "SELECT (user_id) FROM user_api_keys JOIN api_keys ON user_api_keys.key_id = api_keys.key_id WHERE key = $1", api_key.inner() ) .fetch_optional(pool) .await?; match query { Some(id) => id.user_id, None => return Ok(ServerResponse::UserError(UserError::InvalidApiKey)), } }; *self = AppState::Authorized { user_id }; Ok(ServerResponse::Success) } pub(crate) async fn register( &mut self, username: Username, password: Password, pool: &PgPool, ) -> OperationResult { self.assert_state(AppState::Unauthorized)?; let salted_password = SaltedPassword::try_from(password).map_err(|_| StateError::PasswordHash)?; query!( "INSERT INTO users (username, pwhash_data) VALUES ($1, $2)", username.inner(), salted_password.as_ref(), ) .execute(pool) .await?; Ok(ServerResponse::Success) } pub(crate) async fn generate_api_key( &self, username: Username, password: Password, pool: &PgPool, ) -> OperationResult { self.assert_state(AppState::Unauthorized)?; match self.validate_password(&username, &password, pool).await { Ok(None) => Ok(ServerResponse::UserError(UserError::InvalidPassword)), Ok(Some(user_id)) => self.create_api_key(user_id, pool).await, Err(_) => todo!(), } } async fn create_api_key(&self, user_id: i32, pool: &PgPool) -> OperationResult { let mut transaction = pool.begin().await?; let api_key = uuid::Uuid::new_v4(); let api_key_id = query!( "INSERT INTO api_keys (key) VALUES ($1) RETURNING api_keys.key_id;", &api_key, ) .fetch_one(&mut transaction) .await? .key_id; query!( "INSERT INTO user_api_keys (user_id, key_id) VALUES ($1, $2)", user_id, api_key_id ) .execute(&mut transaction) .await?; transaction.commit().await?; Ok(ServerResponse::NewApiKey(api_key)) } async fn validate_password( &self, username: &Username, password: &Password, pool: &PgPool, ) -> Result, StateError> { let result = query!( "SELECT user_id, pwhash_data FROM users WHERE username = $1", username.inner() ) .fetch_one(pool) .await?; let login_attempt = HashedPassword::from_slice(&result.pwhash_data) .map(|pw| pwhash_verify(&pw, password.as_bytes())) .unwrap(); if login_attempt { Ok(Some(result.user_id)) } else { Ok(None) } } } /// Market operation implementation impl AppState { pub(crate) async fn buy( &self, symbol: StockSymbol, amount: usize, pool: &PgPool, market: &mut Arc, ) -> OperationResult { let id = *match self { Self::Authorized { user_id } => user_id, _ => return Err(StateError::WrongState), }; let stock_id = Self::get_stock_id(&symbol, pool).await?; let user_balance = Self::get_user_balance(id, pool).await?; let mut market = market.lock().await; let current_price = { let response = Self::get_stock_price(symbol.clone(), pool).await; match response { Ok(v) => v, Err(e) => return e.into_operation_result(), } }; let cost_to_purchase = market.peek_price(¤t_price, amount); if user_balance >= cost_to_purchase { let resp = self .purchase_stock(stock_id, amount, cost_to_purchase, pool) .await; if resp.is_ok() { info!( "User {} will be purchasing {} {} securities for {}", id, amount, symbol, cost_to_purchase ); market.commit(); } resp } else { todo!() } } pub(crate) async fn sell( &self, symbol: StockSymbol, amount: usize, pool: &PgPool, market: &mut Arc, ) -> OperationResult { let id = *match self { Self::Authorized { user_id } => user_id, _ => return Err(StateError::WrongState), }; let stock_id = Self::get_stock_id(&symbol, pool).await?; let num_owned = Self::get_owned_stock_count(id, stock_id, pool).await?; if num_owned < amount { return Ok(ServerResponse::UserError(UserError::NotEnoughOwnedStock( num_owned, ))); } let mut market = market.lock().await; let current_price = { let response = Self::get_stock_price(symbol.clone(), pool).await; match response { Ok(v) => v, Err(e) => return e.into_operation_result(), } }; let gains = market.peek_price(¤t_price, amount); let response = self .sell_stock(stock_id, gains, amount, num_owned == amount, pool) .await; market.commit(); response } // todo: fetch_one needs to turn into fetch_optional, else we can't discriminate user versus // server error. async fn get_user_balance(user_id: i32, pool: &PgPool) -> Result { query!("SELECT balance FROM users where user_id = $1", user_id) .fetch_one(pool) .await .map(|record| record.balance) .map_err(StateError::from) } async fn get_owned_stock_count( user_id: i32, stock_id: i32, pool: &PgPool, ) -> Result { query!( "SELECT amount FROM users_stocks WHERE user_id = $1 AND stock_id = $2", user_id, stock_id ) .fetch_one(pool) .await .map(|record| record.amount as usize) .map_err(StateError::from) } async fn get_stock_price(symbol: StockSymbol, pool: &PgPool) -> Result { let query = query!("SELECT price FROM stocks WHERE symbol = $1", symbol.inner()) .fetch_optional(pool) .await; match query { Ok(Some(record)) => Ok(record.price), Ok(None) => Err(UserOrDbError::User(UserError::InvalidStock(symbol))), Err(e) => Err(UserOrDbError::Database(e)), } } async fn purchase_stock( &self, symbol_id: i32, amount: usize, cost: Decimal, pool: &PgPool, ) -> OperationResult { let id = *match self { Self::Authorized { user_id } => user_id, _ => return Err(StateError::WrongState), }; debug!( "Starting transaction for id {}: {} {} at total cost {}", id, amount, symbol_id, cost ); let mut transaction = pool.begin().await?; query!( "UPDATE users SET balance = balance - $1 WHERE user_id = $2", cost, id ) .execute(&mut transaction) .await?; query!( "INSERT INTO users_stocks VALUES ($1, $2, $3) ON CONFLICT ON CONSTRAINT users_stocks_user_id_stock_id_key DO UPDATE SET amount = users_stocks.amount + $3::integer WHERE users_stocks.user_id = $1", id, symbol_id, amount as u32 ) .execute(&mut transaction) .await?; transaction.commit().await?; Ok(ServerResponse::Success) } async fn get_stock_id(symbol: &StockSymbol, pool: &PgPool) -> Result { query!( "SELECT stock_id FROM stocks WHERE symbol = $1", symbol.inner() ) .fetch_one(pool) .await .map(|record| record.stock_id) .map_err(StateError::from) } async fn sell_stock( &self, stock_id: i32, gains: Decimal, amount_sold: usize, sell_all: bool, pool: &PgPool, ) -> OperationResult { let id = *match self { Self::Authorized { user_id } => user_id, _ => return Err(StateError::WrongState), }; let mut transaction = pool.begin().await?; query!( "UPDATE users SET balance = balance + $1::numeric WHERE user_id = $2", gains, id ) .execute(&mut transaction) .await?; if sell_all { query!( "DELETE FROM users_stocks WHERE user_id = $1 AND stock_id = $2", id, stock_id ) .execute(&mut transaction) .await?; } else { query!( "UPDATE users_stocks SET amount = amount - $1::integer WHERE user_id = $2 AND stock_id = $3", amount_sold as u32, id, stock_id, ) .execute(&mut transaction) .await?; } transaction.commit().await?; Ok(ServerResponse::Success) } } impl Default for AppState { fn default() -> Self { Self::new() } }