diff --git a/Cargo.lock b/Cargo.lock index b107deb..f14c1c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1376,6 +1376,7 @@ dependencies = [ "dotenv", "log", "rand 0.8.3", + "rust_decimal", "serde", "serde_json", "simple_logger", diff --git a/vtse-common/src/net.rs b/vtse-common/src/net.rs index 59495ef..4422b5e 100644 --- a/vtse-common/src/net.rs +++ b/vtse-common/src/net.rs @@ -1,4 +1,4 @@ -use crate::stock::Stock; +use crate::{stock::Stock, user::User}; use serde::Serialize; use uuid::Uuid; @@ -10,11 +10,14 @@ pub enum ServerResponse { NewApiKey(Uuid), UserError(UserError), StockInfo(Stock), + UserInfo(User), } #[derive(Serialize)] pub enum UserError { + InvalidUsername, InvalidPassword, InvalidApiKey, NotAuthorized, + NotEnoughOwnedStock(usize), } diff --git a/vtse-common/src/operations.rs b/vtse-common/src/operations.rs index 887b7e1..d0bb05d 100644 --- a/vtse-common/src/operations.rs +++ b/vtse-common/src/operations.rs @@ -1,5 +1,3 @@ -use crate::stock::StockName; - pub(crate) enum MarketOperation { Buy, Sell, diff --git a/vtse-common/src/stock.rs b/vtse-common/src/stock.rs index dbb1af8..2f0abd2 100644 --- a/vtse-common/src/stock.rs +++ b/vtse-common/src/stock.rs @@ -3,11 +3,8 @@ use serde::{Deserialize, Serialize}; use std::str::FromStr; use thiserror::Error; -#[derive( - Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, sqlx::Type, -)] +#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)] #[serde(transparent)] -#[sqlx(transparent)] pub struct StockName(String); impl From for StockName { @@ -37,6 +34,20 @@ impl From for StockSymbol { } } +impl FromStr for StockSymbol { + type Err = StockNameParseError; + + fn from_str(s: &str) -> Result { + Ok(Self(s.to_string())) + } +} + +impl StockSymbol { + pub fn inner(&self) -> &str { + &self.0 + } +} + #[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)] #[serde(transparent)] pub struct StockDescription(String); diff --git a/vtse-common/src/user.rs b/vtse-common/src/user.rs index 139c178..c3f1c3a 100644 --- a/vtse-common/src/user.rs +++ b/vtse-common/src/user.rs @@ -1,4 +1,6 @@ +use rust_decimal::Decimal; use serde::Deserialize; +use serde::Serialize; use uuid::Uuid; #[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)] @@ -10,7 +12,8 @@ impl ApiKey { } } -#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)] +#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize, Serialize)] +#[serde(transparent)] pub struct Username(String); impl Username { @@ -20,6 +23,7 @@ impl Username { } #[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)] +#[serde(transparent)] pub struct Password(String); impl Password { @@ -27,3 +31,40 @@ impl Password { self.0.as_bytes() } } + +#[derive(Serialize)] +#[serde(transparent)] +pub struct UserBalance(Decimal); + +impl UserBalance { + pub fn new(decimal: Decimal) -> Self { + Self(decimal) + } +} + +#[derive(Serialize)] +#[serde(transparent)] +pub struct UserDebt(Decimal); + +impl UserDebt { + pub fn new(decimal: Decimal) -> Self { + Self(decimal) + } +} + +#[derive(Serialize)] +pub struct User { + username: Username, + balance: UserBalance, + debt: UserDebt, +} + +impl User { + pub fn new(username: Username, balance: UserBalance, debt: UserDebt) -> Self { + Self { + username, + balance, + debt, + } + } +} diff --git a/vtse-server/Cargo.toml b/vtse-server/Cargo.toml index 1897d00..5db4169 100644 --- a/vtse-server/Cargo.toml +++ b/vtse-server/Cargo.toml @@ -10,6 +10,7 @@ bytes = "1" dotenv = "0.15" log = "0.4" rand = "0.8" +rust_decimal = "1.10.2" serde = { version = "1", features = ["derive"] } serde_json = "1" simple_logger = "1.11" diff --git a/vtse-server/src/main.rs b/vtse-server/src/main.rs index 46a81d2..b3c05bb 100644 --- a/vtse-server/src/main.rs +++ b/vtse-server/src/main.rs @@ -36,12 +36,8 @@ async fn main() -> Result<()> { // clone simply clones the arc reference, so this is cheap. let db_pool = db_pool.clone(); tokio::task::spawn(async { - match handle_stream(stream, db_pool).await { - Ok(_) => (), - Err(e) => { - // stream.write_all(e); - error!("{}", e); - } + if let Err(e) = handle_stream(stream, db_pool).await { + error!("{}", e); } }); } @@ -65,10 +61,11 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> { let data = buffer.split_to(bytes_read); // O(1) let parsed = serde_json::from_slice::(&data)?; debug!("Parsed operation: {:?}", parsed); + let response = match parsed { ServerOperation::Query(op) => match op { QueryOperation::StockInfo { stock } => state.stock_info(stock, &pool).await?, - QueryOperation::User(username) => state.user_info(username, &pool).await?, + QueryOperation::User { username } => state.user_info(username, &pool).await?, }, ServerOperation::User(op) => match op { UserOperation::Login { api_key } => state.login(api_key, &pool).await?, @@ -80,8 +77,10 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> { } }, ServerOperation::Market(op) => match op { - MarketOperation::Buy(stock_name) => state.buy(stock_name, &pool)?, - MarketOperation::Sell(stock_name) => state.sell(stock_name, &pool)?, + MarketOperation::Buy { symbol, amount } => state.buy(symbol, amount, &pool).await?, + MarketOperation::Sell { symbol, amount } => { + state.sell(symbol, amount, &pool).await? + } }, }; diff --git a/vtse-server/src/operations.rs b/vtse-server/src/operations.rs index d80c52e..b7d119c 100644 --- a/vtse-server/src/operations.rs +++ b/vtse-server/src/operations.rs @@ -1,5 +1,5 @@ use serde::Deserialize; -use vtse_common::stock::StockName; +use vtse_common::stock::StockSymbol; use vtse_common::user::{ApiKey, Password, Username}; #[derive(Deserialize, Debug, PartialEq)] @@ -13,8 +13,8 @@ pub(crate) enum ServerOperation { #[derive(Deserialize, Debug, PartialEq)] #[serde(tag = "type", rename_all = "snake_case")] pub(crate) enum QueryOperation { - StockInfo { stock: StockName }, - User(Username), + StockInfo { stock: StockSymbol }, + User { username: Username }, } #[derive(Deserialize, Debug, PartialEq)] @@ -35,8 +35,8 @@ pub(crate) enum UserOperation { #[derive(Deserialize, Debug, PartialEq)] pub(crate) enum MarketOperation { - Buy(StockName), - Sell(StockName), + Buy { symbol: StockSymbol, amount: usize }, + Sell { symbol: StockSymbol, amount: usize }, } #[cfg(test)] @@ -54,7 +54,7 @@ mod deserialize { })) .unwrap(), ServerOperation::Query(QueryOperation::StockInfo { - stock: StockName::from_str("Gura").unwrap() + stock: StockSymbol::from_str("Gura").unwrap() }) ) } @@ -68,7 +68,7 @@ mod deserialize { })) .unwrap(), QueryOperation::StockInfo { - stock: StockName::from_str("Gura").unwrap() + stock: StockSymbol::from_str("Gura").unwrap() } ) } diff --git a/vtse-server/src/state.rs b/vtse-server/src/state.rs index bbd59c8..01c877f 100644 --- a/vtse-server/src/state.rs +++ b/vtse-server/src/state.rs @@ -1,11 +1,18 @@ use crate::user::SaltedPassword; +use rust_decimal::Decimal; use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword}; use sqlx::{query, PgPool}; use std::convert::TryFrom; use thiserror::Error; -use vtse_common::net::{ServerResponse, UserError}; -use vtse_common::stock::{Stock, StockName}; -use vtse_common::user::{ApiKey, Password, Username}; +use vtse_common::{ + net::{ServerResponse, UserError}, + user::User, +}; +use vtse_common::{stock::Stock, user::UserBalance}; +use vtse_common::{ + stock::StockSymbol, + user::{ApiKey, Password, UserDebt, Username}, +}; #[derive(Error, Debug)] pub(crate) enum StateError { @@ -42,10 +49,14 @@ impl AppState { /// Query operations impl impl AppState { - pub(crate) async fn stock_info(&self, stock_name: StockName, pool: &PgPool) -> OperationResult { + pub(crate) async fn stock_info( + &self, + stock_symbol: StockSymbol, + pool: &PgPool, + ) -> OperationResult { let stock = query!( - "SELECT name, symbol, description, price FROM stocks WHERE name = $1", - stock_name as StockName + "SELECT name, symbol, description, price FROM stocks WHERE symbol = $1", + stock_symbol.inner() ) .fetch_one(pool) .await?; @@ -59,7 +70,21 @@ impl AppState { } pub(crate) async fn user_info(&self, username: Username, pool: &PgPool) -> OperationResult { - todo!() + let user = query!( + "SELECT balance, debt FROM users WHERE username = $1", + username.inner() + ) + .fetch_optional(pool) + .await?; + + match user { + Some(user) => Ok(ServerResponse::UserInfo(User::new( + username, + UserBalance::new(user.balance), + UserDebt::new(user.debt), + ))), + None => Ok(ServerResponse::UserError(UserError::InvalidUsername)), + } } } @@ -175,18 +200,64 @@ impl AppState { /// Market operation implementation impl AppState { - pub(crate) fn buy(&self, stock_name: StockName, pool: &PgPool) -> OperationResult { - if !matches!(self, Self::Authorized{..}) { - return Err(StateError::WrongState); + pub(crate) async fn buy( + &self, + symbol: StockSymbol, + amount: usize, + pool: &PgPool, + ) -> OperationResult { + let id = *match self { + Self::Authorized { user_id } => user_id, + _ => return Err(StateError::WrongState), + }; + let user_balance = Self::get_user_balance(id, pool).await?; + + todo!() + } + + pub(crate) async fn sell( + &self, + symbol: StockSymbol, + amount: usize, + pool: &PgPool, + ) -> OperationResult { + let id = *match self { + Self::Authorized { user_id } => user_id, + _ => return Err(StateError::WrongState), + }; + let num_owned = Self::get_owned_stock_count(id, symbol, pool).await?; + if num_owned < amount { + return Ok(ServerResponse::UserError(UserError::NotEnoughOwnedStock( + num_owned, + ))); } todo!() } - pub(crate) fn sell(&self, stock_name: StockName, pool: &PgPool) -> OperationResult { - if !matches!(self, Self::Authorized{..}) { - return Err(StateError::WrongState); - } - todo!() + async fn get_user_balance(user_id: i32, pool: &PgPool) -> Result { + query!("SELECT balance FROM users where user_id = $1", user_id) + .fetch_one(pool) + .await + .map(|record| record.balance) + .map_err(StateError::from) + } + + async fn get_owned_stock_count( + user_id: i32, + symbol: StockSymbol, + pool: &PgPool, + ) -> Result { + query!( + "SELECT amount FROM users_stocks + JOIN stocks ON users_stocks.stock_id = stocks.stock_id + WHERE user_id = $1 AND symbol = $2", + user_id, + symbol.inner() + ) + .fetch_one(pool) + .await + .map(|record| record.amount as usize) + .map_err(StateError::from) } } diff --git a/vtse-server/src/user.rs b/vtse-server/src/user.rs index ae9de55..7baef96 100644 --- a/vtse-server/src/user.rs +++ b/vtse-server/src/user.rs @@ -1,7 +1,5 @@ -use sodiumoxide::crypto::pwhash::{ - argon2id13::HashedPassword, - argon2id13::{pwhash, MEMLIMIT_INTERACTIVE, OPSLIMIT_INTERACTIVE}, -}; +use sodiumoxide::crypto::pwhash::argon2id13::HashedPassword; +use sodiumoxide::crypto::pwhash::argon2id13::{pwhash, MEMLIMIT_INTERACTIVE, OPSLIMIT_INTERACTIVE}; use std::convert::TryFrom; use vtse_common::user::Password;