Compare commits

..

2 commits

Author SHA1 Message Date
e6ff1576f1
add fetch_one note 2021-02-08 20:59:35 -05:00
f560236f8f
more work 2021-02-08 20:54:00 -05:00
10 changed files with 162 additions and 42 deletions

1
Cargo.lock generated
View file

@ -1376,6 +1376,7 @@ dependencies = [
"dotenv",
"log",
"rand 0.8.3",
"rust_decimal",
"serde",
"serde_json",
"simple_logger",

View file

@ -1,4 +1,4 @@
use crate::stock::Stock;
use crate::{stock::Stock, user::User};
use serde::Serialize;
use uuid::Uuid;
@ -10,11 +10,14 @@ pub enum ServerResponse {
NewApiKey(Uuid),
UserError(UserError),
StockInfo(Stock),
UserInfo(User),
}
#[derive(Serialize)]
pub enum UserError {
InvalidUsername,
InvalidPassword,
InvalidApiKey,
NotAuthorized,
NotEnoughOwnedStock(usize),
}

View file

@ -1,5 +1,3 @@
use crate::stock::StockName;
pub(crate) enum MarketOperation {
Buy,
Sell,

View file

@ -3,11 +3,8 @@ use serde::{Deserialize, Serialize};
use std::str::FromStr;
use thiserror::Error;
#[derive(
Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, sqlx::Type,
)]
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)]
#[serde(transparent)]
#[sqlx(transparent)]
pub struct StockName(String);
impl From<String> for StockName {
@ -37,6 +34,20 @@ impl From<String> for StockSymbol {
}
}
impl FromStr for StockSymbol {
type Err = StockNameParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(s.to_string()))
}
}
impl StockSymbol {
pub fn inner(&self) -> &str {
&self.0
}
}
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)]
#[serde(transparent)]
pub struct StockDescription(String);

View file

@ -1,4 +1,6 @@
use rust_decimal::Decimal;
use serde::Deserialize;
use serde::Serialize;
use uuid::Uuid;
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)]
@ -10,7 +12,8 @@ impl ApiKey {
}
}
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)]
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize, Serialize)]
#[serde(transparent)]
pub struct Username(String);
impl Username {
@ -20,6 +23,7 @@ impl Username {
}
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)]
#[serde(transparent)]
pub struct Password(String);
impl Password {
@ -27,3 +31,40 @@ impl Password {
self.0.as_bytes()
}
}
#[derive(Serialize)]
#[serde(transparent)]
pub struct UserBalance(Decimal);
impl UserBalance {
pub fn new(decimal: Decimal) -> Self {
Self(decimal)
}
}
#[derive(Serialize)]
#[serde(transparent)]
pub struct UserDebt(Decimal);
impl UserDebt {
pub fn new(decimal: Decimal) -> Self {
Self(decimal)
}
}
#[derive(Serialize)]
pub struct User {
username: Username,
balance: UserBalance,
debt: UserDebt,
}
impl User {
pub fn new(username: Username, balance: UserBalance, debt: UserDebt) -> Self {
Self {
username,
balance,
debt,
}
}
}

View file

@ -10,6 +10,7 @@ bytes = "1"
dotenv = "0.15"
log = "0.4"
rand = "0.8"
rust_decimal = "1.10.2"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
simple_logger = "1.11"

View file

@ -36,13 +36,9 @@ async fn main() -> Result<()> {
// clone simply clones the arc reference, so this is cheap.
let db_pool = db_pool.clone();
tokio::task::spawn(async {
match handle_stream(stream, db_pool).await {
Ok(_) => (),
Err(e) => {
// stream.write_all(e);
if let Err(e) = handle_stream(stream, db_pool).await {
error!("{}", e);
}
}
});
}
@ -65,10 +61,11 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
let data = buffer.split_to(bytes_read); // O(1)
let parsed = serde_json::from_slice::<ServerOperation>(&data)?;
debug!("Parsed operation: {:?}", parsed);
let response = match parsed {
ServerOperation::Query(op) => match op {
QueryOperation::StockInfo { stock } => state.stock_info(stock, &pool).await?,
QueryOperation::User(username) => state.user_info(username, &pool).await?,
QueryOperation::User { username } => state.user_info(username, &pool).await?,
},
ServerOperation::User(op) => match op {
UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
@ -80,8 +77,10 @@ async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
}
},
ServerOperation::Market(op) => match op {
MarketOperation::Buy(stock_name) => state.buy(stock_name, &pool)?,
MarketOperation::Sell(stock_name) => state.sell(stock_name, &pool)?,
MarketOperation::Buy { symbol, amount } => state.buy(symbol, amount, &pool).await?,
MarketOperation::Sell { symbol, amount } => {
state.sell(symbol, amount, &pool).await?
}
},
};

View file

@ -1,5 +1,5 @@
use serde::Deserialize;
use vtse_common::stock::StockName;
use vtse_common::stock::StockSymbol;
use vtse_common::user::{ApiKey, Password, Username};
#[derive(Deserialize, Debug, PartialEq)]
@ -13,8 +13,8 @@ pub(crate) enum ServerOperation {
#[derive(Deserialize, Debug, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum QueryOperation {
StockInfo { stock: StockName },
User(Username),
StockInfo { stock: StockSymbol },
User { username: Username },
}
#[derive(Deserialize, Debug, PartialEq)]
@ -35,8 +35,8 @@ pub(crate) enum UserOperation {
#[derive(Deserialize, Debug, PartialEq)]
pub(crate) enum MarketOperation {
Buy(StockName),
Sell(StockName),
Buy { symbol: StockSymbol, amount: usize },
Sell { symbol: StockSymbol, amount: usize },
}
#[cfg(test)]
@ -54,7 +54,7 @@ mod deserialize {
}))
.unwrap(),
ServerOperation::Query(QueryOperation::StockInfo {
stock: StockName::from_str("Gura").unwrap()
stock: StockSymbol::from_str("Gura").unwrap()
})
)
}
@ -68,7 +68,7 @@ mod deserialize {
}))
.unwrap(),
QueryOperation::StockInfo {
stock: StockName::from_str("Gura").unwrap()
stock: StockSymbol::from_str("Gura").unwrap()
}
)
}

View file

@ -1,11 +1,12 @@
use crate::user::SaltedPassword;
use rust_decimal::Decimal;
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword};
use sqlx::{query, PgPool};
use std::convert::TryFrom;
use thiserror::Error;
use vtse_common::net::{ServerResponse, UserError};
use vtse_common::stock::{Stock, StockName};
use vtse_common::user::{ApiKey, Password, Username};
use vtse_common::stock::{Stock, StockSymbol};
use vtse_common::user::{ApiKey, Password, User, UserBalance, UserDebt, Username};
#[derive(Error, Debug)]
pub(crate) enum StateError {
@ -42,10 +43,14 @@ impl AppState {
/// Query operations impl
impl AppState {
pub(crate) async fn stock_info(&self, stock_name: StockName, pool: &PgPool) -> OperationResult {
pub(crate) async fn stock_info(
&self,
stock_symbol: StockSymbol,
pool: &PgPool,
) -> OperationResult {
let stock = query!(
"SELECT name, symbol, description, price FROM stocks WHERE name = $1",
stock_name as StockName
"SELECT name, symbol, description, price FROM stocks WHERE symbol = $1",
stock_symbol.inner()
)
.fetch_one(pool)
.await?;
@ -59,7 +64,21 @@ impl AppState {
}
pub(crate) async fn user_info(&self, username: Username, pool: &PgPool) -> OperationResult {
todo!()
let user = query!(
"SELECT balance, debt FROM users WHERE username = $1",
username.inner()
)
.fetch_optional(pool)
.await?;
match user {
Some(user) => Ok(ServerResponse::UserInfo(User::new(
username,
UserBalance::new(user.balance),
UserDebt::new(user.debt),
))),
None => Ok(ServerResponse::UserError(UserError::InvalidUsername)),
}
}
}
@ -175,18 +194,67 @@ impl AppState {
/// Market operation implementation
impl AppState {
pub(crate) fn buy(&self, stock_name: StockName, pool: &PgPool) -> OperationResult {
if !matches!(self, Self::Authorized{..}) {
return Err(StateError::WrongState);
pub(crate) async fn buy(
&self,
symbol: StockSymbol,
amount: usize,
pool: &PgPool,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let user_balance = Self::get_user_balance(id, pool).await?;
todo!()
}
pub(crate) async fn sell(
&self,
symbol: StockSymbol,
amount: usize,
pool: &PgPool,
) -> OperationResult {
let id = *match self {
Self::Authorized { user_id } => user_id,
_ => return Err(StateError::WrongState),
};
let num_owned = Self::get_owned_stock_count(id, symbol, pool).await?;
if num_owned < amount {
return Ok(ServerResponse::UserError(UserError::NotEnoughOwnedStock(
num_owned,
)));
}
todo!()
}
pub(crate) fn sell(&self, stock_name: StockName, pool: &PgPool) -> OperationResult {
if !matches!(self, Self::Authorized{..}) {
return Err(StateError::WrongState);
// todo: fetch_one needs to turn into fetch_optional, else we can't discriminate user versus
// server error.
async fn get_user_balance(user_id: i32, pool: &PgPool) -> Result<Decimal, StateError> {
query!("SELECT balance FROM users where user_id = $1", user_id)
.fetch_one(pool)
.await
.map(|record| record.balance)
.map_err(StateError::from)
}
todo!()
async fn get_owned_stock_count(
user_id: i32,
symbol: StockSymbol,
pool: &PgPool,
) -> Result<usize, StateError> {
query!(
"SELECT amount FROM users_stocks
JOIN stocks ON users_stocks.stock_id = stocks.stock_id
WHERE user_id = $1 AND symbol = $2",
user_id,
symbol.inner()
)
.fetch_one(pool)
.await
.map(|record| record.amount as usize)
.map_err(StateError::from)
}
}

View file

@ -1,7 +1,5 @@
use sodiumoxide::crypto::pwhash::{
argon2id13::HashedPassword,
argon2id13::{pwhash, MEMLIMIT_INTERACTIVE, OPSLIMIT_INTERACTIVE},
};
use sodiumoxide::crypto::pwhash::argon2id13::HashedPassword;
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash, MEMLIMIT_INTERACTIVE, OPSLIMIT_INTERACTIVE};
use std::convert::TryFrom;
use vtse_common::user::Password;