423 lines
12 KiB
Rust
423 lines
12 KiB
Rust
use crate::{market::Market, user::SaltedPassword};
|
|
use log::{debug, info};
|
|
use rust_decimal::Decimal;
|
|
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword};
|
|
use sqlx::{query, PgPool};
|
|
use std::{convert::TryFrom, sync::Arc};
|
|
use thiserror::Error;
|
|
use vtse_common::net::{ServerResponse, UserError};
|
|
use vtse_common::stock::{Stock, StockSymbol};
|
|
use vtse_common::user::{ApiKey, Password, User, UserBalance, UserDebt, Username};
|
|
|
|
#[derive(Error, Debug)]
|
|
pub(crate) enum StateError {
|
|
#[error("Was not in the correct state")]
|
|
WrongState,
|
|
#[error("Got SQLx error: {0}")]
|
|
Database(#[from] sqlx::Error),
|
|
#[error("Failed to hash password")]
|
|
PasswordHash,
|
|
}
|
|
|
|
#[derive(Error, Debug)]
|
|
pub(crate) enum UserOrDbError {
|
|
#[error("Got SQLx error: {0}")]
|
|
Database(#[from] sqlx::Error),
|
|
#[error("Got user error: {0}")]
|
|
User(UserError),
|
|
}
|
|
|
|
impl UserOrDbError {
|
|
fn into_operation_result(self) -> OperationResult {
|
|
match self {
|
|
UserOrDbError::Database(e) => Err(StateError::Database(e)),
|
|
UserOrDbError::User(e) => Ok(ServerResponse::UserError(e)),
|
|
}
|
|
}
|
|
}
|
|
|
|
type OperationResult = Result<ServerResponse, StateError>;
|
|
|
|
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
|
|
pub(crate) enum AppState {
|
|
Unauthorized,
|
|
Authorized { user_id: i32 },
|
|
}
|
|
|
|
/// Helper functions
|
|
impl AppState {
|
|
pub(crate) fn new() -> Self {
|
|
Self::Unauthorized
|
|
}
|
|
|
|
fn assert_state(&self, expected_state: AppState) -> Result<(), StateError> {
|
|
if *self != expected_state {
|
|
Err(StateError::WrongState)
|
|
} else {
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Query operations impl
|
|
impl AppState {
|
|
pub(crate) async fn stock_info(stock_symbol: StockSymbol, pool: &PgPool) -> OperationResult {
|
|
let stock = query!(
|
|
"SELECT name, symbol, description, price FROM stocks WHERE symbol = $1",
|
|
stock_symbol.inner()
|
|
)
|
|
.fetch_one(pool)
|
|
.await?;
|
|
|
|
Ok(ServerResponse::StockInfo(Stock::new(
|
|
stock.name,
|
|
stock.symbol,
|
|
stock.description,
|
|
stock.price,
|
|
)))
|
|
}
|
|
|
|
pub(crate) async fn user_info(username: Username, pool: &PgPool) -> OperationResult {
|
|
let user = query!(
|
|
"SELECT balance, debt FROM users WHERE username = $1",
|
|
username.inner()
|
|
)
|
|
.fetch_optional(pool)
|
|
.await?;
|
|
|
|
match user {
|
|
Some(user) => Ok(ServerResponse::UserInfo(User::new(
|
|
username,
|
|
UserBalance::new(user.balance),
|
|
UserDebt::new(user.debt),
|
|
))),
|
|
None => Ok(ServerResponse::UserError(UserError::InvalidUsername)),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// User operation implementation
|
|
impl AppState {
|
|
pub(crate) async fn login(&mut self, api_key: ApiKey, pool: &PgPool) -> OperationResult {
|
|
self.assert_state(AppState::Unauthorized)?;
|
|
let user_id = {
|
|
let query = query!(
|
|
"SELECT (user_id) FROM user_api_keys
|
|
JOIN api_keys ON user_api_keys.key_id = api_keys.key_id
|
|
WHERE key = $1",
|
|
api_key.inner()
|
|
)
|
|
.fetch_optional(pool)
|
|
.await?;
|
|
match query {
|
|
Some(id) => id.user_id,
|
|
None => return Ok(ServerResponse::UserError(UserError::InvalidApiKey)),
|
|
}
|
|
};
|
|
*self = AppState::Authorized { user_id };
|
|
Ok(ServerResponse::Success)
|
|
}
|
|
|
|
pub(crate) async fn register(
|
|
&mut self,
|
|
username: Username,
|
|
password: Password,
|
|
pool: &PgPool,
|
|
) -> OperationResult {
|
|
self.assert_state(AppState::Unauthorized)?;
|
|
let salted_password =
|
|
SaltedPassword::try_from(password).map_err(|_| StateError::PasswordHash)?;
|
|
query!(
|
|
"INSERT INTO users (username, pwhash_data) VALUES ($1, $2)",
|
|
username.inner(),
|
|
salted_password.as_ref(),
|
|
)
|
|
.execute(pool)
|
|
.await?;
|
|
Ok(ServerResponse::Success)
|
|
}
|
|
|
|
pub(crate) async fn generate_api_key(
|
|
&self,
|
|
username: Username,
|
|
password: Password,
|
|
pool: &PgPool,
|
|
) -> OperationResult {
|
|
self.assert_state(AppState::Unauthorized)?;
|
|
match self.validate_password(&username, &password, pool).await {
|
|
Ok(None) => Ok(ServerResponse::UserError(UserError::InvalidPassword)),
|
|
Ok(Some(user_id)) => self.create_api_key(user_id, pool).await,
|
|
Err(_) => todo!(),
|
|
}
|
|
}
|
|
|
|
async fn create_api_key(&self, user_id: i32, pool: &PgPool) -> OperationResult {
|
|
let mut transaction = pool.begin().await?;
|
|
let api_key = uuid::Uuid::new_v4();
|
|
let api_key_id = query!(
|
|
"INSERT INTO api_keys (key) VALUES ($1) RETURNING api_keys.key_id;",
|
|
&api_key,
|
|
)
|
|
.fetch_one(&mut transaction)
|
|
.await?
|
|
.key_id;
|
|
|
|
query!(
|
|
"INSERT INTO user_api_keys (user_id, key_id) VALUES ($1, $2)",
|
|
user_id,
|
|
api_key_id
|
|
)
|
|
.execute(&mut transaction)
|
|
.await?;
|
|
|
|
transaction.commit().await?;
|
|
|
|
Ok(ServerResponse::NewApiKey(api_key))
|
|
}
|
|
|
|
async fn validate_password(
|
|
&self,
|
|
username: &Username,
|
|
password: &Password,
|
|
pool: &PgPool,
|
|
) -> Result<Option<i32>, StateError> {
|
|
let result = query!(
|
|
"SELECT user_id, pwhash_data FROM users WHERE username = $1",
|
|
username.inner()
|
|
)
|
|
.fetch_one(pool)
|
|
.await?;
|
|
|
|
let login_attempt = HashedPassword::from_slice(&result.pwhash_data)
|
|
.map(|pw| pwhash_verify(&pw, password.as_bytes()))
|
|
.unwrap();
|
|
|
|
if login_attempt {
|
|
Ok(Some(result.user_id))
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Market operation implementation
|
|
impl AppState {
|
|
pub(crate) async fn buy(
|
|
&self,
|
|
symbol: StockSymbol,
|
|
amount: usize,
|
|
pool: &PgPool,
|
|
market: &mut Arc<Market>,
|
|
) -> OperationResult {
|
|
let id = *match self {
|
|
Self::Authorized { user_id } => user_id,
|
|
_ => return Err(StateError::WrongState),
|
|
};
|
|
let stock_id = Self::get_stock_id(&symbol, pool).await?;
|
|
let user_balance = Self::get_user_balance(id, pool).await?;
|
|
let mut market = market.lock().await;
|
|
let current_price = {
|
|
let response = Self::get_stock_price(symbol.clone(), pool).await;
|
|
match response {
|
|
Ok(v) => v,
|
|
Err(e) => return e.into_operation_result(),
|
|
}
|
|
};
|
|
let cost_to_purchase = market.peek_price(¤t_price, amount);
|
|
|
|
if user_balance >= cost_to_purchase {
|
|
let resp = self
|
|
.purchase_stock(stock_id, amount, cost_to_purchase, pool)
|
|
.await;
|
|
if resp.is_ok() {
|
|
info!(
|
|
"User {} will be purchasing {} {} securities for {}",
|
|
id, amount, symbol, cost_to_purchase
|
|
);
|
|
market.commit();
|
|
}
|
|
resp
|
|
} else {
|
|
todo!()
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn sell(
|
|
&self,
|
|
symbol: StockSymbol,
|
|
amount: usize,
|
|
pool: &PgPool,
|
|
market: &mut Arc<Market>,
|
|
) -> OperationResult {
|
|
let id = *match self {
|
|
Self::Authorized { user_id } => user_id,
|
|
_ => return Err(StateError::WrongState),
|
|
};
|
|
let stock_id = Self::get_stock_id(&symbol, pool).await?;
|
|
let num_owned = Self::get_owned_stock_count(id, stock_id, pool).await?;
|
|
if num_owned < amount {
|
|
return Ok(ServerResponse::UserError(UserError::NotEnoughOwnedStock(
|
|
num_owned,
|
|
)));
|
|
}
|
|
let mut market = market.lock().await;
|
|
let current_price = {
|
|
let response = Self::get_stock_price(symbol.clone(), pool).await;
|
|
match response {
|
|
Ok(v) => v,
|
|
Err(e) => return e.into_operation_result(),
|
|
}
|
|
};
|
|
let gains = market.peek_price(¤t_price, amount);
|
|
let response = self
|
|
.sell_stock(stock_id, gains, amount, num_owned == amount, pool)
|
|
.await;
|
|
market.commit();
|
|
response
|
|
}
|
|
|
|
// todo: fetch_one needs to turn into fetch_optional, else we can't discriminate user versus
|
|
// server error.
|
|
|
|
async fn get_user_balance(user_id: i32, pool: &PgPool) -> Result<Decimal, StateError> {
|
|
query!("SELECT balance FROM users where user_id = $1", user_id)
|
|
.fetch_one(pool)
|
|
.await
|
|
.map(|record| record.balance)
|
|
.map_err(StateError::from)
|
|
}
|
|
|
|
async fn get_owned_stock_count(
|
|
user_id: i32,
|
|
stock_id: i32,
|
|
pool: &PgPool,
|
|
) -> Result<usize, StateError> {
|
|
query!(
|
|
"SELECT amount FROM users_stocks WHERE user_id = $1 AND stock_id = $2",
|
|
user_id,
|
|
stock_id
|
|
)
|
|
.fetch_one(pool)
|
|
.await
|
|
.map(|record| record.amount as usize)
|
|
.map_err(StateError::from)
|
|
}
|
|
|
|
async fn get_stock_price(symbol: StockSymbol, pool: &PgPool) -> Result<Decimal, UserOrDbError> {
|
|
let query = query!("SELECT price FROM stocks WHERE symbol = $1", symbol.inner())
|
|
.fetch_optional(pool)
|
|
.await;
|
|
match query {
|
|
Ok(Some(record)) => Ok(record.price),
|
|
Ok(None) => Err(UserOrDbError::User(UserError::InvalidStock(symbol))),
|
|
Err(e) => Err(UserOrDbError::Database(e)),
|
|
}
|
|
}
|
|
|
|
async fn purchase_stock(
|
|
&self,
|
|
symbol_id: i32,
|
|
amount: usize,
|
|
cost: Decimal,
|
|
pool: &PgPool,
|
|
) -> OperationResult {
|
|
let id = *match self {
|
|
Self::Authorized { user_id } => user_id,
|
|
_ => return Err(StateError::WrongState),
|
|
};
|
|
|
|
debug!(
|
|
"Starting transaction for id {}: {} {} at total cost {}",
|
|
id, amount, symbol_id, cost
|
|
);
|
|
|
|
let mut transaction = pool.begin().await?;
|
|
query!(
|
|
"UPDATE users SET balance = balance - $1 WHERE user_id = $2",
|
|
cost,
|
|
id
|
|
)
|
|
.execute(&mut transaction)
|
|
.await?;
|
|
|
|
query!(
|
|
"INSERT INTO users_stocks VALUES ($1, $2, $3)
|
|
ON CONFLICT ON CONSTRAINT users_stocks_user_id_stock_id_key DO
|
|
UPDATE SET amount = users_stocks.amount + $3::integer
|
|
WHERE users_stocks.user_id = $1",
|
|
id,
|
|
symbol_id,
|
|
amount as u32
|
|
)
|
|
.execute(&mut transaction)
|
|
.await?;
|
|
|
|
transaction.commit().await?;
|
|
|
|
Ok(ServerResponse::Success)
|
|
}
|
|
|
|
async fn get_stock_id(symbol: &StockSymbol, pool: &PgPool) -> Result<i32, StateError> {
|
|
query!(
|
|
"SELECT stock_id FROM stocks WHERE symbol = $1",
|
|
symbol.inner()
|
|
)
|
|
.fetch_one(pool)
|
|
.await
|
|
.map(|record| record.stock_id)
|
|
.map_err(StateError::from)
|
|
}
|
|
|
|
async fn sell_stock(
|
|
&self,
|
|
stock_id: i32,
|
|
gains: Decimal,
|
|
amount_sold: usize,
|
|
sell_all: bool,
|
|
pool: &PgPool,
|
|
) -> OperationResult {
|
|
let id = *match self {
|
|
Self::Authorized { user_id } => user_id,
|
|
_ => return Err(StateError::WrongState),
|
|
};
|
|
let mut transaction = pool.begin().await?;
|
|
|
|
query!(
|
|
"UPDATE users SET balance = balance + $1::numeric WHERE user_id = $2",
|
|
gains,
|
|
id
|
|
)
|
|
.execute(&mut transaction)
|
|
.await?;
|
|
|
|
if sell_all {
|
|
query!(
|
|
"DELETE FROM users_stocks WHERE user_id = $1 AND stock_id = $2",
|
|
id,
|
|
stock_id
|
|
)
|
|
.execute(&mut transaction)
|
|
.await?;
|
|
} else {
|
|
query!(
|
|
"UPDATE users_stocks SET amount = amount - $1::integer
|
|
WHERE user_id = $2 AND stock_id = $3",
|
|
amount_sold as u32,
|
|
id,
|
|
stock_id,
|
|
)
|
|
.execute(&mut transaction)
|
|
.await?;
|
|
}
|
|
|
|
transaction.commit().await?;
|
|
Ok(ServerResponse::Success)
|
|
}
|
|
}
|
|
|
|
impl Default for AppState {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|