diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..8883ac1 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,12 @@ +{ + "cSpell.words": [ + "Gura", + "dotenv", + "icious", + "pwhash", + "sodiumoxide", + "thiserror", + "unsplit", + "vtse" + ] +} \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index f14c1c5..f09ae8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -152,6 +152,38 @@ dependencies = [ "winapi", ] +[[package]] +name = "clap" +version = "3.0.0-beta.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bd1061998a501ee7d4b6d449020df3266ca3124b941ec56cf2005c3779ca142" +dependencies = [ + "atty", + "bitflags", + "clap_derive", + "indexmap", + "lazy_static", + "os_str_bytes", + "strsim", + "termcolor", + "textwrap", + "unicode-width", + "vec_map", +] + +[[package]] +name = "clap_derive" +version = "3.0.0-beta.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "370f715b81112975b1b69db93e0b56ea4cd4e5002ac43b2da8474106a54096a1" +dependencies = [ + "heck", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "colored" version = "1.9.3" @@ -446,6 +478,16 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indexmap" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb1fa934250de4de8aef298d81c729a7d33d8c239daa3a7575e6b92bfc7313b" +dependencies = [ + "autocfg", + "hashbrown", +] + [[package]] name = "instant" version = "0.1.9" @@ -649,6 +691,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" +[[package]] +name = "os_str_bytes" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afb2e1c3ee07430c2cf76151675e583e0f19985fa6efae47d6848a3e2c824f85" + [[package]] name = "parking_lot" version = "0.11.1" @@ -704,6 +752,30 @@ version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro-hack" version = "0.5.19" @@ -1146,6 +1218,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "subtle" version = "2.4.0" @@ -1169,6 +1247,24 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "36474e732d1affd3a6ed582781b3683df3d0563714c59c39591e8ff707cf078e" +[[package]] +name = "termcolor" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dfed899f0eb03f32ee8c6a0aabdb8a7949659e3466561fc0adf54e26d88c5f4" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "textwrap" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "203008d98caf094106cfaba70acfed15e18ed3ddb7d94e49baec153a2b462789" +dependencies = [ + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.23" @@ -1293,9 +1389,9 @@ dependencies = [ [[package]] name = "unicode-normalization" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a13e63ab62dbe32aeee58d1c5408d35c36c392bba5d9d3142287219721afe606" +checksum = "07fbfce1c8a97d547e8b5334978438d9d6ec8c20e38f56d4a4374d181493eaef" dependencies = [ "tinyvec", ] @@ -1306,6 +1402,12 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb0d2e7be6ae3a5fa87eed5fb451aff96f2573d2694942e40543ae0bbe19c796" +[[package]] +name = "unicode-width" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9337591893a19b88d8d87f2cec1e73fad5cdfd10e5a6f349f498ad6ea2ffb1e3" + [[package]] name = "unicode-xid" version = "0.2.1" @@ -1346,6 +1448,12 @@ dependencies = [ "serde", ] +[[package]] +name = "vec_map" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" + [[package]] name = "version_check" version = "0.9.2" @@ -1373,6 +1481,7 @@ version = "0.1.0" dependencies = [ "anyhow", "bytes", + "clap", "dotenv", "log", "rand 0.8.3", @@ -1514,6 +1623,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/vtse-common/src/net.rs b/vtse-common/src/net.rs index 4422b5e..9802a62 100644 --- a/vtse-common/src/net.rs +++ b/vtse-common/src/net.rs @@ -1,5 +1,9 @@ -use crate::{stock::Stock, user::User}; +use crate::{ + stock::{Stock, StockSymbol}, + user::User, +}; use serde::Serialize; +use thiserror::Error; use uuid::Uuid; #[derive(Serialize)] @@ -13,11 +17,18 @@ pub enum ServerResponse { UserInfo(User), } -#[derive(Serialize)] +#[derive(Error, Serialize, Debug)] pub enum UserError { + #[error("An invalid username was provided.")] InvalidUsername, + #[error("An invalid password was provided.")] InvalidPassword, + #[error("An invalid API key was provided.")] InvalidApiKey, + #[error("This requires authorization. Please login first.")] NotAuthorized, + #[error("You don't have enough stock to sell {0} units.")] NotEnoughOwnedStock(usize), + #[error("Stock symbol {0} does not exist.")] + InvalidStock(StockSymbol), } diff --git a/vtse-common/src/stock.rs b/vtse-common/src/stock.rs index 2f0abd2..e5f35a5 100644 --- a/vtse-common/src/stock.rs +++ b/vtse-common/src/stock.rs @@ -1,6 +1,6 @@ use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; -use std::str::FromStr; +use std::{fmt::Display, str::FromStr}; use thiserror::Error; #[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)] @@ -28,6 +28,12 @@ impl FromStr for StockName { #[serde(transparent)] pub struct StockSymbol(String); +impl Display for StockSymbol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "${}", self.0) + } +} + impl From for StockSymbol { fn from(s: String) -> Self { Self(s) diff --git a/vtse-server/Cargo.toml b/vtse-server/Cargo.toml index 5db4169..ccb6d1e 100644 --- a/vtse-server/Cargo.toml +++ b/vtse-server/Cargo.toml @@ -7,6 +7,7 @@ edition = "2018" [dependencies] anyhow = "1" bytes = "1" +clap = "3.0.0-beta.2" dotenv = "0.15" log = "0.4" rand = "0.8" diff --git a/vtse-server/src/args.rs b/vtse-server/src/args.rs new file mode 100644 index 0000000..9fee9f5 --- /dev/null +++ b/vtse-server/src/args.rs @@ -0,0 +1,14 @@ +use clap::Clap; +use log::LevelFilter; + +#[derive(Clap)] +pub(crate) struct Args { + #[clap(long, env("DATABASE_URL"))] + pub(crate) database_url: String, + /// What address the vtse server should listen to. + #[clap(long, env("BIND_ADDRESS"), default_value = "localhost:8080")] + pub(crate) bind_address: String, + /// Sets the logging level DB queries should report at. + #[clap(long, default_value = "off")] + pub(crate) db_log_level: LevelFilter, +} diff --git a/vtse-server/src/main.rs b/vtse-server/src/main.rs index b3c05bb..388a197 100644 --- a/vtse-server/src/main.rs +++ b/vtse-server/src/main.rs @@ -1,15 +1,20 @@ use crate::operations::ServerOperation; use anyhow::{bail, Result}; +use args::Args; use bytes::BytesMut; +use clap::Clap; use log::{debug, error, info}; +use market::Market; use operations::{MarketOperation, QueryOperation, UserOperation}; -use sqlx::PgPool; +use simple_logger::SimpleLogger; +use sqlx::{postgres::PgConnectOptions, ConnectOptions, PgPool}; use state::AppState; -use std::env; +use std::{str::FromStr, sync::Arc}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; +mod args; mod market; mod operations; mod state; @@ -17,26 +22,35 @@ mod user; #[tokio::main] async fn main() -> Result<()> { + // Must init dotenv before initializing args, else args won't default from + // env properly dotenv::dotenv().ok(); - // If we can't successfully initialize our crypto library, fail immediately. - sodiumoxide::init().unwrap(); - simple_logger::SimpleLogger::default().init()?; + let args: Args = Args::parse(); - let mut listener_stream = TcpListener::bind(&env::var("BIND_ADDRESS")?) + // If we can't successfully initialize our crypto library, fail immediately. + sodiumoxide::init().expect("to initialize crypto library"); + SimpleLogger::default().init()?; + + let mut listener_stream = TcpListener::bind(&args.bind_address) .await .map(TcpListenerStream::new)?; - info!("Successfully bound to port."); - let db_pool = PgPool::connect(&env::var("DATABASE_URL")?).await?; - + let db_pool = { + let mut connect_options = PgConnectOptions::from_str(&args.database_url)?; + connect_options.log_statements(args.db_log_level); + PgPool::connect_with(connect_options).await? + }; info!("Successfully established connection to database."); + let market = Arc::new(Market::new()); + while let Some(Ok(stream)) = listener_stream.next().await { // clone simply clones the arc reference, so this is cheap. let db_pool = db_pool.clone(); + let market = Arc::clone(&market); tokio::task::spawn(async { - if let Err(e) = handle_stream(stream, db_pool).await { + if let Err(e) = handle_stream(stream, db_pool, market).await { error!("{}", e); } }); @@ -47,7 +61,7 @@ async fn main() -> Result<()> { Ok(()) } -async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> { +async fn handle_stream(mut socket: TcpStream, pool: PgPool, mut market: Arc) -> Result<()> { // only accept data that can fit in 256 bytes // the assumption is that a single request should always be within n bytes, // otherwise there's a good chance that it's mal{formed,icious}. @@ -58,35 +72,52 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> { if bytes_read == 0 { bail!("Failed to read bytes, assuming socket is closed.") } + let data = buffer.split_to(bytes_read); // O(1) - let parsed = serde_json::from_slice::(&data)?; - debug!("Parsed operation: {:?}", parsed); + dbg!(&data); + let iter = serde_json::Deserializer::from_slice(&data).into_iter(); - let response = match parsed { - ServerOperation::Query(op) => match op { - QueryOperation::StockInfo { stock } => state.stock_info(stock, &pool).await?, - QueryOperation::User { username } => state.user_info(username, &pool).await?, - }, - ServerOperation::User(op) => match op { - UserOperation::Login { api_key } => state.login(api_key, &pool).await?, - UserOperation::Register { username, password } => { - state.register(username, password, &pool).await? - } - UserOperation::GetKey { username, password } => { - state.generate_api_key(username, password, &pool).await? - } - }, - ServerOperation::Market(op) => match op { - MarketOperation::Buy { symbol, amount } => state.buy(symbol, amount, &pool).await?, - MarketOperation::Sell { symbol, amount } => { - state.sell(symbol, amount, &pool).await? - } - }, - }; + for messages in iter { + let message = match messages { + Ok(p) => p, + Err(e) if e.is_eof() || e.is_io() => return Ok(()), + Err(e) => return Err(e.into()), + }; + + debug!("Parsed operation: {:?}", message); + + let response = match message { + ServerOperation::Query(op) => match op { + QueryOperation::StockInfo { stock } => { + AppState::stock_info(stock, &pool).await? + } + QueryOperation::User { username } => { + AppState::user_info(username, &pool).await? + } + }, + ServerOperation::User(op) => match op { + UserOperation::Login { api_key } => state.login(api_key, &pool).await?, + UserOperation::Register { username, password } => { + state.register(username, password, &pool).await? + } + UserOperation::GetKey { username, password } => { + state.generate_api_key(username, password, &pool).await? + } + }, + ServerOperation::Market(op) => match op { + MarketOperation::Buy { symbol, amount } => { + state.buy(symbol, amount, &pool, &mut market).await? + } + MarketOperation::Sell { symbol, amount } => { + state.sell(symbol, amount, &pool, &mut market).await? + } + }, + }; + socket + .write_all(serde_json::to_string(&response).unwrap().as_bytes()) + .await?; + } - socket - .write_all(serde_json::to_string(&response).unwrap().as_bytes()) - .await?; buffer.unsplit(data); // O(1) buffer.clear(); } diff --git a/vtse-server/src/market/generator.rs b/vtse-server/src/market/generator.rs index 1258d23..7195b8e 100644 --- a/vtse-server/src/market/generator.rs +++ b/vtse-server/src/market/generator.rs @@ -1,37 +1,66 @@ -use rand::rngs::ThreadRng; -use rand::thread_rng; use rand::{distributions::Uniform, prelude::Distribution}; +use rand::{rngs::StdRng, SeedableRng}; +use rust_decimal::Decimal; #[derive(Copy, Clone, PartialEq, PartialOrd, Debug)] -pub(crate) struct MarketModifier(f64); +pub(super) struct MarketModifier(Decimal); -#[derive(Clone, Debug)] -pub(crate) struct MarketGenerator { - rng: ThreadRng, - sample_range: Uniform, - volatility: f64, - old_price: MarketModifier, +impl MarketModifier { + pub(super) fn modifier(&self) -> Decimal { + self.0 + } } -impl MarketGenerator { - pub(crate) fn new(volatility: f64, initial_value: f64) -> Self { +#[derive(Clone, Debug)] +pub(super) struct MarketMultiplier { + rng: StdRng, + sample_range: Uniform, + volatility: Decimal, + old_price: MarketModifier, + buffer: Vec, +} + +impl MarketMultiplier { + pub(super) fn new(volatility: Decimal, initial_value: Decimal) -> Self { Self { - rng: thread_rng(), - sample_range: Uniform::new_inclusive(-0.5, 0.5), + rng: StdRng::from_entropy(), + sample_range: Uniform::new_inclusive(-50, 50), volatility, old_price: MarketModifier(initial_value), + buffer: vec![], } } } -impl Iterator for MarketGenerator { +impl Iterator for MarketMultiplier { type Item = MarketModifier; fn next(&mut self) -> Option { // Algorithm from https://stackoverflow.com/a/8597889 let rng_multiplier = self.sample_range.sample(&mut self.rng); - let delta = 2f64 * self.volatility * rng_multiplier; + let delta = Decimal::new(2, 0) * self.volatility * Decimal::new(rng_multiplier, 2); self.old_price.0 += self.old_price.0 * delta; Some(self.old_price) } } + +impl MarketMultiplier { + pub(super) fn next_n(&mut self, n: usize) -> Vec { + self.ensure_buffer_capacity(n); + let mut ret = self.buffer.split_off(n); + std::mem::swap(&mut ret, &mut self.buffer); + ret + } + + pub(super) fn peek_n(&mut self, n: usize) -> &[MarketModifier] { + self.ensure_buffer_capacity(n); + &self.buffer[..n] + } + + fn ensure_buffer_capacity(&mut self, n: usize) { + for _ in 0..(n - self.buffer.len()) { + let next_val = self.next().unwrap(); + self.buffer.push(next_val); + } + } +} diff --git a/vtse-server/src/market/mod.rs b/vtse-server/src/market/mod.rs index 5bade7e..5665041 100644 --- a/vtse-server/src/market/mod.rs +++ b/vtse-server/src/market/mod.rs @@ -1 +1,41 @@ +use generator::MarketMultiplier; +use rust_decimal::Decimal; +use tokio::sync::{Mutex, MutexGuard}; + mod generator; + +pub(crate) struct Market { + generator: Mutex, +} + +impl Market { + pub(crate) fn new() -> Self { + Self { + generator: Mutex::new(MarketMultiplier::new( + Decimal::new(1, 1), + Decimal::new(1, 0), + )), + } + } + + pub(crate) async fn lock(&self) -> LockedMarket<'_> { + LockedMarket(self.generator.lock().await, 0) + } +} + +pub(crate) struct LockedMarket<'a>(MutexGuard<'a, MarketMultiplier>, usize); + +impl<'a> LockedMarket<'a> { + pub(crate) fn peek_price(&mut self, initial_price: &Decimal, n: usize) -> Decimal { + let mut price = *initial_price; + for i in self.0.peek_n(n).iter() { + price *= i.modifier(); + } + self.1 = n; + price.round_dp(2) + } + + pub(crate) fn commit(mut self) { + self.0.next_n(self.1); + } +} diff --git a/vtse-server/src/operations.rs b/vtse-server/src/operations.rs index b7d119c..4f54d0c 100644 --- a/vtse-server/src/operations.rs +++ b/vtse-server/src/operations.rs @@ -34,6 +34,7 @@ pub(crate) enum UserOperation { } #[derive(Deserialize, Debug, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] pub(crate) enum MarketOperation { Buy { symbol: StockSymbol, amount: usize }, Sell { symbol: StockSymbol, amount: usize }, diff --git a/vtse-server/src/state.rs b/vtse-server/src/state.rs index ec43d26..e1ea049 100644 --- a/vtse-server/src/state.rs +++ b/vtse-server/src/state.rs @@ -1,8 +1,9 @@ -use crate::user::SaltedPassword; +use crate::{market::Market, user::SaltedPassword}; +use log::{debug, info}; use rust_decimal::Decimal; use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword}; use sqlx::{query, PgPool}; -use std::convert::TryFrom; +use std::{convert::TryFrom, sync::Arc}; use thiserror::Error; use vtse_common::net::{ServerResponse, UserError}; use vtse_common::stock::{Stock, StockSymbol}; @@ -12,12 +13,29 @@ use vtse_common::user::{ApiKey, Password, User, UserBalance, UserDebt, Username} pub(crate) enum StateError { #[error("Was not in the correct state")] WrongState, - #[error("Got SQLx error: {0}`")] + #[error("Got SQLx error: {0}")] Database(#[from] sqlx::Error), #[error("Failed to hash password")] PasswordHash, } +#[derive(Error, Debug)] +pub(crate) enum UserOrDbError { + #[error("Got SQLx error: {0}")] + Database(#[from] sqlx::Error), + #[error("Got user error: {0}")] + User(UserError), +} + +impl UserOrDbError { + fn into_operation_result(self) -> OperationResult { + match self { + UserOrDbError::Database(e) => Err(StateError::Database(e)), + UserOrDbError::User(e) => Ok(ServerResponse::UserError(e)), + } + } +} + type OperationResult = Result; #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)] @@ -43,11 +61,7 @@ impl AppState { /// Query operations impl impl AppState { - pub(crate) async fn stock_info( - &self, - stock_symbol: StockSymbol, - pool: &PgPool, - ) -> OperationResult { + pub(crate) async fn stock_info(stock_symbol: StockSymbol, pool: &PgPool) -> OperationResult { let stock = query!( "SELECT name, symbol, description, price FROM stocks WHERE symbol = $1", stock_symbol.inner() @@ -63,7 +77,7 @@ impl AppState { ))) } - pub(crate) async fn user_info(&self, username: Username, pool: &PgPool) -> OperationResult { + pub(crate) async fn user_info(username: Username, pool: &PgPool) -> OperationResult { let user = query!( "SELECT balance, debt FROM users WHERE username = $1", username.inner() @@ -114,9 +128,7 @@ impl AppState { let salted_password = SaltedPassword::try_from(password).map_err(|_| StateError::PasswordHash)?; query!( - "INSERT INTO users - (username, pwhash_data) - VALUES ($1, $2)", + "INSERT INTO users (username, pwhash_data) VALUES ($1, $2)", username.inner(), salted_password.as_ref(), ) @@ -143,9 +155,7 @@ impl AppState { let mut transaction = pool.begin().await?; let api_key = uuid::Uuid::new_v4(); let api_key_id = query!( - "INSERT INTO api_keys (key) - VALUES ($1) - RETURNING api_keys.key_id;", + "INSERT INTO api_keys (key) VALUES ($1) RETURNING api_keys.key_id;", &api_key, ) .fetch_one(&mut transaction) @@ -153,8 +163,7 @@ impl AppState { .key_id; query!( - "INSERT INTO user_api_keys (user_id, key_id) - VALUES ($1, $2)", + "INSERT INTO user_api_keys (user_id, key_id) VALUES ($1, $2)", user_id, api_key_id ) @@ -173,8 +182,7 @@ impl AppState { pool: &PgPool, ) -> Result, StateError> { let result = query!( - "SELECT user_id, pwhash_data FROM users - WHERE username = $1", + "SELECT user_id, pwhash_data FROM users WHERE username = $1", username.inner() ) .fetch_one(pool) @@ -199,14 +207,39 @@ impl AppState { symbol: StockSymbol, amount: usize, pool: &PgPool, + market: &mut Arc, ) -> OperationResult { let id = *match self { Self::Authorized { user_id } => user_id, _ => return Err(StateError::WrongState), }; + let stock_id = Self::get_stock_id(&symbol, pool).await?; let user_balance = Self::get_user_balance(id, pool).await?; + let mut market = market.lock().await; + let current_price = { + let response = Self::get_stock_price(symbol.clone(), pool).await; + match response { + Ok(v) => v, + Err(e) => return e.into_operation_result(), + } + }; + let cost_to_purchase = market.peek_price(¤t_price, amount); - todo!() + if user_balance >= cost_to_purchase { + let resp = self + .purchase_stock(stock_id, amount, cost_to_purchase, pool) + .await; + if resp.is_ok() { + info!( + "User {} will be purchasing {} {} securities for {}", + id, amount, symbol, cost_to_purchase + ); + market.commit(); + } + resp + } else { + todo!() + } } pub(crate) async fn sell( @@ -214,18 +247,33 @@ impl AppState { symbol: StockSymbol, amount: usize, pool: &PgPool, + market: &mut Arc, ) -> OperationResult { let id = *match self { Self::Authorized { user_id } => user_id, _ => return Err(StateError::WrongState), }; - let num_owned = Self::get_owned_stock_count(id, symbol, pool).await?; + let stock_id = Self::get_stock_id(&symbol, pool).await?; + let num_owned = Self::get_owned_stock_count(id, stock_id, pool).await?; if num_owned < amount { return Ok(ServerResponse::UserError(UserError::NotEnoughOwnedStock( num_owned, ))); } - todo!() + let mut market = market.lock().await; + let current_price = { + let response = Self::get_stock_price(symbol.clone(), pool).await; + match response { + Ok(v) => v, + Err(e) => return e.into_operation_result(), + } + }; + let gains = market.peek_price(¤t_price, amount); + let response = self + .sell_stock(stock_id, gains, amount, num_owned == amount, pool) + .await; + market.commit(); + response } // todo: fetch_one needs to turn into fetch_optional, else we can't discriminate user versus @@ -241,21 +289,130 @@ impl AppState { async fn get_owned_stock_count( user_id: i32, - symbol: StockSymbol, + stock_id: i32, pool: &PgPool, ) -> Result { query!( - "SELECT amount FROM users_stocks - JOIN stocks ON users_stocks.stock_id = stocks.stock_id - WHERE user_id = $1 AND symbol = $2", + "SELECT amount FROM users_stocks WHERE user_id = $1 AND stock_id = $2", user_id, - symbol.inner() + stock_id ) .fetch_one(pool) .await .map(|record| record.amount as usize) .map_err(StateError::from) } + + async fn get_stock_price(symbol: StockSymbol, pool: &PgPool) -> Result { + let query = query!("SELECT price FROM stocks WHERE symbol = $1", symbol.inner()) + .fetch_optional(pool) + .await; + match query { + Ok(Some(record)) => Ok(record.price), + Ok(None) => Err(UserOrDbError::User(UserError::InvalidStock(symbol))), + Err(e) => Err(UserOrDbError::Database(e)), + } + } + + async fn purchase_stock( + &self, + symbol_id: i32, + amount: usize, + cost: Decimal, + pool: &PgPool, + ) -> OperationResult { + let id = *match self { + Self::Authorized { user_id } => user_id, + _ => return Err(StateError::WrongState), + }; + + debug!( + "Starting transaction for id {}: {} {} at total cost {}", + id, amount, symbol_id, cost + ); + + let mut transaction = pool.begin().await?; + query!( + "UPDATE users SET balance = balance - $1 WHERE user_id = $2", + cost, + id + ) + .execute(&mut transaction) + .await?; + + query!( + "INSERT INTO users_stocks VALUES ($1, $2, $3) + ON CONFLICT ON CONSTRAINT users_stocks_user_id_stock_id_key DO + UPDATE SET amount = users_stocks.amount + $3::integer + WHERE users_stocks.user_id = $1", + id, + symbol_id, + amount as u32 + ) + .execute(&mut transaction) + .await?; + + transaction.commit().await?; + + Ok(ServerResponse::Success) + } + + async fn get_stock_id(symbol: &StockSymbol, pool: &PgPool) -> Result { + query!( + "SELECT stock_id FROM stocks WHERE symbol = $1", + symbol.inner() + ) + .fetch_one(pool) + .await + .map(|record| record.stock_id) + .map_err(StateError::from) + } + + async fn sell_stock( + &self, + stock_id: i32, + gains: Decimal, + amount_sold: usize, + sell_all: bool, + pool: &PgPool, + ) -> OperationResult { + let id = *match self { + Self::Authorized { user_id } => user_id, + _ => return Err(StateError::WrongState), + }; + let mut transaction = pool.begin().await?; + + query!( + "UPDATE users SET balance = balance + $1::numeric WHERE user_id = $2", + gains, + id + ) + .execute(&mut transaction) + .await?; + + if sell_all { + query!( + "DELETE FROM users_stocks WHERE user_id = $1 AND stock_id = $2", + id, + stock_id + ) + .execute(&mut transaction) + .await?; + } else { + query!( + "UPDATE users_stocks SET amount = amount - $1::integer + WHERE user_id = $2 AND stock_id = $3", + amount_sold as u32, + id, + stock_id, + ) + .execute(&mut transaction) + .await?; + } + + transaction.commit().await?; + Ok(ServerResponse::Success) + } } impl Default for AppState {