implement partial buy and sell

This commit is contained in:
Edward Shen 2021-02-11 13:55:28 -05:00
parent e6ff1576f1
commit 07bb068a2b
Signed by: edward
GPG key ID: 19182661E818369F
11 changed files with 504 additions and 84 deletions

12
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,12 @@
{
"cSpell.words": [
"Gura",
"dotenv",
"icious",
"pwhash",
"sodiumoxide",
"thiserror",
"unsplit",
"vtse"
]
}

122
Cargo.lock generated
View file

@ -152,6 +152,38 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "clap"
version = "3.0.0-beta.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4bd1061998a501ee7d4b6d449020df3266ca3124b941ec56cf2005c3779ca142"
dependencies = [
"atty",
"bitflags",
"clap_derive",
"indexmap",
"lazy_static",
"os_str_bytes",
"strsim",
"termcolor",
"textwrap",
"unicode-width",
"vec_map",
]
[[package]]
name = "clap_derive"
version = "3.0.0-beta.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "370f715b81112975b1b69db93e0b56ea4cd4e5002ac43b2da8474106a54096a1"
dependencies = [
"heck",
"proc-macro-error",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "colored" name = "colored"
version = "1.9.3" version = "1.9.3"
@ -446,6 +478,16 @@ dependencies = [
"unicode-normalization", "unicode-normalization",
] ]
[[package]]
name = "indexmap"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fb1fa934250de4de8aef298d81c729a7d33d8c239daa3a7575e6b92bfc7313b"
dependencies = [
"autocfg",
"hashbrown",
]
[[package]] [[package]]
name = "instant" name = "instant"
version = "0.1.9" version = "0.1.9"
@ -649,6 +691,12 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]]
name = "os_str_bytes"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afb2e1c3ee07430c2cf76151675e583e0f19985fa6efae47d6848a3e2c824f85"
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.11.1" version = "0.11.1"
@ -704,6 +752,30 @@ version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]] [[package]]
name = "proc-macro-hack" name = "proc-macro-hack"
version = "0.5.19" version = "0.5.19"
@ -1146,6 +1218,12 @@ dependencies = [
"unicode-normalization", "unicode-normalization",
] ]
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]] [[package]]
name = "subtle" name = "subtle"
version = "2.4.0" version = "2.4.0"
@ -1169,6 +1247,24 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36474e732d1affd3a6ed582781b3683df3d0563714c59c39591e8ff707cf078e" checksum = "36474e732d1affd3a6ed582781b3683df3d0563714c59c39591e8ff707cf078e"
[[package]]
name = "termcolor"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dfed899f0eb03f32ee8c6a0aabdb8a7949659e3466561fc0adf54e26d88c5f4"
dependencies = [
"winapi-util",
]
[[package]]
name = "textwrap"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "203008d98caf094106cfaba70acfed15e18ed3ddb7d94e49baec153a2b462789"
dependencies = [
"unicode-width",
]
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.23" version = "1.0.23"
@ -1293,9 +1389,9 @@ dependencies = [
[[package]] [[package]]
name = "unicode-normalization" name = "unicode-normalization"
version = "0.1.16" version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a13e63ab62dbe32aeee58d1c5408d35c36c392bba5d9d3142287219721afe606" checksum = "07fbfce1c8a97d547e8b5334978438d9d6ec8c20e38f56d4a4374d181493eaef"
dependencies = [ dependencies = [
"tinyvec", "tinyvec",
] ]
@ -1306,6 +1402,12 @@ version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb0d2e7be6ae3a5fa87eed5fb451aff96f2573d2694942e40543ae0bbe19c796" checksum = "bb0d2e7be6ae3a5fa87eed5fb451aff96f2573d2694942e40543ae0bbe19c796"
[[package]]
name = "unicode-width"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9337591893a19b88d8d87f2cec1e73fad5cdfd10e5a6f349f498ad6ea2ffb1e3"
[[package]] [[package]]
name = "unicode-xid" name = "unicode-xid"
version = "0.2.1" version = "0.2.1"
@ -1346,6 +1448,12 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "vec_map"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.2" version = "0.9.2"
@ -1373,6 +1481,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytes", "bytes",
"clap",
"dotenv", "dotenv",
"log", "log",
"rand 0.8.3", "rand 0.8.3",
@ -1514,6 +1623,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
dependencies = [
"winapi",
]
[[package]] [[package]]
name = "winapi-x86_64-pc-windows-gnu" name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0" version = "0.4.0"

View file

@ -1,5 +1,9 @@
use crate::{stock::Stock, user::User}; use crate::{
stock::{Stock, StockSymbol},
user::User,
};
use serde::Serialize; use serde::Serialize;
use thiserror::Error;
use uuid::Uuid; use uuid::Uuid;
#[derive(Serialize)] #[derive(Serialize)]
@ -13,11 +17,18 @@ pub enum ServerResponse {
UserInfo(User), UserInfo(User),
} }
#[derive(Serialize)] #[derive(Error, Serialize, Debug)]
pub enum UserError { pub enum UserError {
#[error("An invalid username was provided.")]
InvalidUsername, InvalidUsername,
#[error("An invalid password was provided.")]
InvalidPassword, InvalidPassword,
#[error("An invalid API key was provided.")]
InvalidApiKey, InvalidApiKey,
#[error("This requires authorization. Please login first.")]
NotAuthorized, NotAuthorized,
#[error("You don't have enough stock to sell {0} units.")]
NotEnoughOwnedStock(usize), NotEnoughOwnedStock(usize),
#[error("Stock symbol {0} does not exist.")]
InvalidStock(StockSymbol),
} }

View file

@ -1,6 +1,6 @@
use rust_decimal::Decimal; use rust_decimal::Decimal;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::str::FromStr; use std::{fmt::Display, str::FromStr};
use thiserror::Error; use thiserror::Error;
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)] #[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)]
@ -28,6 +28,12 @@ impl FromStr for StockName {
#[serde(transparent)] #[serde(transparent)]
pub struct StockSymbol(String); pub struct StockSymbol(String);
impl Display for StockSymbol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "${}", self.0)
}
}
impl From<String> for StockSymbol { impl From<String> for StockSymbol {
fn from(s: String) -> Self { fn from(s: String) -> Self {
Self(s) Self(s)

View file

@ -7,6 +7,7 @@ edition = "2018"
[dependencies] [dependencies]
anyhow = "1" anyhow = "1"
bytes = "1" bytes = "1"
clap = "3.0.0-beta.2"
dotenv = "0.15" dotenv = "0.15"
log = "0.4" log = "0.4"
rand = "0.8" rand = "0.8"

14
vtse-server/src/args.rs Normal file
View file

@ -0,0 +1,14 @@
use clap::Clap;
use log::LevelFilter;
#[derive(Clap)]
pub(crate) struct Args {
#[clap(long, env("DATABASE_URL"))]
pub(crate) database_url: String,
/// What address the vtse server should listen to.
#[clap(long, env("BIND_ADDRESS"), default_value = "localhost:8080")]
pub(crate) bind_address: String,
/// Sets the logging level DB queries should report at.
#[clap(long, default_value = "off")]
pub(crate) db_log_level: LevelFilter,
}

View file

@ -1,15 +1,20 @@
use crate::operations::ServerOperation; use crate::operations::ServerOperation;
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use args::Args;
use bytes::BytesMut; use bytes::BytesMut;
use clap::Clap;
use log::{debug, error, info}; use log::{debug, error, info};
use market::Market;
use operations::{MarketOperation, QueryOperation, UserOperation}; use operations::{MarketOperation, QueryOperation, UserOperation};
use sqlx::PgPool; use simple_logger::SimpleLogger;
use sqlx::{postgres::PgConnectOptions, ConnectOptions, PgPool};
use state::AppState; use state::AppState;
use std::env; use std::{str::FromStr, sync::Arc};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
mod args;
mod market; mod market;
mod operations; mod operations;
mod state; mod state;
@ -17,26 +22,35 @@ mod user;
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
// Must init dotenv before initializing args, else args won't default from
// env properly
dotenv::dotenv().ok(); dotenv::dotenv().ok();
// If we can't successfully initialize our crypto library, fail immediately. let args: Args = Args::parse();
sodiumoxide::init().unwrap();
simple_logger::SimpleLogger::default().init()?;
let mut listener_stream = TcpListener::bind(&env::var("BIND_ADDRESS")?) // If we can't successfully initialize our crypto library, fail immediately.
sodiumoxide::init().expect("to initialize crypto library");
SimpleLogger::default().init()?;
let mut listener_stream = TcpListener::bind(&args.bind_address)
.await .await
.map(TcpListenerStream::new)?; .map(TcpListenerStream::new)?;
info!("Successfully bound to port."); info!("Successfully bound to port.");
let db_pool = PgPool::connect(&env::var("DATABASE_URL")?).await?; let db_pool = {
let mut connect_options = PgConnectOptions::from_str(&args.database_url)?;
connect_options.log_statements(args.db_log_level);
PgPool::connect_with(connect_options).await?
};
info!("Successfully established connection to database."); info!("Successfully established connection to database.");
let market = Arc::new(Market::new());
while let Some(Ok(stream)) = listener_stream.next().await { while let Some(Ok(stream)) = listener_stream.next().await {
// clone simply clones the arc reference, so this is cheap. // clone simply clones the arc reference, so this is cheap.
let db_pool = db_pool.clone(); let db_pool = db_pool.clone();
let market = Arc::clone(&market);
tokio::task::spawn(async { tokio::task::spawn(async {
if let Err(e) = handle_stream(stream, db_pool).await { if let Err(e) = handle_stream(stream, db_pool, market).await {
error!("{}", e); error!("{}", e);
} }
}); });
@ -47,7 +61,7 @@ async fn main() -> Result<()> {
Ok(()) Ok(())
} }
async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> { async fn handle_stream(mut socket: TcpStream, pool: PgPool, mut market: Arc<Market>) -> Result<()> {
// only accept data that can fit in 256 bytes // only accept data that can fit in 256 bytes
// the assumption is that a single request should always be within n bytes, // the assumption is that a single request should always be within n bytes,
// otherwise there's a good chance that it's mal{formed,icious}. // otherwise there's a good chance that it's mal{formed,icious}.
@ -58,14 +72,28 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
if bytes_read == 0 { if bytes_read == 0 {
bail!("Failed to read bytes, assuming socket is closed.") bail!("Failed to read bytes, assuming socket is closed.")
} }
let data = buffer.split_to(bytes_read); // O(1)
let parsed = serde_json::from_slice::<ServerOperation>(&data)?;
debug!("Parsed operation: {:?}", parsed);
let response = match parsed { let data = buffer.split_to(bytes_read); // O(1)
dbg!(&data);
let iter = serde_json::Deserializer::from_slice(&data).into_iter();
for messages in iter {
let message = match messages {
Ok(p) => p,
Err(e) if e.is_eof() || e.is_io() => return Ok(()),
Err(e) => return Err(e.into()),
};
debug!("Parsed operation: {:?}", message);
let response = match message {
ServerOperation::Query(op) => match op { ServerOperation::Query(op) => match op {
QueryOperation::StockInfo { stock } => state.stock_info(stock, &pool).await?, QueryOperation::StockInfo { stock } => {
QueryOperation::User { username } => state.user_info(username, &pool).await?, AppState::stock_info(stock, &pool).await?
}
QueryOperation::User { username } => {
AppState::user_info(username, &pool).await?
}
}, },
ServerOperation::User(op) => match op { ServerOperation::User(op) => match op {
UserOperation::Login { api_key } => state.login(api_key, &pool).await?, UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
@ -77,16 +105,19 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
} }
}, },
ServerOperation::Market(op) => match op { ServerOperation::Market(op) => match op {
MarketOperation::Buy { symbol, amount } => state.buy(symbol, amount, &pool).await?, MarketOperation::Buy { symbol, amount } => {
state.buy(symbol, amount, &pool, &mut market).await?
}
MarketOperation::Sell { symbol, amount } => { MarketOperation::Sell { symbol, amount } => {
state.sell(symbol, amount, &pool).await? state.sell(symbol, amount, &pool, &mut market).await?
} }
}, },
}; };
socket socket
.write_all(serde_json::to_string(&response).unwrap().as_bytes()) .write_all(serde_json::to_string(&response).unwrap().as_bytes())
.await?; .await?;
}
buffer.unsplit(data); // O(1) buffer.unsplit(data); // O(1)
buffer.clear(); buffer.clear();
} }

View file

@ -1,37 +1,66 @@
use rand::rngs::ThreadRng;
use rand::thread_rng;
use rand::{distributions::Uniform, prelude::Distribution}; use rand::{distributions::Uniform, prelude::Distribution};
use rand::{rngs::StdRng, SeedableRng};
use rust_decimal::Decimal;
#[derive(Copy, Clone, PartialEq, PartialOrd, Debug)] #[derive(Copy, Clone, PartialEq, PartialOrd, Debug)]
pub(crate) struct MarketModifier(f64); pub(super) struct MarketModifier(Decimal);
#[derive(Clone, Debug)] impl MarketModifier {
pub(crate) struct MarketGenerator { pub(super) fn modifier(&self) -> Decimal {
rng: ThreadRng, self.0
sample_range: Uniform<f64>, }
volatility: f64,
old_price: MarketModifier,
} }
impl MarketGenerator { #[derive(Clone, Debug)]
pub(crate) fn new(volatility: f64, initial_value: f64) -> Self { pub(super) struct MarketMultiplier {
rng: StdRng,
sample_range: Uniform<i64>,
volatility: Decimal,
old_price: MarketModifier,
buffer: Vec<MarketModifier>,
}
impl MarketMultiplier {
pub(super) fn new(volatility: Decimal, initial_value: Decimal) -> Self {
Self { Self {
rng: thread_rng(), rng: StdRng::from_entropy(),
sample_range: Uniform::new_inclusive(-0.5, 0.5), sample_range: Uniform::new_inclusive(-50, 50),
volatility, volatility,
old_price: MarketModifier(initial_value), old_price: MarketModifier(initial_value),
buffer: vec![],
} }
} }
} }
impl Iterator for MarketGenerator { impl Iterator for MarketMultiplier {
type Item = MarketModifier; type Item = MarketModifier;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
// Algorithm from https://stackoverflow.com/a/8597889 // Algorithm from https://stackoverflow.com/a/8597889
let rng_multiplier = self.sample_range.sample(&mut self.rng); let rng_multiplier = self.sample_range.sample(&mut self.rng);
let delta = 2f64 * self.volatility * rng_multiplier; let delta = Decimal::new(2, 0) * self.volatility * Decimal::new(rng_multiplier, 2);
self.old_price.0 += self.old_price.0 * delta; self.old_price.0 += self.old_price.0 * delta;
Some(self.old_price) Some(self.old_price)
} }
} }
impl MarketMultiplier {
pub(super) fn next_n(&mut self, n: usize) -> Vec<MarketModifier> {
self.ensure_buffer_capacity(n);
let mut ret = self.buffer.split_off(n);
std::mem::swap(&mut ret, &mut self.buffer);
ret
}
pub(super) fn peek_n(&mut self, n: usize) -> &[MarketModifier] {
self.ensure_buffer_capacity(n);
&self.buffer[..n]
}
fn ensure_buffer_capacity(&mut self, n: usize) {
for _ in 0..(n - self.buffer.len()) {
let next_val = self.next().unwrap();
self.buffer.push(next_val);
}
}
}

View file

@ -1 +1,41 @@
use generator::MarketMultiplier;
use rust_decimal::Decimal;
use tokio::sync::{Mutex, MutexGuard};
mod generator; mod generator;
pub(crate) struct Market {
generator: Mutex<MarketMultiplier>,
}
impl Market {
pub(crate) fn new() -> Self {
Self {
generator: Mutex::new(MarketMultiplier::new(
Decimal::new(1, 1),
Decimal::new(1, 0),
)),
}
}
pub(crate) async fn lock(&self) -> LockedMarket<'_> {
LockedMarket(self.generator.lock().await, 0)
}
}
pub(crate) struct LockedMarket<'a>(MutexGuard<'a, MarketMultiplier>, usize);
impl<'a> LockedMarket<'a> {
pub(crate) fn peek_price(&mut self, initial_price: &Decimal, n: usize) -> Decimal {
let mut price = *initial_price;
for i in self.0.peek_n(n).iter() {
price *= i.modifier();
}
self.1 = n;
price.round_dp(2)
}
pub(crate) fn commit(mut self) {
self.0.next_n(self.1);
}
}

View file

@ -34,6 +34,7 @@ pub(crate) enum UserOperation {
} }
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum MarketOperation { pub(crate) enum MarketOperation {
Buy { symbol: StockSymbol, amount: usize }, Buy { symbol: StockSymbol, amount: usize },
Sell { symbol: StockSymbol, amount: usize }, Sell { symbol: StockSymbol, amount: usize },

View file

@ -1,8 +1,9 @@
use crate::user::SaltedPassword; use crate::{market::Market, user::SaltedPassword};
use log::{debug, info};
use rust_decimal::Decimal; use rust_decimal::Decimal;
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword}; use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword};
use sqlx::{query, PgPool}; use sqlx::{query, PgPool};
use std::convert::TryFrom; use std::{convert::TryFrom, sync::Arc};
use thiserror::Error; use thiserror::Error;
use vtse_common::net::{ServerResponse, UserError}; use vtse_common::net::{ServerResponse, UserError};
use vtse_common::stock::{Stock, StockSymbol}; use vtse_common::stock::{Stock, StockSymbol};
@ -12,12 +13,29 @@ use vtse_common::user::{ApiKey, Password, User, UserBalance, UserDebt, Username}
pub(crate) enum StateError { pub(crate) enum StateError {
#[error("Was not in the correct state")] #[error("Was not in the correct state")]
WrongState, WrongState,
#[error("Got SQLx error: {0}`")] #[error("Got SQLx error: {0}")]
Database(#[from] sqlx::Error), Database(#[from] sqlx::Error),
#[error("Failed to hash password")] #[error("Failed to hash password")]
PasswordHash, PasswordHash,
} }
#[derive(Error, Debug)]
pub(crate) enum UserOrDbError {
#[error("Got SQLx error: {0}")]
Database(#[from] sqlx::Error),
#[error("Got user error: {0}")]
User(UserError),
}
impl UserOrDbError {
fn into_operation_result(self) -> OperationResult {
match self {
UserOrDbError::Database(e) => Err(StateError::Database(e)),
UserOrDbError::User(e) => Ok(ServerResponse::UserError(e)),
}
}
}
type OperationResult = Result<ServerResponse, StateError>; type OperationResult = Result<ServerResponse, StateError>;
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)] #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
@ -43,11 +61,7 @@ impl AppState {
/// Query operations impl /// Query operations impl
impl AppState { impl AppState {
pub(crate) async fn stock_info( pub(crate) async fn stock_info(stock_symbol: StockSymbol, pool: &PgPool) -> OperationResult {
&self,
stock_symbol: StockSymbol,
pool: &PgPool,
) -> OperationResult {
let stock = query!( let stock = query!(
"SELECT name, symbol, description, price FROM stocks WHERE symbol = $1", "SELECT name, symbol, description, price FROM stocks WHERE symbol = $1",
stock_symbol.inner() stock_symbol.inner()
@ -63,7 +77,7 @@ impl AppState {
))) )))
} }
pub(crate) async fn user_info(&self, username: Username, pool: &PgPool) -> OperationResult { pub(crate) async fn user_info(username: Username, pool: &PgPool) -> OperationResult {
let user = query!( let user = query!(
"SELECT balance, debt FROM users WHERE username = $1", "SELECT balance, debt FROM users WHERE username = $1",
username.inner() username.inner()
@ -114,9 +128,7 @@ impl AppState {
let salted_password = let salted_password =
SaltedPassword::try_from(password).map_err(|_| StateError::PasswordHash)?; SaltedPassword::try_from(password).map_err(|_| StateError::PasswordHash)?;
query!( query!(
"INSERT INTO users "INSERT INTO users (username, pwhash_data) VALUES ($1, $2)",
(username, pwhash_data)
VALUES ($1, $2)",
username.inner(), username.inner(),
salted_password.as_ref(), salted_password.as_ref(),
) )
@ -143,9 +155,7 @@ impl AppState {
let mut transaction = pool.begin().await?; let mut transaction = pool.begin().await?;
let api_key = uuid::Uuid::new_v4(); let api_key = uuid::Uuid::new_v4();
let api_key_id = query!( let api_key_id = query!(
"INSERT INTO api_keys (key) "INSERT INTO api_keys (key) VALUES ($1) RETURNING api_keys.key_id;",
VALUES ($1)
RETURNING api_keys.key_id;",
&api_key, &api_key,
) )
.fetch_one(&mut transaction) .fetch_one(&mut transaction)
@ -153,8 +163,7 @@ impl AppState {
.key_id; .key_id;
query!( query!(
"INSERT INTO user_api_keys (user_id, key_id) "INSERT INTO user_api_keys (user_id, key_id) VALUES ($1, $2)",
VALUES ($1, $2)",
user_id, user_id,
api_key_id api_key_id
) )
@ -173,8 +182,7 @@ impl AppState {
pool: &PgPool, pool: &PgPool,
) -> Result<Option<i32>, StateError> { ) -> Result<Option<i32>, StateError> {
let result = query!( let result = query!(
"SELECT user_id, pwhash_data FROM users "SELECT user_id, pwhash_data FROM users WHERE username = $1",
WHERE username = $1",
username.inner() username.inner()
) )
.fetch_one(pool) .fetch_one(pool)
@ -199,33 +207,73 @@ impl AppState {
symbol: StockSymbol, symbol: StockSymbol,
amount: usize, amount: usize,
pool: &PgPool, pool: &PgPool,
market: &mut Arc<Market>,
) -> OperationResult { ) -> OperationResult {
let id = *match self { let id = *match self {
Self::Authorized { user_id } => user_id, Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState), _ => return Err(StateError::WrongState),
}; };
let stock_id = Self::get_stock_id(&symbol, pool).await?;
let user_balance = Self::get_user_balance(id, pool).await?; let user_balance = Self::get_user_balance(id, pool).await?;
let mut market = market.lock().await;
let current_price = {
let response = Self::get_stock_price(symbol.clone(), pool).await;
match response {
Ok(v) => v,
Err(e) => return e.into_operation_result(),
}
};
let cost_to_purchase = market.peek_price(&current_price, amount);
if user_balance >= cost_to_purchase {
let resp = self
.purchase_stock(stock_id, amount, cost_to_purchase, pool)
.await;
if resp.is_ok() {
info!(
"User {} will be purchasing {} {} securities for {}",
id, amount, symbol, cost_to_purchase
);
market.commit();
}
resp
} else {
todo!() todo!()
} }
}
pub(crate) async fn sell( pub(crate) async fn sell(
&self, &self,
symbol: StockSymbol, symbol: StockSymbol,
amount: usize, amount: usize,
pool: &PgPool, pool: &PgPool,
market: &mut Arc<Market>,
) -> OperationResult { ) -> OperationResult {
let id = *match self { let id = *match self {
Self::Authorized { user_id } => user_id, Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState), _ => return Err(StateError::WrongState),
}; };
let num_owned = Self::get_owned_stock_count(id, symbol, pool).await?; let stock_id = Self::get_stock_id(&symbol, pool).await?;
let num_owned = Self::get_owned_stock_count(id, stock_id, pool).await?;
if num_owned < amount { if num_owned < amount {
return Ok(ServerResponse::UserError(UserError::NotEnoughOwnedStock( return Ok(ServerResponse::UserError(UserError::NotEnoughOwnedStock(
num_owned, num_owned,
))); )));
} }
todo!() let mut market = market.lock().await;
let current_price = {
let response = Self::get_stock_price(symbol.clone(), pool).await;
match response {
Ok(v) => v,
Err(e) => return e.into_operation_result(),
}
};
let gains = market.peek_price(&current_price, amount);
let response = self
.sell_stock(stock_id, gains, amount, num_owned == amount, pool)
.await;
market.commit();
response
} }
// todo: fetch_one needs to turn into fetch_optional, else we can't discriminate user versus // todo: fetch_one needs to turn into fetch_optional, else we can't discriminate user versus
@ -241,21 +289,130 @@ impl AppState {
async fn get_owned_stock_count( async fn get_owned_stock_count(
user_id: i32, user_id: i32,
symbol: StockSymbol, stock_id: i32,
pool: &PgPool, pool: &PgPool,
) -> Result<usize, StateError> { ) -> Result<usize, StateError> {
query!( query!(
"SELECT amount FROM users_stocks "SELECT amount FROM users_stocks WHERE user_id = $1 AND stock_id = $2",
JOIN stocks ON users_stocks.stock_id = stocks.stock_id
WHERE user_id = $1 AND symbol = $2",
user_id, user_id,
symbol.inner() stock_id
) )
.fetch_one(pool) .fetch_one(pool)
.await .await
.map(|record| record.amount as usize) .map(|record| record.amount as usize)
.map_err(StateError::from) .map_err(StateError::from)
} }
async fn get_stock_price(symbol: StockSymbol, pool: &PgPool) -> Result<Decimal, UserOrDbError> {
let query = query!("SELECT price FROM stocks WHERE symbol = $1", symbol.inner())
.fetch_optional(pool)
.await;
match query {
Ok(Some(record)) => Ok(record.price),
Ok(None) => Err(UserOrDbError::User(UserError::InvalidStock(symbol))),
Err(e) => Err(UserOrDbError::Database(e)),
}
}
async fn purchase_stock(
&self,
symbol_id: i32,
amount: usize,
cost: Decimal,
pool: &PgPool,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
debug!(
"Starting transaction for id {}: {} {} at total cost {}",
id, amount, symbol_id, cost
);
let mut transaction = pool.begin().await?;
query!(
"UPDATE users SET balance = balance - $1 WHERE user_id = $2",
cost,
id
)
.execute(&mut transaction)
.await?;
query!(
"INSERT INTO users_stocks VALUES ($1, $2, $3)
ON CONFLICT ON CONSTRAINT users_stocks_user_id_stock_id_key DO
UPDATE SET amount = users_stocks.amount + $3::integer
WHERE users_stocks.user_id = $1",
id,
symbol_id,
amount as u32
)
.execute(&mut transaction)
.await?;
transaction.commit().await?;
Ok(ServerResponse::Success)
}
async fn get_stock_id(symbol: &StockSymbol, pool: &PgPool) -> Result<i32, StateError> {
query!(
"SELECT stock_id FROM stocks WHERE symbol = $1",
symbol.inner()
)
.fetch_one(pool)
.await
.map(|record| record.stock_id)
.map_err(StateError::from)
}
async fn sell_stock(
&self,
stock_id: i32,
gains: Decimal,
amount_sold: usize,
sell_all: bool,
pool: &PgPool,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let mut transaction = pool.begin().await?;
query!(
"UPDATE users SET balance = balance + $1::numeric WHERE user_id = $2",
gains,
id
)
.execute(&mut transaction)
.await?;
if sell_all {
query!(
"DELETE FROM users_stocks WHERE user_id = $1 AND stock_id = $2",
id,
stock_id
)
.execute(&mut transaction)
.await?;
} else {
query!(
"UPDATE users_stocks SET amount = amount - $1::integer
WHERE user_id = $2 AND stock_id = $3",
amount_sold as u32,
id,
stock_id,
)
.execute(&mut transaction)
.await?;
}
transaction.commit().await?;
Ok(ServerResponse::Success)
}
} }
impl Default for AppState { impl Default for AppState {