get user login working

This commit is contained in:
Edward Shen 2021-02-08 12:33:24 -05:00
parent c446508829
commit 0bb86c79e3
Signed by: edward
GPG key ID: 19182661E818369F
12 changed files with 1365 additions and 44 deletions

1
.gitignore vendored
View file

@ -1 +1,2 @@
**/target **/target
**/.env

1015
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -7,3 +7,4 @@ edition = "2018"
[dependencies] [dependencies]
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
thiserror = "1" thiserror = "1"
uuid = { version = "0.8", features = ["serde", "v4"] }

View file

@ -1,3 +1,4 @@
pub mod net;
pub mod operations; pub mod operations;
pub mod stock; pub mod stock;

18
vtse-common/src/net.rs Normal file
View file

@ -0,0 +1,18 @@
use serde::Serialize;
use uuid::Uuid;
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ServerResponse {
/// Generic success
Success,
NewApiKey(Uuid),
UserError(UserError),
}
#[derive(Serialize)]
pub enum UserError {
InvalidPassword,
InvalidApiKey,
NotAuthorized,
}

View file

@ -5,13 +5,18 @@ authors = ["Edward Shen <code@eddie.sh>"]
edition = "2018" edition = "2018"
[dependencies] [dependencies]
vtse-common = { path = "../vtse-common" }
log = "0.4"
simple_logger = "1.11"
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1", features = ["net"] }
anyhow = "1" anyhow = "1"
uuid = { version = "0.8", features = ["v4"] } bytes = "1"
dotenv = "0.15"
log = "0.4"
rand = "0.8"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"
bytes = "1" simple_logger = "1.11"
sodiumoxide = "0.2"
sqlx = { version = "0.5", features = [ "runtime-tokio-rustls", "macros", "uuid", "postgres" ] }
thiserror = "1"
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1", features = ["net"] }
uuid = { version = "0.8", features = ["v4"] }
vtse-common = { path = "../vtse-common" }

View file

@ -1,30 +1,47 @@
use crate::operations::ServerOperation; use crate::operations::ServerOperation;
use anyhow::anyhow; use anyhow::{bail, Result};
use bytes::BytesMut; use bytes::BytesMut;
use log::{error, info}; use log::{debug, error, info};
use tokio::io::AsyncReadExt; use operations::{MarketOperation, QueryOperation, UserOperation};
use sqlx::PgPool;
use state::AppState;
use std::env;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
mod market;
mod operations; mod operations;
mod state;
pub(crate) type Result<T> = std::result::Result<T, anyhow::Error>; mod user;
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
dotenv::dotenv().ok();
// If we can't successfully initialize our crypto library, fail immediately.
sodiumoxide::init().unwrap();
simple_logger::SimpleLogger::default().init()?; simple_logger::SimpleLogger::default().init()?;
let mut listener_stream = TcpListener::bind("localhost:8080") let mut listener_stream = TcpListener::bind(&env::var("BIND_ADDRESS")?)
.await .await
.map(TcpListenerStream::new)?; .map(TcpListenerStream::new)?;
info!("Successfully bound to port"); info!("Successfully bound to port.");
let db_pool = PgPool::connect(&env::var("DATABASE_URL")?).await?;
info!("Successfully established connection to database.");
while let Some(Ok(stream)) = listener_stream.next().await { while let Some(Ok(stream)) = listener_stream.next().await {
// clone simply clones the arc reference, so this is cheap.
let db_pool = db_pool.clone();
tokio::task::spawn(async { tokio::task::spawn(async {
match handle_stream(stream).await { match handle_stream(stream, db_pool).await {
Ok(_) => (), Ok(_) => (),
Err(e) => error!("{}", e), Err(e) => {
// stream.write_all(e);
error!("{}", e);
}
} }
}); });
} }
@ -34,27 +51,44 @@ async fn main() -> Result<()> {
Ok(()) Ok(())
} }
async fn handle_stream(mut socket: TcpStream) -> Result<()> { async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
// only accept data that can fit in 256 bytes // only accept data that can fit in 256 bytes
// the assumption is that a single request should always be within n bytes,
// otherwise there's a good chance that it's mal{formed,icious}.
let mut buffer = BytesMut::with_capacity(256); let mut buffer = BytesMut::with_capacity(256);
let mut state = AppState::new();
loop { loop {
let bytes_read = socket.read_buf(&mut buffer).await?; let bytes_read = socket.read_buf(&mut buffer).await?;
match bytes_read { if bytes_read == 0 {
0 => return Err(anyhow!("Failed to read bytes, assuming socket is closed.")), bail!("Failed to read bytes, assuming socket is closed.")
n => {
let data = buffer.split_to(n); // O(1)
let parsed = serde_json::from_slice::<ServerOperation>(&data)?;
match parsed {
ServerOperation::Query(op) => {
dbg!(op);
}
ServerOperation::Meta(op) => {
dbg!(op);
}
}
buffer.unsplit(data); // O(1)
}
} }
let data = buffer.split_to(bytes_read); // O(1)
let parsed = serde_json::from_slice::<ServerOperation>(&data)?;
debug!("Parsed operation: {:?}", parsed);
let response = match parsed {
ServerOperation::Query(op) => match op {
QueryOperation::StockInfo { stock } => todo!(),
QueryOperation::User(_) => todo!(),
},
ServerOperation::User(op) => match op {
UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
UserOperation::Register { username, password } => {
state.register(username, password, &pool).await?
}
UserOperation::GetKey { username, password } => {
state.generate_api_key(username, password, &pool).await?
}
},
ServerOperation::Market(op) => match op {
MarketOperation::Buy(stock_name) => state.buy(stock_name)?,
MarketOperation::Sell(stock_name) => state.sell(stock_name)?,
},
};
socket
.write_all(serde_json::to_string(&response).unwrap().as_bytes())
.await?;
buffer.unsplit(data); // O(1)
buffer.clear(); buffer.clear();
} }
} }

View file

@ -0,0 +1,37 @@
use rand::rngs::ThreadRng;
use rand::thread_rng;
use rand::{distributions::Uniform, prelude::Distribution};
#[derive(Copy, Clone, PartialEq, PartialOrd, Debug)]
pub(crate) struct MarketModifier(f64);
#[derive(Clone, Debug)]
pub(crate) struct MarketGenerator {
rng: ThreadRng,
sample_range: Uniform<f64>,
volatility: f64,
old_price: MarketModifier,
}
impl MarketGenerator {
pub(crate) fn new(volatility: f64, initial_value: f64) -> Self {
Self {
rng: thread_rng(),
sample_range: Uniform::new_inclusive(-0.5, 0.5),
volatility,
old_price: MarketModifier(initial_value),
}
}
}
impl Iterator for MarketGenerator {
type Item = MarketModifier;
fn next(&mut self) -> Option<Self::Item> {
// Algorithm from https://stackoverflow.com/a/8597889
let rng_multiplier = self.sample_range.sample(&mut self.rng);
let delta = 2f64 * self.volatility * rng_multiplier;
self.old_price.0 += self.old_price.0 * delta;
Some(self.old_price)
}
}

View file

@ -0,0 +1 @@
mod generator;

View file

@ -1,23 +1,43 @@
use serde::Deserialize; use serde::Deserialize;
use vtse_common::stock::StockName; use vtse_common::stock::StockName;
use crate::user::{ApiKey, Password, Username};
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
#[serde(untagged)] #[serde(untagged)]
pub(crate) enum ServerOperation { pub(crate) enum ServerOperation {
Query(QueryOperation), Query(QueryOperation),
Meta(MetaOperation), User(UserOperation),
Market(MarketOperation),
} }
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum QueryOperation { pub(crate) enum QueryOperation {
StockInfo { stock: StockName }, StockInfo { stock: StockName },
User(Username),
} }
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum MetaOperation { pub(crate) enum UserOperation {
Register { username: String, password: String }, Login {
api_key: ApiKey,
},
Register {
username: Username,
password: Password,
},
GetKey {
username: Username,
password: Password,
},
}
#[derive(Deserialize, Debug, PartialEq)]
pub(crate) enum MarketOperation {
Buy(StockName),
Sell(StockName),
} }
#[cfg(test)] #[cfg(test)]

167
vtse-server/src/state.rs Normal file
View file

@ -0,0 +1,167 @@
use crate::user::{ApiKey, Password, Username};
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword};
use sqlx::{query, PgPool};
use thiserror::Error;
use vtse_common::{
net::{ServerResponse, UserError},
stock::StockName,
};
#[derive(Error, Debug)]
pub(crate) enum StateError {
#[error("Was not in the correct state")]
WrongState,
#[error("Got SQLx error: {0}`")]
Database(#[from] sqlx::Error),
#[error("Failed to hash password")]
PasswordHash,
}
type OperationResult = Result<ServerResponse, StateError>;
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
pub(crate) enum AppState {
Unauthorized,
Authorized { user_id: i32 },
}
impl AppState {
pub(crate) fn new() -> Self {
Self::Unauthorized
}
fn assert_state(&self, expected_state: AppState) -> Result<(), StateError> {
if *self != expected_state {
Err(StateError::WrongState)
} else {
Ok(())
}
}
}
/// User operation implementation
impl AppState {
pub(crate) async fn login(&mut self, api_key: ApiKey, pool: &PgPool) -> OperationResult {
self.assert_state(AppState::Unauthorized)?;
let user_id = {
let query = query!(
"SELECT (user_id) FROM user_api_keys
JOIN api_keys ON user_api_keys.key_id = api_keys.key_id
WHERE key = $1",
api_key.0
)
.fetch_optional(pool)
.await?;
match query {
Some(id) => id.user_id,
None => return Ok(ServerResponse::UserError(UserError::InvalidApiKey)),
}
};
*self = AppState::Authorized { user_id };
Ok(ServerResponse::Success)
}
pub(crate) async fn register(
&mut self,
username: Username,
password: Password,
pool: &PgPool,
) -> OperationResult {
self.assert_state(AppState::Unauthorized)?;
query!(
"INSERT INTO users
(username, pwhash_data)
VALUES ($1::varchar, $2::bytea)",
username as Username,
password.salt().map_err(|_| StateError::PasswordHash)?.0,
)
.execute(pool)
.await?;
Ok(ServerResponse::Success)
}
pub(crate) async fn generate_api_key(
&self,
username: Username,
password: Password,
pool: &PgPool,
) -> OperationResult {
self.assert_state(AppState::Unauthorized)?;
match self.validate_password(&username, &password, pool).await {
Ok(None) => Ok(ServerResponse::UserError(UserError::InvalidPassword)),
Ok(Some(user_id)) => self.create_api_key(user_id, pool).await,
Err(_) => todo!(),
}
}
async fn create_api_key(&self, user_id: i32, pool: &PgPool) -> OperationResult {
let mut transaction = pool.begin().await?;
let api_key = uuid::Uuid::new_v4();
let api_key_id = query!(
"INSERT INTO api_keys (key)
VALUES ($1)
RETURNING api_keys.key_id;",
&api_key,
)
.fetch_one(&mut transaction)
.await?
.key_id;
query!(
"INSERT INTO user_api_keys (user_id, key_id)
VALUES ($1, $2)",
user_id,
api_key_id
)
.execute(&mut transaction)
.await?;
transaction.commit().await?;
Ok(ServerResponse::NewApiKey(api_key))
}
async fn validate_password(
&self,
username: &Username,
password: &Password,
pool: &PgPool,
) -> Result<Option<i32>, StateError> {
let result = query!(
"SELECT user_id, pwhash_data FROM users
WHERE username = $1::varchar",
username as &Username
)
.fetch_one(pool)
.await?;
let login_attempt = HashedPassword::from_slice(&result.pwhash_data)
.map(|pw| pwhash_verify(&pw, password.as_bytes()))
.unwrap();
if login_attempt {
Ok(Some(result.user_id))
} else {
Ok(None)
}
}
}
/// Market operation implementation
impl AppState {
pub(crate) fn buy(&self, stock_name: StockName) -> OperationResult {
// self.assert_state(AppState::Authorized)?;
todo!()
}
pub(crate) fn sell(&self, stock_name: StockName) -> OperationResult {
// self.assert_state(AppState::Authorized)?;
todo!()
}
}
impl Default for AppState {
fn default() -> Self {
Self::new()
}
}

View file

@ -1 +1,36 @@
pub(crate) struct ApiKey(String); use serde::Deserialize;
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash, MEMLIMIT_INTERACTIVE, OPSLIMIT_INTERACTIVE};
use uuid::Uuid;
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize, sqlx::Type)]
#[sqlx(transparent)]
pub(crate) struct ApiKey(pub(crate) Uuid);
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize, sqlx::Type)]
#[sqlx(transparent)]
pub(crate) struct Username(String);
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)]
pub(crate) struct Password(String);
#[derive(sqlx::Type)]
#[sqlx(transparent)]
pub(crate) struct SaltedPasswordData(pub(crate) Vec<u8>);
impl Password {
pub(crate) fn salt(self) -> Result<SaltedPasswordData, ()> {
Ok(SaltedPasswordData(
pwhash(
self.0.as_bytes(),
OPSLIMIT_INTERACTIVE,
MEMLIMIT_INTERACTIVE,
)?
.as_ref()
.to_vec(),
))
}
pub(crate) fn as_bytes(&self) -> &[u8] {
self.0.as_bytes()
}
}