implement partial buy and sell
This commit is contained in:
parent
e6ff1576f1
commit
07bb068a2b
11 changed files with 504 additions and 84 deletions
12
.vscode/settings.json
vendored
Normal file
12
.vscode/settings.json
vendored
Normal file
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"cSpell.words": [
|
||||
"Gura",
|
||||
"dotenv",
|
||||
"icious",
|
||||
"pwhash",
|
||||
"sodiumoxide",
|
||||
"thiserror",
|
||||
"unsplit",
|
||||
"vtse"
|
||||
]
|
||||
}
|
122
Cargo.lock
generated
122
Cargo.lock
generated
|
@ -152,6 +152,38 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "3.0.0-beta.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4bd1061998a501ee7d4b6d449020df3266ca3124b941ec56cf2005c3779ca142"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"bitflags",
|
||||
"clap_derive",
|
||||
"indexmap",
|
||||
"lazy_static",
|
||||
"os_str_bytes",
|
||||
"strsim",
|
||||
"termcolor",
|
||||
"textwrap",
|
||||
"unicode-width",
|
||||
"vec_map",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "3.0.0-beta.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "370f715b81112975b1b69db93e0b56ea4cd4e5002ac43b2da8474106a54096a1"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colored"
|
||||
version = "1.9.3"
|
||||
|
@ -446,6 +478,16 @@ dependencies = [
|
|||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "1.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4fb1fa934250de4de8aef298d81c729a7d33d8c239daa3a7575e6b92bfc7313b"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "instant"
|
||||
version = "0.1.9"
|
||||
|
@ -649,6 +691,12 @@ version = "0.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
|
||||
|
||||
[[package]]
|
||||
name = "os_str_bytes"
|
||||
version = "2.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "afb2e1c3ee07430c2cf76151675e583e0f19985fa6efae47d6848a3e2c824f85"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.11.1"
|
||||
|
@ -704,6 +752,30 @@ version = "0.2.10"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
|
||||
dependencies = [
|
||||
"proc-macro-error-attr",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error-attr"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-hack"
|
||||
version = "0.5.19"
|
||||
|
@ -1146,6 +1218,12 @@ dependencies = [
|
|||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.4.0"
|
||||
|
@ -1169,6 +1247,24 @@ version = "1.0.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "36474e732d1affd3a6ed582781b3683df3d0563714c59c39591e8ff707cf078e"
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2dfed899f0eb03f32ee8c6a0aabdb8a7949659e3466561fc0adf54e26d88c5f4"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "203008d98caf094106cfaba70acfed15e18ed3ddb7d94e49baec153a2b462789"
|
||||
dependencies = [
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.23"
|
||||
|
@ -1293,9 +1389,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "unicode-normalization"
|
||||
version = "0.1.16"
|
||||
version = "0.1.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a13e63ab62dbe32aeee58d1c5408d35c36c392bba5d9d3142287219721afe606"
|
||||
checksum = "07fbfce1c8a97d547e8b5334978438d9d6ec8c20e38f56d4a4374d181493eaef"
|
||||
dependencies = [
|
||||
"tinyvec",
|
||||
]
|
||||
|
@ -1306,6 +1402,12 @@ version = "1.7.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bb0d2e7be6ae3a5fa87eed5fb451aff96f2573d2694942e40543ae0bbe19c796"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9337591893a19b88d8d87f2cec1e73fad5cdfd10e5a6f349f498ad6ea2ffb1e3"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.1"
|
||||
|
@ -1346,6 +1448,12 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "vec_map"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.2"
|
||||
|
@ -1373,6 +1481,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"clap",
|
||||
"dotenv",
|
||||
"log",
|
||||
"rand 0.8.3",
|
||||
|
@ -1514,6 +1623,15 @@ version = "0.4.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
use crate::{stock::Stock, user::User};
|
||||
use crate::{
|
||||
stock::{Stock, StockSymbol},
|
||||
user::User,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
@ -13,11 +17,18 @@ pub enum ServerResponse {
|
|||
UserInfo(User),
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[derive(Error, Serialize, Debug)]
|
||||
pub enum UserError {
|
||||
#[error("An invalid username was provided.")]
|
||||
InvalidUsername,
|
||||
#[error("An invalid password was provided.")]
|
||||
InvalidPassword,
|
||||
#[error("An invalid API key was provided.")]
|
||||
InvalidApiKey,
|
||||
#[error("This requires authorization. Please login first.")]
|
||||
NotAuthorized,
|
||||
#[error("You don't have enough stock to sell {0} units.")]
|
||||
NotEnoughOwnedStock(usize),
|
||||
#[error("Stock symbol {0} does not exist.")]
|
||||
InvalidStock(StockSymbol),
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use rust_decimal::Decimal;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::str::FromStr;
|
||||
use std::{fmt::Display, str::FromStr};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)]
|
||||
|
@ -28,6 +28,12 @@ impl FromStr for StockName {
|
|||
#[serde(transparent)]
|
||||
pub struct StockSymbol(String);
|
||||
|
||||
impl Display for StockSymbol {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "${}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for StockSymbol {
|
||||
fn from(s: String) -> Self {
|
||||
Self(s)
|
||||
|
|
|
@ -7,6 +7,7 @@ edition = "2018"
|
|||
[dependencies]
|
||||
anyhow = "1"
|
||||
bytes = "1"
|
||||
clap = "3.0.0-beta.2"
|
||||
dotenv = "0.15"
|
||||
log = "0.4"
|
||||
rand = "0.8"
|
||||
|
|
14
vtse-server/src/args.rs
Normal file
14
vtse-server/src/args.rs
Normal file
|
@ -0,0 +1,14 @@
|
|||
use clap::Clap;
|
||||
use log::LevelFilter;
|
||||
|
||||
#[derive(Clap)]
|
||||
pub(crate) struct Args {
|
||||
#[clap(long, env("DATABASE_URL"))]
|
||||
pub(crate) database_url: String,
|
||||
/// What address the vtse server should listen to.
|
||||
#[clap(long, env("BIND_ADDRESS"), default_value = "localhost:8080")]
|
||||
pub(crate) bind_address: String,
|
||||
/// Sets the logging level DB queries should report at.
|
||||
#[clap(long, default_value = "off")]
|
||||
pub(crate) db_log_level: LevelFilter,
|
||||
}
|
|
@ -1,15 +1,20 @@
|
|||
use crate::operations::ServerOperation;
|
||||
use anyhow::{bail, Result};
|
||||
use args::Args;
|
||||
use bytes::BytesMut;
|
||||
use clap::Clap;
|
||||
use log::{debug, error, info};
|
||||
use market::Market;
|
||||
use operations::{MarketOperation, QueryOperation, UserOperation};
|
||||
use sqlx::PgPool;
|
||||
use simple_logger::SimpleLogger;
|
||||
use sqlx::{postgres::PgConnectOptions, ConnectOptions, PgPool};
|
||||
use state::AppState;
|
||||
use std::env;
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
|
||||
|
||||
mod args;
|
||||
mod market;
|
||||
mod operations;
|
||||
mod state;
|
||||
|
@ -17,26 +22,35 @@ mod user;
|
|||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Must init dotenv before initializing args, else args won't default from
|
||||
// env properly
|
||||
dotenv::dotenv().ok();
|
||||
// If we can't successfully initialize our crypto library, fail immediately.
|
||||
sodiumoxide::init().unwrap();
|
||||
simple_logger::SimpleLogger::default().init()?;
|
||||
let args: Args = Args::parse();
|
||||
|
||||
let mut listener_stream = TcpListener::bind(&env::var("BIND_ADDRESS")?)
|
||||
// If we can't successfully initialize our crypto library, fail immediately.
|
||||
sodiumoxide::init().expect("to initialize crypto library");
|
||||
SimpleLogger::default().init()?;
|
||||
|
||||
let mut listener_stream = TcpListener::bind(&args.bind_address)
|
||||
.await
|
||||
.map(TcpListenerStream::new)?;
|
||||
|
||||
info!("Successfully bound to port.");
|
||||
|
||||
let db_pool = PgPool::connect(&env::var("DATABASE_URL")?).await?;
|
||||
|
||||
let db_pool = {
|
||||
let mut connect_options = PgConnectOptions::from_str(&args.database_url)?;
|
||||
connect_options.log_statements(args.db_log_level);
|
||||
PgPool::connect_with(connect_options).await?
|
||||
};
|
||||
info!("Successfully established connection to database.");
|
||||
|
||||
let market = Arc::new(Market::new());
|
||||
|
||||
while let Some(Ok(stream)) = listener_stream.next().await {
|
||||
// clone simply clones the arc reference, so this is cheap.
|
||||
let db_pool = db_pool.clone();
|
||||
let market = Arc::clone(&market);
|
||||
tokio::task::spawn(async {
|
||||
if let Err(e) = handle_stream(stream, db_pool).await {
|
||||
if let Err(e) = handle_stream(stream, db_pool, market).await {
|
||||
error!("{}", e);
|
||||
}
|
||||
});
|
||||
|
@ -47,7 +61,7 @@ async fn main() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
|
||||
async fn handle_stream(mut socket: TcpStream, pool: PgPool, mut market: Arc<Market>) -> Result<()> {
|
||||
// only accept data that can fit in 256 bytes
|
||||
// the assumption is that a single request should always be within n bytes,
|
||||
// otherwise there's a good chance that it's mal{formed,icious}.
|
||||
|
@ -58,14 +72,28 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
|
|||
if bytes_read == 0 {
|
||||
bail!("Failed to read bytes, assuming socket is closed.")
|
||||
}
|
||||
let data = buffer.split_to(bytes_read); // O(1)
|
||||
let parsed = serde_json::from_slice::<ServerOperation>(&data)?;
|
||||
debug!("Parsed operation: {:?}", parsed);
|
||||
|
||||
let response = match parsed {
|
||||
let data = buffer.split_to(bytes_read); // O(1)
|
||||
dbg!(&data);
|
||||
let iter = serde_json::Deserializer::from_slice(&data).into_iter();
|
||||
|
||||
for messages in iter {
|
||||
let message = match messages {
|
||||
Ok(p) => p,
|
||||
Err(e) if e.is_eof() || e.is_io() => return Ok(()),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
debug!("Parsed operation: {:?}", message);
|
||||
|
||||
let response = match message {
|
||||
ServerOperation::Query(op) => match op {
|
||||
QueryOperation::StockInfo { stock } => state.stock_info(stock, &pool).await?,
|
||||
QueryOperation::User { username } => state.user_info(username, &pool).await?,
|
||||
QueryOperation::StockInfo { stock } => {
|
||||
AppState::stock_info(stock, &pool).await?
|
||||
}
|
||||
QueryOperation::User { username } => {
|
||||
AppState::user_info(username, &pool).await?
|
||||
}
|
||||
},
|
||||
ServerOperation::User(op) => match op {
|
||||
UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
|
||||
|
@ -77,16 +105,19 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
|
|||
}
|
||||
},
|
||||
ServerOperation::Market(op) => match op {
|
||||
MarketOperation::Buy { symbol, amount } => state.buy(symbol, amount, &pool).await?,
|
||||
MarketOperation::Buy { symbol, amount } => {
|
||||
state.buy(symbol, amount, &pool, &mut market).await?
|
||||
}
|
||||
MarketOperation::Sell { symbol, amount } => {
|
||||
state.sell(symbol, amount, &pool).await?
|
||||
state.sell(symbol, amount, &pool, &mut market).await?
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
socket
|
||||
.write_all(serde_json::to_string(&response).unwrap().as_bytes())
|
||||
.await?;
|
||||
}
|
||||
|
||||
buffer.unsplit(data); // O(1)
|
||||
buffer.clear();
|
||||
}
|
||||
|
|
|
@ -1,37 +1,66 @@
|
|||
use rand::rngs::ThreadRng;
|
||||
use rand::thread_rng;
|
||||
use rand::{distributions::Uniform, prelude::Distribution};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use rust_decimal::Decimal;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, PartialOrd, Debug)]
|
||||
pub(crate) struct MarketModifier(f64);
|
||||
pub(super) struct MarketModifier(Decimal);
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct MarketGenerator {
|
||||
rng: ThreadRng,
|
||||
sample_range: Uniform<f64>,
|
||||
volatility: f64,
|
||||
old_price: MarketModifier,
|
||||
impl MarketModifier {
|
||||
pub(super) fn modifier(&self) -> Decimal {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl MarketGenerator {
|
||||
pub(crate) fn new(volatility: f64, initial_value: f64) -> Self {
|
||||
#[derive(Clone, Debug)]
|
||||
pub(super) struct MarketMultiplier {
|
||||
rng: StdRng,
|
||||
sample_range: Uniform<i64>,
|
||||
volatility: Decimal,
|
||||
old_price: MarketModifier,
|
||||
buffer: Vec<MarketModifier>,
|
||||
}
|
||||
|
||||
impl MarketMultiplier {
|
||||
pub(super) fn new(volatility: Decimal, initial_value: Decimal) -> Self {
|
||||
Self {
|
||||
rng: thread_rng(),
|
||||
sample_range: Uniform::new_inclusive(-0.5, 0.5),
|
||||
rng: StdRng::from_entropy(),
|
||||
sample_range: Uniform::new_inclusive(-50, 50),
|
||||
volatility,
|
||||
old_price: MarketModifier(initial_value),
|
||||
buffer: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for MarketGenerator {
|
||||
impl Iterator for MarketMultiplier {
|
||||
type Item = MarketModifier;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
// Algorithm from https://stackoverflow.com/a/8597889
|
||||
let rng_multiplier = self.sample_range.sample(&mut self.rng);
|
||||
let delta = 2f64 * self.volatility * rng_multiplier;
|
||||
let delta = Decimal::new(2, 0) * self.volatility * Decimal::new(rng_multiplier, 2);
|
||||
self.old_price.0 += self.old_price.0 * delta;
|
||||
Some(self.old_price)
|
||||
}
|
||||
}
|
||||
|
||||
impl MarketMultiplier {
|
||||
pub(super) fn next_n(&mut self, n: usize) -> Vec<MarketModifier> {
|
||||
self.ensure_buffer_capacity(n);
|
||||
let mut ret = self.buffer.split_off(n);
|
||||
std::mem::swap(&mut ret, &mut self.buffer);
|
||||
ret
|
||||
}
|
||||
|
||||
pub(super) fn peek_n(&mut self, n: usize) -> &[MarketModifier] {
|
||||
self.ensure_buffer_capacity(n);
|
||||
&self.buffer[..n]
|
||||
}
|
||||
|
||||
fn ensure_buffer_capacity(&mut self, n: usize) {
|
||||
for _ in 0..(n - self.buffer.len()) {
|
||||
let next_val = self.next().unwrap();
|
||||
self.buffer.push(next_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1 +1,41 @@
|
|||
use generator::MarketMultiplier;
|
||||
use rust_decimal::Decimal;
|
||||
use tokio::sync::{Mutex, MutexGuard};
|
||||
|
||||
mod generator;
|
||||
|
||||
pub(crate) struct Market {
|
||||
generator: Mutex<MarketMultiplier>,
|
||||
}
|
||||
|
||||
impl Market {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
generator: Mutex::new(MarketMultiplier::new(
|
||||
Decimal::new(1, 1),
|
||||
Decimal::new(1, 0),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn lock(&self) -> LockedMarket<'_> {
|
||||
LockedMarket(self.generator.lock().await, 0)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LockedMarket<'a>(MutexGuard<'a, MarketMultiplier>, usize);
|
||||
|
||||
impl<'a> LockedMarket<'a> {
|
||||
pub(crate) fn peek_price(&mut self, initial_price: &Decimal, n: usize) -> Decimal {
|
||||
let mut price = *initial_price;
|
||||
for i in self.0.peek_n(n).iter() {
|
||||
price *= i.modifier();
|
||||
}
|
||||
self.1 = n;
|
||||
price.round_dp(2)
|
||||
}
|
||||
|
||||
pub(crate) fn commit(mut self) {
|
||||
self.0.next_n(self.1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ pub(crate) enum UserOperation {
|
|||
}
|
||||
|
||||
#[derive(Deserialize, Debug, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub(crate) enum MarketOperation {
|
||||
Buy { symbol: StockSymbol, amount: usize },
|
||||
Sell { symbol: StockSymbol, amount: usize },
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
use crate::user::SaltedPassword;
|
||||
use crate::{market::Market, user::SaltedPassword};
|
||||
use log::{debug, info};
|
||||
use rust_decimal::Decimal;
|
||||
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword};
|
||||
use sqlx::{query, PgPool};
|
||||
use std::convert::TryFrom;
|
||||
use std::{convert::TryFrom, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use vtse_common::net::{ServerResponse, UserError};
|
||||
use vtse_common::stock::{Stock, StockSymbol};
|
||||
|
@ -12,12 +13,29 @@ use vtse_common::user::{ApiKey, Password, User, UserBalance, UserDebt, Username}
|
|||
pub(crate) enum StateError {
|
||||
#[error("Was not in the correct state")]
|
||||
WrongState,
|
||||
#[error("Got SQLx error: {0}`")]
|
||||
#[error("Got SQLx error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
#[error("Failed to hash password")]
|
||||
PasswordHash,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub(crate) enum UserOrDbError {
|
||||
#[error("Got SQLx error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
#[error("Got user error: {0}")]
|
||||
User(UserError),
|
||||
}
|
||||
|
||||
impl UserOrDbError {
|
||||
fn into_operation_result(self) -> OperationResult {
|
||||
match self {
|
||||
UserOrDbError::Database(e) => Err(StateError::Database(e)),
|
||||
UserOrDbError::User(e) => Ok(ServerResponse::UserError(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type OperationResult = Result<ServerResponse, StateError>;
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
|
||||
|
@ -43,11 +61,7 @@ impl AppState {
|
|||
|
||||
/// Query operations impl
|
||||
impl AppState {
|
||||
pub(crate) async fn stock_info(
|
||||
&self,
|
||||
stock_symbol: StockSymbol,
|
||||
pool: &PgPool,
|
||||
) -> OperationResult {
|
||||
pub(crate) async fn stock_info(stock_symbol: StockSymbol, pool: &PgPool) -> OperationResult {
|
||||
let stock = query!(
|
||||
"SELECT name, symbol, description, price FROM stocks WHERE symbol = $1",
|
||||
stock_symbol.inner()
|
||||
|
@ -63,7 +77,7 @@ impl AppState {
|
|||
)))
|
||||
}
|
||||
|
||||
pub(crate) async fn user_info(&self, username: Username, pool: &PgPool) -> OperationResult {
|
||||
pub(crate) async fn user_info(username: Username, pool: &PgPool) -> OperationResult {
|
||||
let user = query!(
|
||||
"SELECT balance, debt FROM users WHERE username = $1",
|
||||
username.inner()
|
||||
|
@ -114,9 +128,7 @@ impl AppState {
|
|||
let salted_password =
|
||||
SaltedPassword::try_from(password).map_err(|_| StateError::PasswordHash)?;
|
||||
query!(
|
||||
"INSERT INTO users
|
||||
(username, pwhash_data)
|
||||
VALUES ($1, $2)",
|
||||
"INSERT INTO users (username, pwhash_data) VALUES ($1, $2)",
|
||||
username.inner(),
|
||||
salted_password.as_ref(),
|
||||
)
|
||||
|
@ -143,9 +155,7 @@ impl AppState {
|
|||
let mut transaction = pool.begin().await?;
|
||||
let api_key = uuid::Uuid::new_v4();
|
||||
let api_key_id = query!(
|
||||
"INSERT INTO api_keys (key)
|
||||
VALUES ($1)
|
||||
RETURNING api_keys.key_id;",
|
||||
"INSERT INTO api_keys (key) VALUES ($1) RETURNING api_keys.key_id;",
|
||||
&api_key,
|
||||
)
|
||||
.fetch_one(&mut transaction)
|
||||
|
@ -153,8 +163,7 @@ impl AppState {
|
|||
.key_id;
|
||||
|
||||
query!(
|
||||
"INSERT INTO user_api_keys (user_id, key_id)
|
||||
VALUES ($1, $2)",
|
||||
"INSERT INTO user_api_keys (user_id, key_id) VALUES ($1, $2)",
|
||||
user_id,
|
||||
api_key_id
|
||||
)
|
||||
|
@ -173,8 +182,7 @@ impl AppState {
|
|||
pool: &PgPool,
|
||||
) -> Result<Option<i32>, StateError> {
|
||||
let result = query!(
|
||||
"SELECT user_id, pwhash_data FROM users
|
||||
WHERE username = $1",
|
||||
"SELECT user_id, pwhash_data FROM users WHERE username = $1",
|
||||
username.inner()
|
||||
)
|
||||
.fetch_one(pool)
|
||||
|
@ -199,33 +207,73 @@ impl AppState {
|
|||
symbol: StockSymbol,
|
||||
amount: usize,
|
||||
pool: &PgPool,
|
||||
market: &mut Arc<Market>,
|
||||
) -> OperationResult {
|
||||
let id = *match self {
|
||||
Self::Authorized { user_id } => user_id,
|
||||
_ => return Err(StateError::WrongState),
|
||||
};
|
||||
let stock_id = Self::get_stock_id(&symbol, pool).await?;
|
||||
let user_balance = Self::get_user_balance(id, pool).await?;
|
||||
let mut market = market.lock().await;
|
||||
let current_price = {
|
||||
let response = Self::get_stock_price(symbol.clone(), pool).await;
|
||||
match response {
|
||||
Ok(v) => v,
|
||||
Err(e) => return e.into_operation_result(),
|
||||
}
|
||||
};
|
||||
let cost_to_purchase = market.peek_price(¤t_price, amount);
|
||||
|
||||
if user_balance >= cost_to_purchase {
|
||||
let resp = self
|
||||
.purchase_stock(stock_id, amount, cost_to_purchase, pool)
|
||||
.await;
|
||||
if resp.is_ok() {
|
||||
info!(
|
||||
"User {} will be purchasing {} {} securities for {}",
|
||||
id, amount, symbol, cost_to_purchase
|
||||
);
|
||||
market.commit();
|
||||
}
|
||||
resp
|
||||
} else {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn sell(
|
||||
&self,
|
||||
symbol: StockSymbol,
|
||||
amount: usize,
|
||||
pool: &PgPool,
|
||||
market: &mut Arc<Market>,
|
||||
) -> OperationResult {
|
||||
let id = *match self {
|
||||
Self::Authorized { user_id } => user_id,
|
||||
_ => return Err(StateError::WrongState),
|
||||
};
|
||||
let num_owned = Self::get_owned_stock_count(id, symbol, pool).await?;
|
||||
let stock_id = Self::get_stock_id(&symbol, pool).await?;
|
||||
let num_owned = Self::get_owned_stock_count(id, stock_id, pool).await?;
|
||||
if num_owned < amount {
|
||||
return Ok(ServerResponse::UserError(UserError::NotEnoughOwnedStock(
|
||||
num_owned,
|
||||
)));
|
||||
}
|
||||
todo!()
|
||||
let mut market = market.lock().await;
|
||||
let current_price = {
|
||||
let response = Self::get_stock_price(symbol.clone(), pool).await;
|
||||
match response {
|
||||
Ok(v) => v,
|
||||
Err(e) => return e.into_operation_result(),
|
||||
}
|
||||
};
|
||||
let gains = market.peek_price(¤t_price, amount);
|
||||
let response = self
|
||||
.sell_stock(stock_id, gains, amount, num_owned == amount, pool)
|
||||
.await;
|
||||
market.commit();
|
||||
response
|
||||
}
|
||||
|
||||
// todo: fetch_one needs to turn into fetch_optional, else we can't discriminate user versus
|
||||
|
@ -241,21 +289,130 @@ impl AppState {
|
|||
|
||||
async fn get_owned_stock_count(
|
||||
user_id: i32,
|
||||
symbol: StockSymbol,
|
||||
stock_id: i32,
|
||||
pool: &PgPool,
|
||||
) -> Result<usize, StateError> {
|
||||
query!(
|
||||
"SELECT amount FROM users_stocks
|
||||
JOIN stocks ON users_stocks.stock_id = stocks.stock_id
|
||||
WHERE user_id = $1 AND symbol = $2",
|
||||
"SELECT amount FROM users_stocks WHERE user_id = $1 AND stock_id = $2",
|
||||
user_id,
|
||||
symbol.inner()
|
||||
stock_id
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.map(|record| record.amount as usize)
|
||||
.map_err(StateError::from)
|
||||
}
|
||||
|
||||
async fn get_stock_price(symbol: StockSymbol, pool: &PgPool) -> Result<Decimal, UserOrDbError> {
|
||||
let query = query!("SELECT price FROM stocks WHERE symbol = $1", symbol.inner())
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
match query {
|
||||
Ok(Some(record)) => Ok(record.price),
|
||||
Ok(None) => Err(UserOrDbError::User(UserError::InvalidStock(symbol))),
|
||||
Err(e) => Err(UserOrDbError::Database(e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn purchase_stock(
|
||||
&self,
|
||||
symbol_id: i32,
|
||||
amount: usize,
|
||||
cost: Decimal,
|
||||
pool: &PgPool,
|
||||
) -> OperationResult {
|
||||
let id = *match self {
|
||||
Self::Authorized { user_id } => user_id,
|
||||
_ => return Err(StateError::WrongState),
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Starting transaction for id {}: {} {} at total cost {}",
|
||||
id, amount, symbol_id, cost
|
||||
);
|
||||
|
||||
let mut transaction = pool.begin().await?;
|
||||
query!(
|
||||
"UPDATE users SET balance = balance - $1 WHERE user_id = $2",
|
||||
cost,
|
||||
id
|
||||
)
|
||||
.execute(&mut transaction)
|
||||
.await?;
|
||||
|
||||
query!(
|
||||
"INSERT INTO users_stocks VALUES ($1, $2, $3)
|
||||
ON CONFLICT ON CONSTRAINT users_stocks_user_id_stock_id_key DO
|
||||
UPDATE SET amount = users_stocks.amount + $3::integer
|
||||
WHERE users_stocks.user_id = $1",
|
||||
id,
|
||||
symbol_id,
|
||||
amount as u32
|
||||
)
|
||||
.execute(&mut transaction)
|
||||
.await?;
|
||||
|
||||
transaction.commit().await?;
|
||||
|
||||
Ok(ServerResponse::Success)
|
||||
}
|
||||
|
||||
async fn get_stock_id(symbol: &StockSymbol, pool: &PgPool) -> Result<i32, StateError> {
|
||||
query!(
|
||||
"SELECT stock_id FROM stocks WHERE symbol = $1",
|
||||
symbol.inner()
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.map(|record| record.stock_id)
|
||||
.map_err(StateError::from)
|
||||
}
|
||||
|
||||
async fn sell_stock(
|
||||
&self,
|
||||
stock_id: i32,
|
||||
gains: Decimal,
|
||||
amount_sold: usize,
|
||||
sell_all: bool,
|
||||
pool: &PgPool,
|
||||
) -> OperationResult {
|
||||
let id = *match self {
|
||||
Self::Authorized { user_id } => user_id,
|
||||
_ => return Err(StateError::WrongState),
|
||||
};
|
||||
let mut transaction = pool.begin().await?;
|
||||
|
||||
query!(
|
||||
"UPDATE users SET balance = balance + $1::numeric WHERE user_id = $2",
|
||||
gains,
|
||||
id
|
||||
)
|
||||
.execute(&mut transaction)
|
||||
.await?;
|
||||
|
||||
if sell_all {
|
||||
query!(
|
||||
"DELETE FROM users_stocks WHERE user_id = $1 AND stock_id = $2",
|
||||
id,
|
||||
stock_id
|
||||
)
|
||||
.execute(&mut transaction)
|
||||
.await?;
|
||||
} else {
|
||||
query!(
|
||||
"UPDATE users_stocks SET amount = amount - $1::integer
|
||||
WHERE user_id = $2 AND stock_id = $3",
|
||||
amount_sold as u32,
|
||||
id,
|
||||
stock_id,
|
||||
)
|
||||
.execute(&mut transaction)
|
||||
.await?;
|
||||
}
|
||||
|
||||
transaction.commit().await?;
|
||||
Ok(ServerResponse::Success)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppState {
|
||||
|
|
Loading…
Reference in a new issue