Compare commits

..

No commits in common. "e6ff1576f118d0146e5ad53b2ac980502e1fe428" and "bca934ec50cce0488a1bcff8e1ddb40e27aef658" have entirely different histories.

10 changed files with 42 additions and 162 deletions

1
Cargo.lock generated
View file

@ -1376,7 +1376,6 @@ dependencies = [
"dotenv", "dotenv",
"log", "log",
"rand 0.8.3", "rand 0.8.3",
"rust_decimal",
"serde", "serde",
"serde_json", "serde_json",
"simple_logger", "simple_logger",

View file

@ -1,4 +1,4 @@
use crate::{stock::Stock, user::User}; use crate::stock::Stock;
use serde::Serialize; use serde::Serialize;
use uuid::Uuid; use uuid::Uuid;
@ -10,14 +10,11 @@ pub enum ServerResponse {
NewApiKey(Uuid), NewApiKey(Uuid),
UserError(UserError), UserError(UserError),
StockInfo(Stock), StockInfo(Stock),
UserInfo(User),
} }
#[derive(Serialize)] #[derive(Serialize)]
pub enum UserError { pub enum UserError {
InvalidUsername,
InvalidPassword, InvalidPassword,
InvalidApiKey, InvalidApiKey,
NotAuthorized, NotAuthorized,
NotEnoughOwnedStock(usize),
} }

View file

@ -1,3 +1,5 @@
use crate::stock::StockName;
pub(crate) enum MarketOperation { pub(crate) enum MarketOperation {
Buy, Buy,
Sell, Sell,

View file

@ -3,8 +3,11 @@ use serde::{Deserialize, Serialize};
use std::str::FromStr; use std::str::FromStr;
use thiserror::Error; use thiserror::Error;
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)] #[derive(
Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, sqlx::Type,
)]
#[serde(transparent)] #[serde(transparent)]
#[sqlx(transparent)]
pub struct StockName(String); pub struct StockName(String);
impl From<String> for StockName { impl From<String> for StockName {
@ -34,20 +37,6 @@ impl From<String> for StockSymbol {
} }
} }
impl FromStr for StockSymbol {
type Err = StockNameParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(s.to_string()))
}
}
impl StockSymbol {
pub fn inner(&self) -> &str {
&self.0
}
}
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)] #[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)]
#[serde(transparent)] #[serde(transparent)]
pub struct StockDescription(String); pub struct StockDescription(String);

View file

@ -1,6 +1,4 @@
use rust_decimal::Decimal;
use serde::Deserialize; use serde::Deserialize;
use serde::Serialize;
use uuid::Uuid; use uuid::Uuid;
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)] #[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)]
@ -12,8 +10,7 @@ impl ApiKey {
} }
} }
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize, Serialize)] #[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)]
#[serde(transparent)]
pub struct Username(String); pub struct Username(String);
impl Username { impl Username {
@ -23,7 +20,6 @@ impl Username {
} }
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)] #[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)]
#[serde(transparent)]
pub struct Password(String); pub struct Password(String);
impl Password { impl Password {
@ -31,40 +27,3 @@ impl Password {
self.0.as_bytes() self.0.as_bytes()
} }
} }
#[derive(Serialize)]
#[serde(transparent)]
pub struct UserBalance(Decimal);
impl UserBalance {
pub fn new(decimal: Decimal) -> Self {
Self(decimal)
}
}
#[derive(Serialize)]
#[serde(transparent)]
pub struct UserDebt(Decimal);
impl UserDebt {
pub fn new(decimal: Decimal) -> Self {
Self(decimal)
}
}
#[derive(Serialize)]
pub struct User {
username: Username,
balance: UserBalance,
debt: UserDebt,
}
impl User {
pub fn new(username: Username, balance: UserBalance, debt: UserDebt) -> Self {
Self {
username,
balance,
debt,
}
}
}

View file

@ -10,7 +10,6 @@ bytes = "1"
dotenv = "0.15" dotenv = "0.15"
log = "0.4" log = "0.4"
rand = "0.8" rand = "0.8"
rust_decimal = "1.10.2"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"
simple_logger = "1.11" simple_logger = "1.11"

View file

@ -36,9 +36,13 @@ async fn main() -> Result<()> {
// clone simply clones the arc reference, so this is cheap. // clone simply clones the arc reference, so this is cheap.
let db_pool = db_pool.clone(); let db_pool = db_pool.clone();
tokio::task::spawn(async { tokio::task::spawn(async {
if let Err(e) = handle_stream(stream, db_pool).await { match handle_stream(stream, db_pool).await {
Ok(_) => (),
Err(e) => {
// stream.write_all(e);
error!("{}", e); error!("{}", e);
} }
}
}); });
} }
@ -61,11 +65,10 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
let data = buffer.split_to(bytes_read); // O(1) let data = buffer.split_to(bytes_read); // O(1)
let parsed = serde_json::from_slice::<ServerOperation>(&data)?; let parsed = serde_json::from_slice::<ServerOperation>(&data)?;
debug!("Parsed operation: {:?}", parsed); debug!("Parsed operation: {:?}", parsed);
let response = match parsed { let response = match parsed {
ServerOperation::Query(op) => match op { ServerOperation::Query(op) => match op {
QueryOperation::StockInfo { stock } => state.stock_info(stock, &pool).await?, QueryOperation::StockInfo { stock } => state.stock_info(stock, &pool).await?,
QueryOperation::User { username } => state.user_info(username, &pool).await?, QueryOperation::User(username) => state.user_info(username, &pool).await?,
}, },
ServerOperation::User(op) => match op { ServerOperation::User(op) => match op {
UserOperation::Login { api_key } => state.login(api_key, &pool).await?, UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
@ -77,10 +80,8 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
} }
}, },
ServerOperation::Market(op) => match op { ServerOperation::Market(op) => match op {
MarketOperation::Buy { symbol, amount } => state.buy(symbol, amount, &pool).await?, MarketOperation::Buy(stock_name) => state.buy(stock_name, &pool)?,
MarketOperation::Sell { symbol, amount } => { MarketOperation::Sell(stock_name) => state.sell(stock_name, &pool)?,
state.sell(symbol, amount, &pool).await?
}
}, },
}; };

View file

@ -1,5 +1,5 @@
use serde::Deserialize; use serde::Deserialize;
use vtse_common::stock::StockSymbol; use vtse_common::stock::StockName;
use vtse_common::user::{ApiKey, Password, Username}; use vtse_common::user::{ApiKey, Password, Username};
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
@ -13,8 +13,8 @@ pub(crate) enum ServerOperation {
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum QueryOperation { pub(crate) enum QueryOperation {
StockInfo { stock: StockSymbol }, StockInfo { stock: StockName },
User { username: Username }, User(Username),
} }
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
@ -35,8 +35,8 @@ pub(crate) enum UserOperation {
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
pub(crate) enum MarketOperation { pub(crate) enum MarketOperation {
Buy { symbol: StockSymbol, amount: usize }, Buy(StockName),
Sell { symbol: StockSymbol, amount: usize }, Sell(StockName),
} }
#[cfg(test)] #[cfg(test)]
@ -54,7 +54,7 @@ mod deserialize {
})) }))
.unwrap(), .unwrap(),
ServerOperation::Query(QueryOperation::StockInfo { ServerOperation::Query(QueryOperation::StockInfo {
stock: StockSymbol::from_str("Gura").unwrap() stock: StockName::from_str("Gura").unwrap()
}) })
) )
} }
@ -68,7 +68,7 @@ mod deserialize {
})) }))
.unwrap(), .unwrap(),
QueryOperation::StockInfo { QueryOperation::StockInfo {
stock: StockSymbol::from_str("Gura").unwrap() stock: StockName::from_str("Gura").unwrap()
} }
) )
} }

View file

@ -1,12 +1,11 @@
use crate::user::SaltedPassword; use crate::user::SaltedPassword;
use rust_decimal::Decimal;
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword}; use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword};
use sqlx::{query, PgPool}; use sqlx::{query, PgPool};
use std::convert::TryFrom; use std::convert::TryFrom;
use thiserror::Error; use thiserror::Error;
use vtse_common::net::{ServerResponse, UserError}; use vtse_common::net::{ServerResponse, UserError};
use vtse_common::stock::{Stock, StockSymbol}; use vtse_common::stock::{Stock, StockName};
use vtse_common::user::{ApiKey, Password, User, UserBalance, UserDebt, Username}; use vtse_common::user::{ApiKey, Password, Username};
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub(crate) enum StateError { pub(crate) enum StateError {
@ -43,14 +42,10 @@ impl AppState {
/// Query operations impl /// Query operations impl
impl AppState { impl AppState {
pub(crate) async fn stock_info( pub(crate) async fn stock_info(&self, stock_name: StockName, pool: &PgPool) -> OperationResult {
&self,
stock_symbol: StockSymbol,
pool: &PgPool,
) -> OperationResult {
let stock = query!( let stock = query!(
"SELECT name, symbol, description, price FROM stocks WHERE symbol = $1", "SELECT name, symbol, description, price FROM stocks WHERE name = $1",
stock_symbol.inner() stock_name as StockName
) )
.fetch_one(pool) .fetch_one(pool)
.await?; .await?;
@ -64,21 +59,7 @@ impl AppState {
} }
pub(crate) async fn user_info(&self, username: Username, pool: &PgPool) -> OperationResult { pub(crate) async fn user_info(&self, username: Username, pool: &PgPool) -> OperationResult {
let user = query!( todo!()
"SELECT balance, debt FROM users WHERE username = $1",
username.inner()
)
.fetch_optional(pool)
.await?;
match user {
Some(user) => Ok(ServerResponse::UserInfo(User::new(
username,
UserBalance::new(user.balance),
UserDebt::new(user.debt),
))),
None => Ok(ServerResponse::UserError(UserError::InvalidUsername)),
}
} }
} }
@ -194,67 +175,18 @@ impl AppState {
/// Market operation implementation /// Market operation implementation
impl AppState { impl AppState {
pub(crate) async fn buy( pub(crate) fn buy(&self, stock_name: StockName, pool: &PgPool) -> OperationResult {
&self, if !matches!(self, Self::Authorized{..}) {
symbol: StockSymbol, return Err(StateError::WrongState);
amount: usize,
pool: &PgPool,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let user_balance = Self::get_user_balance(id, pool).await?;
todo!()
}
pub(crate) async fn sell(
&self,
symbol: StockSymbol,
amount: usize,
pool: &PgPool,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let num_owned = Self::get_owned_stock_count(id, symbol, pool).await?;
if num_owned < amount {
return Ok(ServerResponse::UserError(UserError::NotEnoughOwnedStock(
num_owned,
)));
} }
todo!() todo!()
} }
// todo: fetch_one needs to turn into fetch_optional, else we can't discriminate user versus pub(crate) fn sell(&self, stock_name: StockName, pool: &PgPool) -> OperationResult {
// server error. if !matches!(self, Self::Authorized{..}) {
return Err(StateError::WrongState);
async fn get_user_balance(user_id: i32, pool: &PgPool) -> Result<Decimal, StateError> {
query!("SELECT balance FROM users where user_id = $1", user_id)
.fetch_one(pool)
.await
.map(|record| record.balance)
.map_err(StateError::from)
} }
todo!()
async fn get_owned_stock_count(
user_id: i32,
symbol: StockSymbol,
pool: &PgPool,
) -> Result<usize, StateError> {
query!(
"SELECT amount FROM users_stocks
JOIN stocks ON users_stocks.stock_id = stocks.stock_id
WHERE user_id = $1 AND symbol = $2",
user_id,
symbol.inner()
)
.fetch_one(pool)
.await
.map(|record| record.amount as usize)
.map_err(StateError::from)
} }
} }

View file

@ -1,5 +1,7 @@
use sodiumoxide::crypto::pwhash::argon2id13::HashedPassword; use sodiumoxide::crypto::pwhash::{
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash, MEMLIMIT_INTERACTIVE, OPSLIMIT_INTERACTIVE}; argon2id13::HashedPassword,
argon2id13::{pwhash, MEMLIMIT_INTERACTIVE, OPSLIMIT_INTERACTIVE},
};
use std::convert::TryFrom; use std::convert::TryFrom;
use vtse_common::user::Password; use vtse_common::user::Password;