get user login working
This commit is contained in:
parent
c446508829
commit
0bb86c79e3
12 changed files with 1365 additions and 44 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1 +1,2 @@
|
||||||
**/target
|
**/target
|
||||||
|
**/.env
|
1015
Cargo.lock
generated
1015
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -7,3 +7,4 @@ edition = "2018"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
|
uuid = { version = "0.8", features = ["serde", "v4"] }
|
|
@ -1,3 +1,4 @@
|
||||||
|
pub mod net;
|
||||||
pub mod operations;
|
pub mod operations;
|
||||||
pub mod stock;
|
pub mod stock;
|
||||||
|
|
||||||
|
|
18
vtse-common/src/net.rs
Normal file
18
vtse-common/src/net.rs
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
use serde::Serialize;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ServerResponse {
|
||||||
|
/// Generic success
|
||||||
|
Success,
|
||||||
|
NewApiKey(Uuid),
|
||||||
|
UserError(UserError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub enum UserError {
|
||||||
|
InvalidPassword,
|
||||||
|
InvalidApiKey,
|
||||||
|
NotAuthorized,
|
||||||
|
}
|
|
@ -5,13 +5,18 @@ authors = ["Edward Shen <code@eddie.sh>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
vtse-common = { path = "../vtse-common" }
|
|
||||||
log = "0.4"
|
|
||||||
simple_logger = "1.11"
|
|
||||||
tokio = { version = "1", features = ["full"] }
|
|
||||||
tokio-stream = { version = "0.1", features = ["net"] }
|
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
uuid = { version = "0.8", features = ["v4"] }
|
bytes = "1"
|
||||||
|
dotenv = "0.15"
|
||||||
|
log = "0.4"
|
||||||
|
rand = "0.8"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
bytes = "1"
|
simple_logger = "1.11"
|
||||||
|
sodiumoxide = "0.2"
|
||||||
|
sqlx = { version = "0.5", features = [ "runtime-tokio-rustls", "macros", "uuid", "postgres" ] }
|
||||||
|
thiserror = "1"
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
tokio-stream = { version = "0.1", features = ["net"] }
|
||||||
|
uuid = { version = "0.8", features = ["v4"] }
|
||||||
|
vtse-common = { path = "../vtse-common" }
|
|
@ -1,30 +1,47 @@
|
||||||
use crate::operations::ServerOperation;
|
use crate::operations::ServerOperation;
|
||||||
use anyhow::anyhow;
|
use anyhow::{bail, Result};
|
||||||
use bytes::BytesMut;
|
use bytes::BytesMut;
|
||||||
use log::{error, info};
|
use log::{debug, error, info};
|
||||||
use tokio::io::AsyncReadExt;
|
use operations::{MarketOperation, QueryOperation, UserOperation};
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use state::AppState;
|
||||||
|
use std::env;
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
|
use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
|
||||||
|
|
||||||
|
mod market;
|
||||||
mod operations;
|
mod operations;
|
||||||
|
mod state;
|
||||||
pub(crate) type Result<T> = std::result::Result<T, anyhow::Error>;
|
mod user;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
|
dotenv::dotenv().ok();
|
||||||
|
// If we can't successfully initialize our crypto library, fail immediately.
|
||||||
|
sodiumoxide::init().unwrap();
|
||||||
simple_logger::SimpleLogger::default().init()?;
|
simple_logger::SimpleLogger::default().init()?;
|
||||||
|
|
||||||
let mut listener_stream = TcpListener::bind("localhost:8080")
|
let mut listener_stream = TcpListener::bind(&env::var("BIND_ADDRESS")?)
|
||||||
.await
|
.await
|
||||||
.map(TcpListenerStream::new)?;
|
.map(TcpListenerStream::new)?;
|
||||||
|
|
||||||
info!("Successfully bound to port");
|
info!("Successfully bound to port.");
|
||||||
|
|
||||||
|
let db_pool = PgPool::connect(&env::var("DATABASE_URL")?).await?;
|
||||||
|
|
||||||
|
info!("Successfully established connection to database.");
|
||||||
|
|
||||||
while let Some(Ok(stream)) = listener_stream.next().await {
|
while let Some(Ok(stream)) = listener_stream.next().await {
|
||||||
|
// clone simply clones the arc reference, so this is cheap.
|
||||||
|
let db_pool = db_pool.clone();
|
||||||
tokio::task::spawn(async {
|
tokio::task::spawn(async {
|
||||||
match handle_stream(stream).await {
|
match handle_stream(stream, db_pool).await {
|
||||||
Ok(_) => (),
|
Ok(_) => (),
|
||||||
Err(e) => error!("{}", e),
|
Err(e) => {
|
||||||
|
// stream.write_all(e);
|
||||||
|
error!("{}", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -34,27 +51,44 @@ async fn main() -> Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_stream(mut socket: TcpStream) -> Result<()> {
|
async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
|
||||||
// only accept data that can fit in 256 bytes
|
// only accept data that can fit in 256 bytes
|
||||||
|
// the assumption is that a single request should always be within n bytes,
|
||||||
|
// otherwise there's a good chance that it's mal{formed,icious}.
|
||||||
let mut buffer = BytesMut::with_capacity(256);
|
let mut buffer = BytesMut::with_capacity(256);
|
||||||
|
let mut state = AppState::new();
|
||||||
loop {
|
loop {
|
||||||
let bytes_read = socket.read_buf(&mut buffer).await?;
|
let bytes_read = socket.read_buf(&mut buffer).await?;
|
||||||
match bytes_read {
|
if bytes_read == 0 {
|
||||||
0 => return Err(anyhow!("Failed to read bytes, assuming socket is closed.")),
|
bail!("Failed to read bytes, assuming socket is closed.")
|
||||||
n => {
|
|
||||||
let data = buffer.split_to(n); // O(1)
|
|
||||||
let parsed = serde_json::from_slice::<ServerOperation>(&data)?;
|
|
||||||
match parsed {
|
|
||||||
ServerOperation::Query(op) => {
|
|
||||||
dbg!(op);
|
|
||||||
}
|
|
||||||
ServerOperation::Meta(op) => {
|
|
||||||
dbg!(op);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
buffer.unsplit(data); // O(1)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
let data = buffer.split_to(bytes_read); // O(1)
|
||||||
|
let parsed = serde_json::from_slice::<ServerOperation>(&data)?;
|
||||||
|
debug!("Parsed operation: {:?}", parsed);
|
||||||
|
let response = match parsed {
|
||||||
|
ServerOperation::Query(op) => match op {
|
||||||
|
QueryOperation::StockInfo { stock } => todo!(),
|
||||||
|
QueryOperation::User(_) => todo!(),
|
||||||
|
},
|
||||||
|
ServerOperation::User(op) => match op {
|
||||||
|
UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
|
||||||
|
UserOperation::Register { username, password } => {
|
||||||
|
state.register(username, password, &pool).await?
|
||||||
|
}
|
||||||
|
UserOperation::GetKey { username, password } => {
|
||||||
|
state.generate_api_key(username, password, &pool).await?
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ServerOperation::Market(op) => match op {
|
||||||
|
MarketOperation::Buy(stock_name) => state.buy(stock_name)?,
|
||||||
|
MarketOperation::Sell(stock_name) => state.sell(stock_name)?,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
socket
|
||||||
|
.write_all(serde_json::to_string(&response).unwrap().as_bytes())
|
||||||
|
.await?;
|
||||||
|
buffer.unsplit(data); // O(1)
|
||||||
buffer.clear();
|
buffer.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
37
vtse-server/src/market/generator.rs
Normal file
37
vtse-server/src/market/generator.rs
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
use rand::rngs::ThreadRng;
|
||||||
|
use rand::thread_rng;
|
||||||
|
use rand::{distributions::Uniform, prelude::Distribution};
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, PartialEq, PartialOrd, Debug)]
|
||||||
|
pub(crate) struct MarketModifier(f64);
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub(crate) struct MarketGenerator {
|
||||||
|
rng: ThreadRng,
|
||||||
|
sample_range: Uniform<f64>,
|
||||||
|
volatility: f64,
|
||||||
|
old_price: MarketModifier,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MarketGenerator {
|
||||||
|
pub(crate) fn new(volatility: f64, initial_value: f64) -> Self {
|
||||||
|
Self {
|
||||||
|
rng: thread_rng(),
|
||||||
|
sample_range: Uniform::new_inclusive(-0.5, 0.5),
|
||||||
|
volatility,
|
||||||
|
old_price: MarketModifier(initial_value),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for MarketGenerator {
|
||||||
|
type Item = MarketModifier;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
// Algorithm from https://stackoverflow.com/a/8597889
|
||||||
|
let rng_multiplier = self.sample_range.sample(&mut self.rng);
|
||||||
|
let delta = 2f64 * self.volatility * rng_multiplier;
|
||||||
|
self.old_price.0 += self.old_price.0 * delta;
|
||||||
|
Some(self.old_price)
|
||||||
|
}
|
||||||
|
}
|
1
vtse-server/src/market/mod.rs
Normal file
1
vtse-server/src/market/mod.rs
Normal file
|
@ -0,0 +1 @@
|
||||||
|
mod generator;
|
|
@ -1,23 +1,43 @@
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use vtse_common::stock::StockName;
|
use vtse_common::stock::StockName;
|
||||||
|
|
||||||
|
use crate::user::{ApiKey, Password, Username};
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, PartialEq)]
|
#[derive(Deserialize, Debug, PartialEq)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub(crate) enum ServerOperation {
|
pub(crate) enum ServerOperation {
|
||||||
Query(QueryOperation),
|
Query(QueryOperation),
|
||||||
Meta(MetaOperation),
|
User(UserOperation),
|
||||||
|
Market(MarketOperation),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, PartialEq)]
|
#[derive(Deserialize, Debug, PartialEq)]
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
pub(crate) enum QueryOperation {
|
pub(crate) enum QueryOperation {
|
||||||
StockInfo { stock: StockName },
|
StockInfo { stock: StockName },
|
||||||
|
User(Username),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, PartialEq)]
|
#[derive(Deserialize, Debug, PartialEq)]
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
pub(crate) enum MetaOperation {
|
pub(crate) enum UserOperation {
|
||||||
Register { username: String, password: String },
|
Login {
|
||||||
|
api_key: ApiKey,
|
||||||
|
},
|
||||||
|
Register {
|
||||||
|
username: Username,
|
||||||
|
password: Password,
|
||||||
|
},
|
||||||
|
GetKey {
|
||||||
|
username: Username,
|
||||||
|
password: Password,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug, PartialEq)]
|
||||||
|
pub(crate) enum MarketOperation {
|
||||||
|
Buy(StockName),
|
||||||
|
Sell(StockName),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
167
vtse-server/src/state.rs
Normal file
167
vtse-server/src/state.rs
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
use crate::user::{ApiKey, Password, Username};
|
||||||
|
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword};
|
||||||
|
use sqlx::{query, PgPool};
|
||||||
|
use thiserror::Error;
|
||||||
|
use vtse_common::{
|
||||||
|
net::{ServerResponse, UserError},
|
||||||
|
stock::StockName,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub(crate) enum StateError {
|
||||||
|
#[error("Was not in the correct state")]
|
||||||
|
WrongState,
|
||||||
|
#[error("Got SQLx error: {0}`")]
|
||||||
|
Database(#[from] sqlx::Error),
|
||||||
|
#[error("Failed to hash password")]
|
||||||
|
PasswordHash,
|
||||||
|
}
|
||||||
|
|
||||||
|
type OperationResult = Result<ServerResponse, StateError>;
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
|
||||||
|
pub(crate) enum AppState {
|
||||||
|
Unauthorized,
|
||||||
|
Authorized { user_id: i32 },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppState {
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
Self::Unauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_state(&self, expected_state: AppState) -> Result<(), StateError> {
|
||||||
|
if *self != expected_state {
|
||||||
|
Err(StateError::WrongState)
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// User operation implementation
|
||||||
|
impl AppState {
|
||||||
|
pub(crate) async fn login(&mut self, api_key: ApiKey, pool: &PgPool) -> OperationResult {
|
||||||
|
self.assert_state(AppState::Unauthorized)?;
|
||||||
|
let user_id = {
|
||||||
|
let query = query!(
|
||||||
|
"SELECT (user_id) FROM user_api_keys
|
||||||
|
JOIN api_keys ON user_api_keys.key_id = api_keys.key_id
|
||||||
|
WHERE key = $1",
|
||||||
|
api_key.0
|
||||||
|
)
|
||||||
|
.fetch_optional(pool)
|
||||||
|
.await?;
|
||||||
|
match query {
|
||||||
|
Some(id) => id.user_id,
|
||||||
|
None => return Ok(ServerResponse::UserError(UserError::InvalidApiKey)),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
*self = AppState::Authorized { user_id };
|
||||||
|
Ok(ServerResponse::Success)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn register(
|
||||||
|
&mut self,
|
||||||
|
username: Username,
|
||||||
|
password: Password,
|
||||||
|
pool: &PgPool,
|
||||||
|
) -> OperationResult {
|
||||||
|
self.assert_state(AppState::Unauthorized)?;
|
||||||
|
query!(
|
||||||
|
"INSERT INTO users
|
||||||
|
(username, pwhash_data)
|
||||||
|
VALUES ($1::varchar, $2::bytea)",
|
||||||
|
username as Username,
|
||||||
|
password.salt().map_err(|_| StateError::PasswordHash)?.0,
|
||||||
|
)
|
||||||
|
.execute(pool)
|
||||||
|
.await?;
|
||||||
|
Ok(ServerResponse::Success)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn generate_api_key(
|
||||||
|
&self,
|
||||||
|
username: Username,
|
||||||
|
password: Password,
|
||||||
|
pool: &PgPool,
|
||||||
|
) -> OperationResult {
|
||||||
|
self.assert_state(AppState::Unauthorized)?;
|
||||||
|
match self.validate_password(&username, &password, pool).await {
|
||||||
|
Ok(None) => Ok(ServerResponse::UserError(UserError::InvalidPassword)),
|
||||||
|
Ok(Some(user_id)) => self.create_api_key(user_id, pool).await,
|
||||||
|
Err(_) => todo!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_api_key(&self, user_id: i32, pool: &PgPool) -> OperationResult {
|
||||||
|
let mut transaction = pool.begin().await?;
|
||||||
|
let api_key = uuid::Uuid::new_v4();
|
||||||
|
let api_key_id = query!(
|
||||||
|
"INSERT INTO api_keys (key)
|
||||||
|
VALUES ($1)
|
||||||
|
RETURNING api_keys.key_id;",
|
||||||
|
&api_key,
|
||||||
|
)
|
||||||
|
.fetch_one(&mut transaction)
|
||||||
|
.await?
|
||||||
|
.key_id;
|
||||||
|
|
||||||
|
query!(
|
||||||
|
"INSERT INTO user_api_keys (user_id, key_id)
|
||||||
|
VALUES ($1, $2)",
|
||||||
|
user_id,
|
||||||
|
api_key_id
|
||||||
|
)
|
||||||
|
.execute(&mut transaction)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
transaction.commit().await?;
|
||||||
|
|
||||||
|
Ok(ServerResponse::NewApiKey(api_key))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn validate_password(
|
||||||
|
&self,
|
||||||
|
username: &Username,
|
||||||
|
password: &Password,
|
||||||
|
pool: &PgPool,
|
||||||
|
) -> Result<Option<i32>, StateError> {
|
||||||
|
let result = query!(
|
||||||
|
"SELECT user_id, pwhash_data FROM users
|
||||||
|
WHERE username = $1::varchar",
|
||||||
|
username as &Username
|
||||||
|
)
|
||||||
|
.fetch_one(pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let login_attempt = HashedPassword::from_slice(&result.pwhash_data)
|
||||||
|
.map(|pw| pwhash_verify(&pw, password.as_bytes()))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
if login_attempt {
|
||||||
|
Ok(Some(result.user_id))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Market operation implementation
|
||||||
|
impl AppState {
|
||||||
|
pub(crate) fn buy(&self, stock_name: StockName) -> OperationResult {
|
||||||
|
// self.assert_state(AppState::Authorized)?;
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn sell(&self, stock_name: StockName) -> OperationResult {
|
||||||
|
// self.assert_state(AppState::Authorized)?;
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AppState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
|
@ -1 +1,36 @@
|
||||||
pub(crate) struct ApiKey(String);
|
use serde::Deserialize;
|
||||||
|
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash, MEMLIMIT_INTERACTIVE, OPSLIMIT_INTERACTIVE};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize, sqlx::Type)]
|
||||||
|
#[sqlx(transparent)]
|
||||||
|
pub(crate) struct ApiKey(pub(crate) Uuid);
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize, sqlx::Type)]
|
||||||
|
#[sqlx(transparent)]
|
||||||
|
pub(crate) struct Username(String);
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)]
|
||||||
|
pub(crate) struct Password(String);
|
||||||
|
|
||||||
|
#[derive(sqlx::Type)]
|
||||||
|
#[sqlx(transparent)]
|
||||||
|
pub(crate) struct SaltedPasswordData(pub(crate) Vec<u8>);
|
||||||
|
|
||||||
|
impl Password {
|
||||||
|
pub(crate) fn salt(self) -> Result<SaltedPasswordData, ()> {
|
||||||
|
Ok(SaltedPasswordData(
|
||||||
|
pwhash(
|
||||||
|
self.0.as_bytes(),
|
||||||
|
OPSLIMIT_INTERACTIVE,
|
||||||
|
MEMLIMIT_INTERACTIVE,
|
||||||
|
)?
|
||||||
|
.as_ref()
|
||||||
|
.to_vec(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn as_bytes(&self) -> &[u8] {
|
||||||
|
self.0.as_bytes()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue