Compare commits

...

2 commits

Author SHA1 Message Date
e6ff1576f1
add fetch_one note 2021-02-08 20:59:35 -05:00
f560236f8f
more work 2021-02-08 20:54:00 -05:00
10 changed files with 162 additions and 42 deletions

1
Cargo.lock generated
View file

@ -1376,6 +1376,7 @@ dependencies = [
"dotenv", "dotenv",
"log", "log",
"rand 0.8.3", "rand 0.8.3",
"rust_decimal",
"serde", "serde",
"serde_json", "serde_json",
"simple_logger", "simple_logger",

View file

@ -1,4 +1,4 @@
use crate::stock::Stock; use crate::{stock::Stock, user::User};
use serde::Serialize; use serde::Serialize;
use uuid::Uuid; use uuid::Uuid;
@ -10,11 +10,14 @@ pub enum ServerResponse {
NewApiKey(Uuid), NewApiKey(Uuid),
UserError(UserError), UserError(UserError),
StockInfo(Stock), StockInfo(Stock),
UserInfo(User),
} }
#[derive(Serialize)] #[derive(Serialize)]
pub enum UserError { pub enum UserError {
InvalidUsername,
InvalidPassword, InvalidPassword,
InvalidApiKey, InvalidApiKey,
NotAuthorized, NotAuthorized,
NotEnoughOwnedStock(usize),
} }

View file

@ -1,5 +1,3 @@
use crate::stock::StockName;
pub(crate) enum MarketOperation { pub(crate) enum MarketOperation {
Buy, Buy,
Sell, Sell,

View file

@ -3,11 +3,8 @@ use serde::{Deserialize, Serialize};
use std::str::FromStr; use std::str::FromStr;
use thiserror::Error; use thiserror::Error;
#[derive( #[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)]
Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, sqlx::Type,
)]
#[serde(transparent)] #[serde(transparent)]
#[sqlx(transparent)]
pub struct StockName(String); pub struct StockName(String);
impl From<String> for StockName { impl From<String> for StockName {
@ -37,6 +34,20 @@ impl From<String> for StockSymbol {
} }
} }
impl FromStr for StockSymbol {
type Err = StockNameParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(s.to_string()))
}
}
impl StockSymbol {
pub fn inner(&self) -> &str {
&self.0
}
}
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)] #[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)]
#[serde(transparent)] #[serde(transparent)]
pub struct StockDescription(String); pub struct StockDescription(String);

View file

@ -1,4 +1,6 @@
use rust_decimal::Decimal;
use serde::Deserialize; use serde::Deserialize;
use serde::Serialize;
use uuid::Uuid; use uuid::Uuid;
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)] #[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)]
@ -10,7 +12,8 @@ impl ApiKey {
} }
} }
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)] #[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize, Serialize)]
#[serde(transparent)]
pub struct Username(String); pub struct Username(String);
impl Username { impl Username {
@ -20,6 +23,7 @@ impl Username {
} }
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)] #[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)]
#[serde(transparent)]
pub struct Password(String); pub struct Password(String);
impl Password { impl Password {
@ -27,3 +31,40 @@ impl Password {
self.0.as_bytes() self.0.as_bytes()
} }
} }
#[derive(Serialize)]
#[serde(transparent)]
pub struct UserBalance(Decimal);
impl UserBalance {
pub fn new(decimal: Decimal) -> Self {
Self(decimal)
}
}
#[derive(Serialize)]
#[serde(transparent)]
pub struct UserDebt(Decimal);
impl UserDebt {
pub fn new(decimal: Decimal) -> Self {
Self(decimal)
}
}
#[derive(Serialize)]
pub struct User {
username: Username,
balance: UserBalance,
debt: UserDebt,
}
impl User {
pub fn new(username: Username, balance: UserBalance, debt: UserDebt) -> Self {
Self {
username,
balance,
debt,
}
}
}

View file

@ -10,6 +10,7 @@ bytes = "1"
dotenv = "0.15" dotenv = "0.15"
log = "0.4" log = "0.4"
rand = "0.8" rand = "0.8"
rust_decimal = "1.10.2"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"
simple_logger = "1.11" simple_logger = "1.11"

View file

@ -36,13 +36,9 @@ async fn main() -> Result<()> {
// clone simply clones the arc reference, so this is cheap. // clone simply clones the arc reference, so this is cheap.
let db_pool = db_pool.clone(); let db_pool = db_pool.clone();
tokio::task::spawn(async { tokio::task::spawn(async {
match handle_stream(stream, db_pool).await { if let Err(e) = handle_stream(stream, db_pool).await {
Ok(_) => (),
Err(e) => {
// stream.write_all(e);
error!("{}", e); error!("{}", e);
} }
}
}); });
} }
@ -65,10 +61,11 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
let data = buffer.split_to(bytes_read); // O(1) let data = buffer.split_to(bytes_read); // O(1)
let parsed = serde_json::from_slice::<ServerOperation>(&data)?; let parsed = serde_json::from_slice::<ServerOperation>(&data)?;
debug!("Parsed operation: {:?}", parsed); debug!("Parsed operation: {:?}", parsed);
let response = match parsed { let response = match parsed {
ServerOperation::Query(op) => match op { ServerOperation::Query(op) => match op {
QueryOperation::StockInfo { stock } => state.stock_info(stock, &pool).await?, QueryOperation::StockInfo { stock } => state.stock_info(stock, &pool).await?,
QueryOperation::User(username) => state.user_info(username, &pool).await?, QueryOperation::User { username } => state.user_info(username, &pool).await?,
}, },
ServerOperation::User(op) => match op { ServerOperation::User(op) => match op {
UserOperation::Login { api_key } => state.login(api_key, &pool).await?, UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
@ -80,8 +77,10 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
} }
}, },
ServerOperation::Market(op) => match op { ServerOperation::Market(op) => match op {
MarketOperation::Buy(stock_name) => state.buy(stock_name, &pool)?, MarketOperation::Buy { symbol, amount } => state.buy(symbol, amount, &pool).await?,
MarketOperation::Sell(stock_name) => state.sell(stock_name, &pool)?, MarketOperation::Sell { symbol, amount } => {
state.sell(symbol, amount, &pool).await?
}
}, },
}; };

View file

@ -1,5 +1,5 @@
use serde::Deserialize; use serde::Deserialize;
use vtse_common::stock::StockName; use vtse_common::stock::StockSymbol;
use vtse_common::user::{ApiKey, Password, Username}; use vtse_common::user::{ApiKey, Password, Username};
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
@ -13,8 +13,8 @@ pub(crate) enum ServerOperation {
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum QueryOperation { pub(crate) enum QueryOperation {
StockInfo { stock: StockName }, StockInfo { stock: StockSymbol },
User(Username), User { username: Username },
} }
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
@ -35,8 +35,8 @@ pub(crate) enum UserOperation {
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
pub(crate) enum MarketOperation { pub(crate) enum MarketOperation {
Buy(StockName), Buy { symbol: StockSymbol, amount: usize },
Sell(StockName), Sell { symbol: StockSymbol, amount: usize },
} }
#[cfg(test)] #[cfg(test)]
@ -54,7 +54,7 @@ mod deserialize {
})) }))
.unwrap(), .unwrap(),
ServerOperation::Query(QueryOperation::StockInfo { ServerOperation::Query(QueryOperation::StockInfo {
stock: StockName::from_str("Gura").unwrap() stock: StockSymbol::from_str("Gura").unwrap()
}) })
) )
} }
@ -68,7 +68,7 @@ mod deserialize {
})) }))
.unwrap(), .unwrap(),
QueryOperation::StockInfo { QueryOperation::StockInfo {
stock: StockName::from_str("Gura").unwrap() stock: StockSymbol::from_str("Gura").unwrap()
} }
) )
} }

View file

@ -1,11 +1,12 @@
use crate::user::SaltedPassword; use crate::user::SaltedPassword;
use rust_decimal::Decimal;
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword}; use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword};
use sqlx::{query, PgPool}; use sqlx::{query, PgPool};
use std::convert::TryFrom; use std::convert::TryFrom;
use thiserror::Error; use thiserror::Error;
use vtse_common::net::{ServerResponse, UserError}; use vtse_common::net::{ServerResponse, UserError};
use vtse_common::stock::{Stock, StockName}; use vtse_common::stock::{Stock, StockSymbol};
use vtse_common::user::{ApiKey, Password, Username}; use vtse_common::user::{ApiKey, Password, User, UserBalance, UserDebt, Username};
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub(crate) enum StateError { pub(crate) enum StateError {
@ -42,10 +43,14 @@ impl AppState {
/// Query operations impl /// Query operations impl
impl AppState { impl AppState {
pub(crate) async fn stock_info(&self, stock_name: StockName, pool: &PgPool) -> OperationResult { pub(crate) async fn stock_info(
&self,
stock_symbol: StockSymbol,
pool: &PgPool,
) -> OperationResult {
let stock = query!( let stock = query!(
"SELECT name, symbol, description, price FROM stocks WHERE name = $1", "SELECT name, symbol, description, price FROM stocks WHERE symbol = $1",
stock_name as StockName stock_symbol.inner()
) )
.fetch_one(pool) .fetch_one(pool)
.await?; .await?;
@ -59,7 +64,21 @@ impl AppState {
} }
pub(crate) async fn user_info(&self, username: Username, pool: &PgPool) -> OperationResult { pub(crate) async fn user_info(&self, username: Username, pool: &PgPool) -> OperationResult {
todo!() let user = query!(
"SELECT balance, debt FROM users WHERE username = $1",
username.inner()
)
.fetch_optional(pool)
.await?;
match user {
Some(user) => Ok(ServerResponse::UserInfo(User::new(
username,
UserBalance::new(user.balance),
UserDebt::new(user.debt),
))),
None => Ok(ServerResponse::UserError(UserError::InvalidUsername)),
}
} }
} }
@ -175,18 +194,67 @@ impl AppState {
/// Market operation implementation /// Market operation implementation
impl AppState { impl AppState {
pub(crate) fn buy(&self, stock_name: StockName, pool: &PgPool) -> OperationResult { pub(crate) async fn buy(
if !matches!(self, Self::Authorized{..}) { &self,
return Err(StateError::WrongState); symbol: StockSymbol,
amount: usize,
pool: &PgPool,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let user_balance = Self::get_user_balance(id, pool).await?;
todo!()
}
pub(crate) async fn sell(
&self,
symbol: StockSymbol,
amount: usize,
pool: &PgPool,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let num_owned = Self::get_owned_stock_count(id, symbol, pool).await?;
if num_owned < amount {
return Ok(ServerResponse::UserError(UserError::NotEnoughOwnedStock(
num_owned,
)));
} }
todo!() todo!()
} }
pub(crate) fn sell(&self, stock_name: StockName, pool: &PgPool) -> OperationResult { // todo: fetch_one needs to turn into fetch_optional, else we can't discriminate user versus
if !matches!(self, Self::Authorized{..}) { // server error.
return Err(StateError::WrongState);
async fn get_user_balance(user_id: i32, pool: &PgPool) -> Result<Decimal, StateError> {
query!("SELECT balance FROM users where user_id = $1", user_id)
.fetch_one(pool)
.await
.map(|record| record.balance)
.map_err(StateError::from)
} }
todo!()
async fn get_owned_stock_count(
user_id: i32,
symbol: StockSymbol,
pool: &PgPool,
) -> Result<usize, StateError> {
query!(
"SELECT amount FROM users_stocks
JOIN stocks ON users_stocks.stock_id = stocks.stock_id
WHERE user_id = $1 AND symbol = $2",
user_id,
symbol.inner()
)
.fetch_one(pool)
.await
.map(|record| record.amount as usize)
.map_err(StateError::from)
} }
} }

View file

@ -1,7 +1,5 @@
use sodiumoxide::crypto::pwhash::{ use sodiumoxide::crypto::pwhash::argon2id13::HashedPassword;
argon2id13::HashedPassword, use sodiumoxide::crypto::pwhash::argon2id13::{pwhash, MEMLIMIT_INTERACTIVE, OPSLIMIT_INTERACTIVE};
argon2id13::{pwhash, MEMLIMIT_INTERACTIVE, OPSLIMIT_INTERACTIVE},
};
use std::convert::TryFrom; use std::convert::TryFrom;
use vtse_common::user::Password; use vtse_common::user::Password;