implement partial buy and sell

master
Edward Shen 2021-02-11 13:55:28 -05:00
parent e6ff1576f1
commit 07bb068a2b
Signed by: edward
GPG Key ID: 19182661E818369F
11 changed files with 504 additions and 84 deletions

12
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,12 @@
{
"cSpell.words": [
"Gura",
"dotenv",
"icious",
"pwhash",
"sodiumoxide",
"thiserror",
"unsplit",
"vtse"
]
}

122
Cargo.lock generated
View File

@ -152,6 +152,38 @@ dependencies = [
"winapi",
]
[[package]]
name = "clap"
version = "3.0.0-beta.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4bd1061998a501ee7d4b6d449020df3266ca3124b941ec56cf2005c3779ca142"
dependencies = [
"atty",
"bitflags",
"clap_derive",
"indexmap",
"lazy_static",
"os_str_bytes",
"strsim",
"termcolor",
"textwrap",
"unicode-width",
"vec_map",
]
[[package]]
name = "clap_derive"
version = "3.0.0-beta.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "370f715b81112975b1b69db93e0b56ea4cd4e5002ac43b2da8474106a54096a1"
dependencies = [
"heck",
"proc-macro-error",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "colored"
version = "1.9.3"
@ -446,6 +478,16 @@ dependencies = [
"unicode-normalization",
]
[[package]]
name = "indexmap"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fb1fa934250de4de8aef298d81c729a7d33d8c239daa3a7575e6b92bfc7313b"
dependencies = [
"autocfg",
"hashbrown",
]
[[package]]
name = "instant"
version = "0.1.9"
@ -649,6 +691,12 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]]
name = "os_str_bytes"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afb2e1c3ee07430c2cf76151675e583e0f19985fa6efae47d6848a3e2c824f85"
[[package]]
name = "parking_lot"
version = "0.11.1"
@ -704,6 +752,30 @@ version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.19"
@ -1146,6 +1218,12 @@ dependencies = [
"unicode-normalization",
]
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "subtle"
version = "2.4.0"
@ -1169,6 +1247,24 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36474e732d1affd3a6ed582781b3683df3d0563714c59c39591e8ff707cf078e"
[[package]]
name = "termcolor"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dfed899f0eb03f32ee8c6a0aabdb8a7949659e3466561fc0adf54e26d88c5f4"
dependencies = [
"winapi-util",
]
[[package]]
name = "textwrap"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "203008d98caf094106cfaba70acfed15e18ed3ddb7d94e49baec153a2b462789"
dependencies = [
"unicode-width",
]
[[package]]
name = "thiserror"
version = "1.0.23"
@ -1293,9 +1389,9 @@ dependencies = [
[[package]]
name = "unicode-normalization"
version = "0.1.16"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a13e63ab62dbe32aeee58d1c5408d35c36c392bba5d9d3142287219721afe606"
checksum = "07fbfce1c8a97d547e8b5334978438d9d6ec8c20e38f56d4a4374d181493eaef"
dependencies = [
"tinyvec",
]
@ -1306,6 +1402,12 @@ version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb0d2e7be6ae3a5fa87eed5fb451aff96f2573d2694942e40543ae0bbe19c796"
[[package]]
name = "unicode-width"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9337591893a19b88d8d87f2cec1e73fad5cdfd10e5a6f349f498ad6ea2ffb1e3"
[[package]]
name = "unicode-xid"
version = "0.2.1"
@ -1346,6 +1448,12 @@ dependencies = [
"serde",
]
[[package]]
name = "vec_map"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
[[package]]
name = "version_check"
version = "0.9.2"
@ -1373,6 +1481,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"bytes",
"clap",
"dotenv",
"log",
"rand 0.8.3",
@ -1514,6 +1623,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
dependencies = [
"winapi",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"

View File

@ -1,5 +1,9 @@
use crate::{stock::Stock, user::User};
use crate::{
stock::{Stock, StockSymbol},
user::User,
};
use serde::Serialize;
use thiserror::Error;
use uuid::Uuid;
#[derive(Serialize)]
@ -13,11 +17,18 @@ pub enum ServerResponse {
UserInfo(User),
}
#[derive(Serialize)]
#[derive(Error, Serialize, Debug)]
pub enum UserError {
#[error("An invalid username was provided.")]
InvalidUsername,
#[error("An invalid password was provided.")]
InvalidPassword,
#[error("An invalid API key was provided.")]
InvalidApiKey,
#[error("This requires authorization. Please login first.")]
NotAuthorized,
#[error("You don't have enough stock to sell {0} units.")]
NotEnoughOwnedStock(usize),
#[error("Stock symbol {0} does not exist.")]
InvalidStock(StockSymbol),
}

View File

@ -1,6 +1,6 @@
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::{fmt::Display, str::FromStr};
use thiserror::Error;
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)]
@ -28,6 +28,12 @@ impl FromStr for StockName {
#[serde(transparent)]
pub struct StockSymbol(String);
impl Display for StockSymbol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "${}", self.0)
}
}
impl From<String> for StockSymbol {
fn from(s: String) -> Self {
Self(s)

View File

@ -7,6 +7,7 @@ edition = "2018"
[dependencies]
anyhow = "1"
bytes = "1"
clap = "3.0.0-beta.2"
dotenv = "0.15"
log = "0.4"
rand = "0.8"

14
vtse-server/src/args.rs Normal file
View File

@ -0,0 +1,14 @@
use clap::Clap;
use log::LevelFilter;
#[derive(Clap)]
pub(crate) struct Args {
#[clap(long, env("DATABASE_URL"))]
pub(crate) database_url: String,
/// What address the vtse server should listen to.
#[clap(long, env("BIND_ADDRESS"), default_value = "localhost:8080")]
pub(crate) bind_address: String,
/// Sets the logging level DB queries should report at.
#[clap(long, default_value = "off")]
pub(crate) db_log_level: LevelFilter,
}

View File

@ -1,15 +1,20 @@
use crate::operations::ServerOperation;
use anyhow::{bail, Result};
use args::Args;
use bytes::BytesMut;
use clap::Clap;
use log::{debug, error, info};
use market::Market;
use operations::{MarketOperation, QueryOperation, UserOperation};
use sqlx::PgPool;
use simple_logger::SimpleLogger;
use sqlx::{postgres::PgConnectOptions, ConnectOptions, PgPool};
use state::AppState;
use std::env;
use std::{str::FromStr, sync::Arc};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
mod args;
mod market;
mod operations;
mod state;
@ -17,26 +22,35 @@ mod user;
#[tokio::main]
async fn main() -> Result<()> {
// Must init dotenv before initializing args, else args won't default from
// env properly
dotenv::dotenv().ok();
// If we can't successfully initialize our crypto library, fail immediately.
sodiumoxide::init().unwrap();
simple_logger::SimpleLogger::default().init()?;
let args: Args = Args::parse();
let mut listener_stream = TcpListener::bind(&env::var("BIND_ADDRESS")?)
// If we can't successfully initialize our crypto library, fail immediately.
sodiumoxide::init().expect("to initialize crypto library");
SimpleLogger::default().init()?;
let mut listener_stream = TcpListener::bind(&args.bind_address)
.await
.map(TcpListenerStream::new)?;
info!("Successfully bound to port.");
let db_pool = PgPool::connect(&env::var("DATABASE_URL")?).await?;
let db_pool = {
let mut connect_options = PgConnectOptions::from_str(&args.database_url)?;
connect_options.log_statements(args.db_log_level);
PgPool::connect_with(connect_options).await?
};
info!("Successfully established connection to database.");
let market = Arc::new(Market::new());
while let Some(Ok(stream)) = listener_stream.next().await {
// clone simply clones the arc reference, so this is cheap.
let db_pool = db_pool.clone();
let market = Arc::clone(&market);
tokio::task::spawn(async {
if let Err(e) = handle_stream(stream, db_pool).await {
if let Err(e) = handle_stream(stream, db_pool, market).await {
error!("{}", e);
}
});
@ -47,7 +61,7 @@ async fn main() -> Result<()> {
Ok(())
}
async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
async fn handle_stream(mut socket: TcpStream, pool: PgPool, mut market: Arc<Market>) -> Result<()> {
// only accept data that can fit in 256 bytes
// the assumption is that a single request should always be within n bytes,
// otherwise there's a good chance that it's mal{formed,icious}.
@ -58,35 +72,52 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
if bytes_read == 0 {
bail!("Failed to read bytes, assuming socket is closed.")
}
let data = buffer.split_to(bytes_read); // O(1)
let parsed = serde_json::from_slice::<ServerOperation>(&data)?;
debug!("Parsed operation: {:?}", parsed);
dbg!(&data);
let iter = serde_json::Deserializer::from_slice(&data).into_iter();
let response = match parsed {
ServerOperation::Query(op) => match op {
QueryOperation::StockInfo { stock } => state.stock_info(stock, &pool).await?,
QueryOperation::User { username } => state.user_info(username, &pool).await?,
},
ServerOperation::User(op) => match op {
UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
UserOperation::Register { username, password } => {
state.register(username, password, &pool).await?
}
UserOperation::GetKey { username, password } => {
state.generate_api_key(username, password, &pool).await?
}
},
ServerOperation::Market(op) => match op {
MarketOperation::Buy { symbol, amount } => state.buy(symbol, amount, &pool).await?,
MarketOperation::Sell { symbol, amount } => {
state.sell(symbol, amount, &pool).await?
}
},
};
for messages in iter {
let message = match messages {
Ok(p) => p,
Err(e) if e.is_eof() || e.is_io() => return Ok(()),
Err(e) => return Err(e.into()),
};
debug!("Parsed operation: {:?}", message);
let response = match message {
ServerOperation::Query(op) => match op {
QueryOperation::StockInfo { stock } => {
AppState::stock_info(stock, &pool).await?
}
QueryOperation::User { username } => {
AppState::user_info(username, &pool).await?
}
},
ServerOperation::User(op) => match op {
UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
UserOperation::Register { username, password } => {
state.register(username, password, &pool).await?
}
UserOperation::GetKey { username, password } => {
state.generate_api_key(username, password, &pool).await?
}
},
ServerOperation::Market(op) => match op {
MarketOperation::Buy { symbol, amount } => {
state.buy(symbol, amount, &pool, &mut market).await?
}
MarketOperation::Sell { symbol, amount } => {
state.sell(symbol, amount, &pool, &mut market).await?
}
},
};
socket
.write_all(serde_json::to_string(&response).unwrap().as_bytes())
.await?;
}
socket
.write_all(serde_json::to_string(&response).unwrap().as_bytes())
.await?;
buffer.unsplit(data); // O(1)
buffer.clear();
}

View File

@ -1,37 +1,66 @@
use rand::rngs::ThreadRng;
use rand::thread_rng;
use rand::{distributions::Uniform, prelude::Distribution};
use rand::{rngs::StdRng, SeedableRng};
use rust_decimal::Decimal;
#[derive(Copy, Clone, PartialEq, PartialOrd, Debug)]
pub(crate) struct MarketModifier(f64);
pub(super) struct MarketModifier(Decimal);
#[derive(Clone, Debug)]
pub(crate) struct MarketGenerator {
rng: ThreadRng,
sample_range: Uniform<f64>,
volatility: f64,
old_price: MarketModifier,
impl MarketModifier {
pub(super) fn modifier(&self) -> Decimal {
self.0
}
}
impl MarketGenerator {
pub(crate) fn new(volatility: f64, initial_value: f64) -> Self {
#[derive(Clone, Debug)]
pub(super) struct MarketMultiplier {
rng: StdRng,
sample_range: Uniform<i64>,
volatility: Decimal,
old_price: MarketModifier,
buffer: Vec<MarketModifier>,
}
impl MarketMultiplier {
pub(super) fn new(volatility: Decimal, initial_value: Decimal) -> Self {
Self {
rng: thread_rng(),
sample_range: Uniform::new_inclusive(-0.5, 0.5),
rng: StdRng::from_entropy(),
sample_range: Uniform::new_inclusive(-50, 50),
volatility,
old_price: MarketModifier(initial_value),
buffer: vec![],
}
}
}
impl Iterator for MarketGenerator {
impl Iterator for MarketMultiplier {
type Item = MarketModifier;
fn next(&mut self) -> Option<Self::Item> {
// Algorithm from https://stackoverflow.com/a/8597889
let rng_multiplier = self.sample_range.sample(&mut self.rng);
let delta = 2f64 * self.volatility * rng_multiplier;
let delta = Decimal::new(2, 0) * self.volatility * Decimal::new(rng_multiplier, 2);
self.old_price.0 += self.old_price.0 * delta;
Some(self.old_price)
}
}
impl MarketMultiplier {
pub(super) fn next_n(&mut self, n: usize) -> Vec<MarketModifier> {
self.ensure_buffer_capacity(n);
let mut ret = self.buffer.split_off(n);
std::mem::swap(&mut ret, &mut self.buffer);
ret
}
pub(super) fn peek_n(&mut self, n: usize) -> &[MarketModifier] {
self.ensure_buffer_capacity(n);
&self.buffer[..n]
}
fn ensure_buffer_capacity(&mut self, n: usize) {
for _ in 0..(n - self.buffer.len()) {
let next_val = self.next().unwrap();
self.buffer.push(next_val);
}
}
}

View File

@ -1 +1,41 @@
use generator::MarketMultiplier;
use rust_decimal::Decimal;
use tokio::sync::{Mutex, MutexGuard};
mod generator;
pub(crate) struct Market {
generator: Mutex<MarketMultiplier>,
}
impl Market {
pub(crate) fn new() -> Self {
Self {
generator: Mutex::new(MarketMultiplier::new(
Decimal::new(1, 1),
Decimal::new(1, 0),
)),
}
}
pub(crate) async fn lock(&self) -> LockedMarket<'_> {
LockedMarket(self.generator.lock().await, 0)
}
}
pub(crate) struct LockedMarket<'a>(MutexGuard<'a, MarketMultiplier>, usize);
impl<'a> LockedMarket<'a> {
pub(crate) fn peek_price(&mut self, initial_price: &Decimal, n: usize) -> Decimal {
let mut price = *initial_price;
for i in self.0.peek_n(n).iter() {
price *= i.modifier();
}
self.1 = n;
price.round_dp(2)
}
pub(crate) fn commit(mut self) {
self.0.next_n(self.1);
}
}

View File

@ -34,6 +34,7 @@ pub(crate) enum UserOperation {
}
#[derive(Deserialize, Debug, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum MarketOperation {
Buy { symbol: StockSymbol, amount: usize },
Sell { symbol: StockSymbol, amount: usize },

View File

@ -1,8 +1,9 @@
use crate::user::SaltedPassword;
use crate::{market::Market, user::SaltedPassword};
use log::{debug, info};
use rust_decimal::Decimal;
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword};
use sqlx::{query, PgPool};
use std::convert::TryFrom;
use std::{convert::TryFrom, sync::Arc};
use thiserror::Error;
use vtse_common::net::{ServerResponse, UserError};
use vtse_common::stock::{Stock, StockSymbol};
@ -12,12 +13,29 @@ use vtse_common::user::{ApiKey, Password, User, UserBalance, UserDebt, Username}
pub(crate) enum StateError {
#[error("Was not in the correct state")]
WrongState,
#[error("Got SQLx error: {0}`")]
#[error("Got SQLx error: {0}")]
Database(#[from] sqlx::Error),
#[error("Failed to hash password")]
PasswordHash,
}
#[derive(Error, Debug)]
pub(crate) enum UserOrDbError {
#[error("Got SQLx error: {0}")]
Database(#[from] sqlx::Error),
#[error("Got user error: {0}")]
User(UserError),
}
impl UserOrDbError {
fn into_operation_result(self) -> OperationResult {
match self {
UserOrDbError::Database(e) => Err(StateError::Database(e)),
UserOrDbError::User(e) => Ok(ServerResponse::UserError(e)),
}
}
}
type OperationResult = Result<ServerResponse, StateError>;
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
@ -43,11 +61,7 @@ impl AppState {
/// Query operations impl
impl AppState {
pub(crate) async fn stock_info(
&self,
stock_symbol: StockSymbol,
pool: &PgPool,
) -> OperationResult {
pub(crate) async fn stock_info(stock_symbol: StockSymbol, pool: &PgPool) -> OperationResult {
let stock = query!(
"SELECT name, symbol, description, price FROM stocks WHERE symbol = $1",
stock_symbol.inner()
@ -63,7 +77,7 @@ impl AppState {
)))
}
pub(crate) async fn user_info(&self, username: Username, pool: &PgPool) -> OperationResult {
pub(crate) async fn user_info(username: Username, pool: &PgPool) -> OperationResult {
let user = query!(
"SELECT balance, debt FROM users WHERE username = $1",
username.inner()
@ -114,9 +128,7 @@ impl AppState {
let salted_password =
SaltedPassword::try_from(password).map_err(|_| StateError::PasswordHash)?;
query!(
"INSERT INTO users
(username, pwhash_data)
VALUES ($1, $2)",
"INSERT INTO users (username, pwhash_data) VALUES ($1, $2)",
username.inner(),
salted_password.as_ref(),
)
@ -143,9 +155,7 @@ impl AppState {
let mut transaction = pool.begin().await?;
let api_key = uuid::Uuid::new_v4();
let api_key_id = query!(
"INSERT INTO api_keys (key)
VALUES ($1)
RETURNING api_keys.key_id;",
"INSERT INTO api_keys (key) VALUES ($1) RETURNING api_keys.key_id;",
&api_key,
)
.fetch_one(&mut transaction)
@ -153,8 +163,7 @@ impl AppState {
.key_id;
query!(
"INSERT INTO user_api_keys (user_id, key_id)
VALUES ($1, $2)",
"INSERT INTO user_api_keys (user_id, key_id) VALUES ($1, $2)",
user_id,
api_key_id
)
@ -173,8 +182,7 @@ impl AppState {
pool: &PgPool,
) -> Result<Option<i32>, StateError> {
let result = query!(
"SELECT user_id, pwhash_data FROM users
WHERE username = $1",
"SELECT user_id, pwhash_data FROM users WHERE username = $1",
username.inner()
)
.fetch_one(pool)
@ -199,14 +207,39 @@ impl AppState {
symbol: StockSymbol,
amount: usize,
pool: &PgPool,
market: &mut Arc<Market>,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let stock_id = Self::get_stock_id(&symbol, pool).await?;
let user_balance = Self::get_user_balance(id, pool).await?;
let mut market = market.lock().await;
let current_price = {
let response = Self::get_stock_price(symbol.clone(), pool).await;
match response {
Ok(v) => v,
Err(e) => return e.into_operation_result(),
}
};
let cost_to_purchase = market.peek_price(&current_price, amount);
todo!()
if user_balance >= cost_to_purchase {
let resp = self
.purchase_stock(stock_id, amount, cost_to_purchase, pool)
.await;
if resp.is_ok() {
info!(
"User {} will be purchasing {} {} securities for {}",
id, amount, symbol, cost_to_purchase
);
market.commit();
}
resp
} else {
todo!()
}
}
pub(crate) async fn sell(
@ -214,18 +247,33 @@ impl AppState {
symbol: StockSymbol,
amount: usize,
pool: &PgPool,
market: &mut Arc<Market>,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let num_owned = Self::get_owned_stock_count(id, symbol, pool).await?;
let stock_id = Self::get_stock_id(&symbol, pool).await?;
let num_owned = Self::get_owned_stock_count(id, stock_id, pool).await?;
if num_owned < amount {
return Ok(ServerResponse::UserError(UserError::NotEnoughOwnedStock(
num_owned,
)));
}
todo!()
let mut market = market.lock().await;
let current_price = {
let response = Self::get_stock_price(symbol.clone(), pool).await;
match response {
Ok(v) => v,
Err(e) => return e.into_operation_result(),
}
};
let gains = market.peek_price(&current_price, amount);
let response = self
.sell_stock(stock_id, gains, amount, num_owned == amount, pool)
.await;
market.commit();
response
}
// todo: fetch_one needs to turn into fetch_optional, else we can't discriminate user versus
@ -241,21 +289,130 @@ impl AppState {
async fn get_owned_stock_count(
user_id: i32,
symbol: StockSymbol,
stock_id: i32,
pool: &PgPool,
) -> Result<usize, StateError> {
query!(
"SELECT amount FROM users_stocks
JOIN stocks ON users_stocks.stock_id = stocks.stock_id
WHERE user_id = $1 AND symbol = $2",
"SELECT amount FROM users_stocks WHERE user_id = $1 AND stock_id = $2",
user_id,
symbol.inner()
stock_id
)
.fetch_one(pool)
.await
.map(|record| record.amount as usize)
.map_err(StateError::from)
}
async fn get_stock_price(symbol: StockSymbol, pool: &PgPool) -> Result<Decimal, UserOrDbError> {
let query = query!("SELECT price FROM stocks WHERE symbol = $1", symbol.inner())
.fetch_optional(pool)
.await;
match query {
Ok(Some(record)) => Ok(record.price),
Ok(None) => Err(UserOrDbError::User(UserError::InvalidStock(symbol))),
Err(e) => Err(UserOrDbError::Database(e)),
}
}
async fn purchase_stock(
&self,
symbol_id: i32,
amount: usize,
cost: Decimal,
pool: &PgPool,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
debug!(
"Starting transaction for id {}: {} {} at total cost {}",
id, amount, symbol_id, cost
);
let mut transaction = pool.begin().await?;
query!(
"UPDATE users SET balance = balance - $1 WHERE user_id = $2",
cost,
id
)
.execute(&mut transaction)
.await?;
query!(
"INSERT INTO users_stocks VALUES ($1, $2, $3)
ON CONFLICT ON CONSTRAINT users_stocks_user_id_stock_id_key DO
UPDATE SET amount = users_stocks.amount + $3::integer
WHERE users_stocks.user_id = $1",
id,
symbol_id,
amount as u32
)
.execute(&mut transaction)
.await?;
transaction.commit().await?;
Ok(ServerResponse::Success)
}
async fn get_stock_id(symbol: &StockSymbol, pool: &PgPool) -> Result<i32, StateError> {
query!(
"SELECT stock_id FROM stocks WHERE symbol = $1",
symbol.inner()
)
.fetch_one(pool)
.await
.map(|record| record.stock_id)
.map_err(StateError::from)
}
async fn sell_stock(
&self,
stock_id: i32,
gains: Decimal,
amount_sold: usize,
sell_all: bool,
pool: &PgPool,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let mut transaction = pool.begin().await?;
query!(
"UPDATE users SET balance = balance + $1::numeric WHERE user_id = $2",
gains,
id
)
.execute(&mut transaction)
.await?;
if sell_all {
query!(
"DELETE FROM users_stocks WHERE user_id = $1 AND stock_id = $2",
id,
stock_id
)
.execute(&mut transaction)
.await?;
} else {
query!(
"UPDATE users_stocks SET amount = amount - $1::integer
WHERE user_id = $2 AND stock_id = $3",
amount_sold as u32,
id,
stock_id,
)
.execute(&mut transaction)
.await?;
}
transaction.commit().await?;
Ok(ServerResponse::Success)
}
}
impl Default for AppState {