vtse/vtse-server/src/state.rs

423 lines
12 KiB
Rust

use crate::{market::Market, user::SaltedPassword};
use log::{debug, info};
use rust_decimal::Decimal;
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword};
use sqlx::{query, PgPool};
use std::{convert::TryFrom, sync::Arc};
use thiserror::Error;
use vtse_common::net::{ServerResponse, UserError};
use vtse_common::stock::{Stock, StockSymbol};
use vtse_common::user::{ApiKey, Password, User, UserBalance, UserDebt, Username};
#[derive(Error, Debug)]
pub(crate) enum StateError {
#[error("Was not in the correct state")]
WrongState,
#[error("Got SQLx error: {0}")]
Database(#[from] sqlx::Error),
#[error("Failed to hash password")]
PasswordHash,
}
#[derive(Error, Debug)]
pub(crate) enum UserOrDbError {
#[error("Got SQLx error: {0}")]
Database(#[from] sqlx::Error),
#[error("Got user error: {0}")]
User(UserError),
}
impl UserOrDbError {
fn into_operation_result(self) -> OperationResult {
match self {
UserOrDbError::Database(e) => Err(StateError::Database(e)),
UserOrDbError::User(e) => Ok(ServerResponse::UserError(e)),
}
}
}
type OperationResult = Result<ServerResponse, StateError>;
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
pub(crate) enum AppState {
Unauthorized,
Authorized { user_id: i32 },
}
/// Helper functions
impl AppState {
pub(crate) fn new() -> Self {
Self::Unauthorized
}
fn assert_state(&self, expected_state: AppState) -> Result<(), StateError> {
if *self != expected_state {
Err(StateError::WrongState)
} else {
Ok(())
}
}
}
/// Query operations impl
impl AppState {
pub(crate) async fn stock_info(stock_symbol: StockSymbol, pool: &PgPool) -> OperationResult {
let stock = query!(
"SELECT name, symbol, description, price FROM stocks WHERE symbol = $1",
stock_symbol.inner()
)
.fetch_one(pool)
.await?;
Ok(ServerResponse::StockInfo(Stock::new(
stock.name,
stock.symbol,
stock.description,
stock.price,
)))
}
pub(crate) async fn user_info(username: Username, pool: &PgPool) -> OperationResult {
let user = query!(
"SELECT balance, debt FROM users WHERE username = $1",
username.inner()
)
.fetch_optional(pool)
.await?;
match user {
Some(user) => Ok(ServerResponse::UserInfo(User::new(
username,
UserBalance::new(user.balance),
UserDebt::new(user.debt),
))),
None => Ok(ServerResponse::UserError(UserError::InvalidUsername)),
}
}
}
/// User operation implementation
impl AppState {
pub(crate) async fn login(&mut self, api_key: ApiKey, pool: &PgPool) -> OperationResult {
self.assert_state(AppState::Unauthorized)?;
let user_id = {
let query = query!(
"SELECT (user_id) FROM user_api_keys
JOIN api_keys ON user_api_keys.key_id = api_keys.key_id
WHERE key = $1",
api_key.inner()
)
.fetch_optional(pool)
.await?;
match query {
Some(id) => id.user_id,
None => return Ok(ServerResponse::UserError(UserError::InvalidApiKey)),
}
};
*self = AppState::Authorized { user_id };
Ok(ServerResponse::Success)
}
pub(crate) async fn register(
&mut self,
username: Username,
password: Password,
pool: &PgPool,
) -> OperationResult {
self.assert_state(AppState::Unauthorized)?;
let salted_password =
SaltedPassword::try_from(password).map_err(|_| StateError::PasswordHash)?;
query!(
"INSERT INTO users (username, pwhash_data) VALUES ($1, $2)",
username.inner(),
salted_password.as_ref(),
)
.execute(pool)
.await?;
Ok(ServerResponse::Success)
}
pub(crate) async fn generate_api_key(
&self,
username: Username,
password: Password,
pool: &PgPool,
) -> OperationResult {
self.assert_state(AppState::Unauthorized)?;
match self.validate_password(&username, &password, pool).await {
Ok(None) => Ok(ServerResponse::UserError(UserError::InvalidPassword)),
Ok(Some(user_id)) => self.create_api_key(user_id, pool).await,
Err(_) => todo!(),
}
}
async fn create_api_key(&self, user_id: i32, pool: &PgPool) -> OperationResult {
let mut transaction = pool.begin().await?;
let api_key = uuid::Uuid::new_v4();
let api_key_id = query!(
"INSERT INTO api_keys (key) VALUES ($1) RETURNING api_keys.key_id;",
&api_key,
)
.fetch_one(&mut transaction)
.await?
.key_id;
query!(
"INSERT INTO user_api_keys (user_id, key_id) VALUES ($1, $2)",
user_id,
api_key_id
)
.execute(&mut transaction)
.await?;
transaction.commit().await?;
Ok(ServerResponse::NewApiKey(api_key))
}
async fn validate_password(
&self,
username: &Username,
password: &Password,
pool: &PgPool,
) -> Result<Option<i32>, StateError> {
let result = query!(
"SELECT user_id, pwhash_data FROM users WHERE username = $1",
username.inner()
)
.fetch_one(pool)
.await?;
let login_attempt = HashedPassword::from_slice(&result.pwhash_data)
.map(|pw| pwhash_verify(&pw, password.as_bytes()))
.unwrap();
if login_attempt {
Ok(Some(result.user_id))
} else {
Ok(None)
}
}
}
/// Market operation implementation
impl AppState {
pub(crate) async fn buy(
&self,
symbol: StockSymbol,
amount: usize,
pool: &PgPool,
market: &mut Arc<Market>,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let stock_id = Self::get_stock_id(&symbol, pool).await?;
let user_balance = Self::get_user_balance(id, pool).await?;
let mut market = market.lock().await;
let current_price = {
let response = Self::get_stock_price(symbol.clone(), pool).await;
match response {
Ok(v) => v,
Err(e) => return e.into_operation_result(),
}
};
let cost_to_purchase = market.peek_price(&current_price, amount);
if user_balance >= cost_to_purchase {
let resp = self
.purchase_stock(stock_id, amount, cost_to_purchase, pool)
.await;
if resp.is_ok() {
info!(
"User {} will be purchasing {} {} securities for {}",
id, amount, symbol, cost_to_purchase
);
market.commit();
}
resp
} else {
todo!()
}
}
pub(crate) async fn sell(
&self,
symbol: StockSymbol,
amount: usize,
pool: &PgPool,
market: &mut Arc<Market>,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let stock_id = Self::get_stock_id(&symbol, pool).await?;
let num_owned = Self::get_owned_stock_count(id, stock_id, pool).await?;
if num_owned < amount {
return Ok(ServerResponse::UserError(UserError::NotEnoughOwnedStock(
num_owned,
)));
}
let mut market = market.lock().await;
let current_price = {
let response = Self::get_stock_price(symbol.clone(), pool).await;
match response {
Ok(v) => v,
Err(e) => return e.into_operation_result(),
}
};
let gains = market.peek_price(&current_price, amount);
let response = self
.sell_stock(stock_id, gains, amount, num_owned == amount, pool)
.await;
market.commit();
response
}
// todo: fetch_one needs to turn into fetch_optional, else we can't discriminate user versus
// server error.
async fn get_user_balance(user_id: i32, pool: &PgPool) -> Result<Decimal, StateError> {
query!("SELECT balance FROM users where user_id = $1", user_id)
.fetch_one(pool)
.await
.map(|record| record.balance)
.map_err(StateError::from)
}
async fn get_owned_stock_count(
user_id: i32,
stock_id: i32,
pool: &PgPool,
) -> Result<usize, StateError> {
query!(
"SELECT amount FROM users_stocks WHERE user_id = $1 AND stock_id = $2",
user_id,
stock_id
)
.fetch_one(pool)
.await
.map(|record| record.amount as usize)
.map_err(StateError::from)
}
async fn get_stock_price(symbol: StockSymbol, pool: &PgPool) -> Result<Decimal, UserOrDbError> {
let query = query!("SELECT price FROM stocks WHERE symbol = $1", symbol.inner())
.fetch_optional(pool)
.await;
match query {
Ok(Some(record)) => Ok(record.price),
Ok(None) => Err(UserOrDbError::User(UserError::InvalidStock(symbol))),
Err(e) => Err(UserOrDbError::Database(e)),
}
}
async fn purchase_stock(
&self,
symbol_id: i32,
amount: usize,
cost: Decimal,
pool: &PgPool,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
debug!(
"Starting transaction for id {}: {} {} at total cost {}",
id, amount, symbol_id, cost
);
let mut transaction = pool.begin().await?;
query!(
"UPDATE users SET balance = balance - $1 WHERE user_id = $2",
cost,
id
)
.execute(&mut transaction)
.await?;
query!(
"INSERT INTO users_stocks VALUES ($1, $2, $3)
ON CONFLICT ON CONSTRAINT users_stocks_user_id_stock_id_key DO
UPDATE SET amount = users_stocks.amount + $3::integer
WHERE users_stocks.user_id = $1",
id,
symbol_id,
amount as u32
)
.execute(&mut transaction)
.await?;
transaction.commit().await?;
Ok(ServerResponse::Success)
}
async fn get_stock_id(symbol: &StockSymbol, pool: &PgPool) -> Result<i32, StateError> {
query!(
"SELECT stock_id FROM stocks WHERE symbol = $1",
symbol.inner()
)
.fetch_one(pool)
.await
.map(|record| record.stock_id)
.map_err(StateError::from)
}
async fn sell_stock(
&self,
stock_id: i32,
gains: Decimal,
amount_sold: usize,
sell_all: bool,
pool: &PgPool,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let mut transaction = pool.begin().await?;
query!(
"UPDATE users SET balance = balance + $1::numeric WHERE user_id = $2",
gains,
id
)
.execute(&mut transaction)
.await?;
if sell_all {
query!(
"DELETE FROM users_stocks WHERE user_id = $1 AND stock_id = $2",
id,
stock_id
)
.execute(&mut transaction)
.await?;
} else {
query!(
"UPDATE users_stocks SET amount = amount - $1::integer
WHERE user_id = $2 AND stock_id = $3",
amount_sold as u32,
id,
stock_id,
)
.execute(&mut transaction)
.await?;
}
transaction.commit().await?;
Ok(ServerResponse::Success)
}
}
impl Default for AppState {
fn default() -> Self {
Self::new()
}
}