get user login working
This commit is contained in:
parent
c446508829
commit
0bb86c79e3
12 changed files with 1365 additions and 44 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1 +1,2 @@
|
|||
**/target
|
||||
**/.env
|
1015
Cargo.lock
generated
1015
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -7,3 +7,4 @@ edition = "2018"
|
|||
[dependencies]
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
thiserror = "1"
|
||||
uuid = { version = "0.8", features = ["serde", "v4"] }
|
|
@ -1,3 +1,4 @@
|
|||
pub mod net;
|
||||
pub mod operations;
|
||||
pub mod stock;
|
||||
|
||||
|
|
18
vtse-common/src/net.rs
Normal file
18
vtse-common/src/net.rs
Normal file
|
@ -0,0 +1,18 @@
|
|||
use serde::Serialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ServerResponse {
|
||||
/// Generic success
|
||||
Success,
|
||||
NewApiKey(Uuid),
|
||||
UserError(UserError),
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub enum UserError {
|
||||
InvalidPassword,
|
||||
InvalidApiKey,
|
||||
NotAuthorized,
|
||||
}
|
|
@ -5,13 +5,18 @@ authors = ["Edward Shen <code@eddie.sh>"]
|
|||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
vtse-common = { path = "../vtse-common" }
|
||||
log = "0.4"
|
||||
simple_logger = "1.11"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-stream = { version = "0.1", features = ["net"] }
|
||||
anyhow = "1"
|
||||
uuid = { version = "0.8", features = ["v4"] }
|
||||
bytes = "1"
|
||||
dotenv = "0.15"
|
||||
log = "0.4"
|
||||
rand = "0.8"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
bytes = "1"
|
||||
simple_logger = "1.11"
|
||||
sodiumoxide = "0.2"
|
||||
sqlx = { version = "0.5", features = [ "runtime-tokio-rustls", "macros", "uuid", "postgres" ] }
|
||||
thiserror = "1"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-stream = { version = "0.1", features = ["net"] }
|
||||
uuid = { version = "0.8", features = ["v4"] }
|
||||
vtse-common = { path = "../vtse-common" }
|
|
@ -1,30 +1,47 @@
|
|||
use crate::operations::ServerOperation;
|
||||
use anyhow::anyhow;
|
||||
use anyhow::{bail, Result};
|
||||
use bytes::BytesMut;
|
||||
use log::{error, info};
|
||||
use tokio::io::AsyncReadExt;
|
||||
use log::{debug, error, info};
|
||||
use operations::{MarketOperation, QueryOperation, UserOperation};
|
||||
use sqlx::PgPool;
|
||||
use state::AppState;
|
||||
use std::env;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
|
||||
|
||||
mod market;
|
||||
mod operations;
|
||||
|
||||
pub(crate) type Result<T> = std::result::Result<T, anyhow::Error>;
|
||||
mod state;
|
||||
mod user;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
dotenv::dotenv().ok();
|
||||
// If we can't successfully initialize our crypto library, fail immediately.
|
||||
sodiumoxide::init().unwrap();
|
||||
simple_logger::SimpleLogger::default().init()?;
|
||||
|
||||
let mut listener_stream = TcpListener::bind("localhost:8080")
|
||||
let mut listener_stream = TcpListener::bind(&env::var("BIND_ADDRESS")?)
|
||||
.await
|
||||
.map(TcpListenerStream::new)?;
|
||||
|
||||
info!("Successfully bound to port");
|
||||
info!("Successfully bound to port.");
|
||||
|
||||
let db_pool = PgPool::connect(&env::var("DATABASE_URL")?).await?;
|
||||
|
||||
info!("Successfully established connection to database.");
|
||||
|
||||
while let Some(Ok(stream)) = listener_stream.next().await {
|
||||
// clone simply clones the arc reference, so this is cheap.
|
||||
let db_pool = db_pool.clone();
|
||||
tokio::task::spawn(async {
|
||||
match handle_stream(stream).await {
|
||||
match handle_stream(stream, db_pool).await {
|
||||
Ok(_) => (),
|
||||
Err(e) => error!("{}", e),
|
||||
Err(e) => {
|
||||
// stream.write_all(e);
|
||||
error!("{}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -34,27 +51,44 @@ async fn main() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_stream(mut socket: TcpStream) -> Result<()> {
|
||||
async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
|
||||
// only accept data that can fit in 256 bytes
|
||||
// the assumption is that a single request should always be within n bytes,
|
||||
// otherwise there's a good chance that it's mal{formed,icious}.
|
||||
let mut buffer = BytesMut::with_capacity(256);
|
||||
let mut state = AppState::new();
|
||||
loop {
|
||||
let bytes_read = socket.read_buf(&mut buffer).await?;
|
||||
match bytes_read {
|
||||
0 => return Err(anyhow!("Failed to read bytes, assuming socket is closed.")),
|
||||
n => {
|
||||
let data = buffer.split_to(n); // O(1)
|
||||
if bytes_read == 0 {
|
||||
bail!("Failed to read bytes, assuming socket is closed.")
|
||||
}
|
||||
let data = buffer.split_to(bytes_read); // O(1)
|
||||
let parsed = serde_json::from_slice::<ServerOperation>(&data)?;
|
||||
match parsed {
|
||||
ServerOperation::Query(op) => {
|
||||
dbg!(op);
|
||||
}
|
||||
ServerOperation::Meta(op) => {
|
||||
dbg!(op);
|
||||
debug!("Parsed operation: {:?}", parsed);
|
||||
let response = match parsed {
|
||||
ServerOperation::Query(op) => match op {
|
||||
QueryOperation::StockInfo { stock } => todo!(),
|
||||
QueryOperation::User(_) => todo!(),
|
||||
},
|
||||
ServerOperation::User(op) => match op {
|
||||
UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
|
||||
UserOperation::Register { username, password } => {
|
||||
state.register(username, password, &pool).await?
|
||||
}
|
||||
UserOperation::GetKey { username, password } => {
|
||||
state.generate_api_key(username, password, &pool).await?
|
||||
}
|
||||
},
|
||||
ServerOperation::Market(op) => match op {
|
||||
MarketOperation::Buy(stock_name) => state.buy(stock_name)?,
|
||||
MarketOperation::Sell(stock_name) => state.sell(stock_name)?,
|
||||
},
|
||||
};
|
||||
|
||||
socket
|
||||
.write_all(serde_json::to_string(&response).unwrap().as_bytes())
|
||||
.await?;
|
||||
buffer.unsplit(data); // O(1)
|
||||
}
|
||||
}
|
||||
buffer.clear();
|
||||
}
|
||||
}
|
||||
|
|
37
vtse-server/src/market/generator.rs
Normal file
37
vtse-server/src/market/generator.rs
Normal file
|
@ -0,0 +1,37 @@
|
|||
use rand::rngs::ThreadRng;
|
||||
use rand::thread_rng;
|
||||
use rand::{distributions::Uniform, prelude::Distribution};
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, PartialOrd, Debug)]
|
||||
pub(crate) struct MarketModifier(f64);
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct MarketGenerator {
|
||||
rng: ThreadRng,
|
||||
sample_range: Uniform<f64>,
|
||||
volatility: f64,
|
||||
old_price: MarketModifier,
|
||||
}
|
||||
|
||||
impl MarketGenerator {
|
||||
pub(crate) fn new(volatility: f64, initial_value: f64) -> Self {
|
||||
Self {
|
||||
rng: thread_rng(),
|
||||
sample_range: Uniform::new_inclusive(-0.5, 0.5),
|
||||
volatility,
|
||||
old_price: MarketModifier(initial_value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for MarketGenerator {
|
||||
type Item = MarketModifier;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
// Algorithm from https://stackoverflow.com/a/8597889
|
||||
let rng_multiplier = self.sample_range.sample(&mut self.rng);
|
||||
let delta = 2f64 * self.volatility * rng_multiplier;
|
||||
self.old_price.0 += self.old_price.0 * delta;
|
||||
Some(self.old_price)
|
||||
}
|
||||
}
|
1
vtse-server/src/market/mod.rs
Normal file
1
vtse-server/src/market/mod.rs
Normal file
|
@ -0,0 +1 @@
|
|||
mod generator;
|
|
@ -1,23 +1,43 @@
|
|||
use serde::Deserialize;
|
||||
use vtse_common::stock::StockName;
|
||||
|
||||
use crate::user::{ApiKey, Password, Username};
|
||||
|
||||
#[derive(Deserialize, Debug, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub(crate) enum ServerOperation {
|
||||
Query(QueryOperation),
|
||||
Meta(MetaOperation),
|
||||
User(UserOperation),
|
||||
Market(MarketOperation),
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub(crate) enum QueryOperation {
|
||||
StockInfo { stock: StockName },
|
||||
User(Username),
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub(crate) enum MetaOperation {
|
||||
Register { username: String, password: String },
|
||||
pub(crate) enum UserOperation {
|
||||
Login {
|
||||
api_key: ApiKey,
|
||||
},
|
||||
Register {
|
||||
username: Username,
|
||||
password: Password,
|
||||
},
|
||||
GetKey {
|
||||
username: Username,
|
||||
password: Password,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, PartialEq)]
|
||||
pub(crate) enum MarketOperation {
|
||||
Buy(StockName),
|
||||
Sell(StockName),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
167
vtse-server/src/state.rs
Normal file
167
vtse-server/src/state.rs
Normal file
|
@ -0,0 +1,167 @@
|
|||
use crate::user::{ApiKey, Password, Username};
|
||||
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash_verify, HashedPassword};
|
||||
use sqlx::{query, PgPool};
|
||||
use thiserror::Error;
|
||||
use vtse_common::{
|
||||
net::{ServerResponse, UserError},
|
||||
stock::StockName,
|
||||
};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub(crate) enum StateError {
|
||||
#[error("Was not in the correct state")]
|
||||
WrongState,
|
||||
#[error("Got SQLx error: {0}`")]
|
||||
Database(#[from] sqlx::Error),
|
||||
#[error("Failed to hash password")]
|
||||
PasswordHash,
|
||||
}
|
||||
|
||||
type OperationResult = Result<ServerResponse, StateError>;
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
|
||||
pub(crate) enum AppState {
|
||||
Unauthorized,
|
||||
Authorized { user_id: i32 },
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self::Unauthorized
|
||||
}
|
||||
|
||||
fn assert_state(&self, expected_state: AppState) -> Result<(), StateError> {
|
||||
if *self != expected_state {
|
||||
Err(StateError::WrongState)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// User operation implementation
|
||||
impl AppState {
|
||||
pub(crate) async fn login(&mut self, api_key: ApiKey, pool: &PgPool) -> OperationResult {
|
||||
self.assert_state(AppState::Unauthorized)?;
|
||||
let user_id = {
|
||||
let query = query!(
|
||||
"SELECT (user_id) FROM user_api_keys
|
||||
JOIN api_keys ON user_api_keys.key_id = api_keys.key_id
|
||||
WHERE key = $1",
|
||||
api_key.0
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
match query {
|
||||
Some(id) => id.user_id,
|
||||
None => return Ok(ServerResponse::UserError(UserError::InvalidApiKey)),
|
||||
}
|
||||
};
|
||||
*self = AppState::Authorized { user_id };
|
||||
Ok(ServerResponse::Success)
|
||||
}
|
||||
|
||||
pub(crate) async fn register(
|
||||
&mut self,
|
||||
username: Username,
|
||||
password: Password,
|
||||
pool: &PgPool,
|
||||
) -> OperationResult {
|
||||
self.assert_state(AppState::Unauthorized)?;
|
||||
query!(
|
||||
"INSERT INTO users
|
||||
(username, pwhash_data)
|
||||
VALUES ($1::varchar, $2::bytea)",
|
||||
username as Username,
|
||||
password.salt().map_err(|_| StateError::PasswordHash)?.0,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(ServerResponse::Success)
|
||||
}
|
||||
|
||||
pub(crate) async fn generate_api_key(
|
||||
&self,
|
||||
username: Username,
|
||||
password: Password,
|
||||
pool: &PgPool,
|
||||
) -> OperationResult {
|
||||
self.assert_state(AppState::Unauthorized)?;
|
||||
match self.validate_password(&username, &password, pool).await {
|
||||
Ok(None) => Ok(ServerResponse::UserError(UserError::InvalidPassword)),
|
||||
Ok(Some(user_id)) => self.create_api_key(user_id, pool).await,
|
||||
Err(_) => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_api_key(&self, user_id: i32, pool: &PgPool) -> OperationResult {
|
||||
let mut transaction = pool.begin().await?;
|
||||
let api_key = uuid::Uuid::new_v4();
|
||||
let api_key_id = query!(
|
||||
"INSERT INTO api_keys (key)
|
||||
VALUES ($1)
|
||||
RETURNING api_keys.key_id;",
|
||||
&api_key,
|
||||
)
|
||||
.fetch_one(&mut transaction)
|
||||
.await?
|
||||
.key_id;
|
||||
|
||||
query!(
|
||||
"INSERT INTO user_api_keys (user_id, key_id)
|
||||
VALUES ($1, $2)",
|
||||
user_id,
|
||||
api_key_id
|
||||
)
|
||||
.execute(&mut transaction)
|
||||
.await?;
|
||||
|
||||
transaction.commit().await?;
|
||||
|
||||
Ok(ServerResponse::NewApiKey(api_key))
|
||||
}
|
||||
|
||||
async fn validate_password(
|
||||
&self,
|
||||
username: &Username,
|
||||
password: &Password,
|
||||
pool: &PgPool,
|
||||
) -> Result<Option<i32>, StateError> {
|
||||
let result = query!(
|
||||
"SELECT user_id, pwhash_data FROM users
|
||||
WHERE username = $1::varchar",
|
||||
username as &Username
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let login_attempt = HashedPassword::from_slice(&result.pwhash_data)
|
||||
.map(|pw| pwhash_verify(&pw, password.as_bytes()))
|
||||
.unwrap();
|
||||
|
||||
if login_attempt {
|
||||
Ok(Some(result.user_id))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Market operation implementation
|
||||
impl AppState {
|
||||
pub(crate) fn buy(&self, stock_name: StockName) -> OperationResult {
|
||||
// self.assert_state(AppState::Authorized)?;
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub(crate) fn sell(&self, stock_name: StockName) -> OperationResult {
|
||||
// self.assert_state(AppState::Authorized)?;
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
|
@ -1 +1,36 @@
|
|||
pub(crate) struct ApiKey(String);
|
||||
use serde::Deserialize;
|
||||
use sodiumoxide::crypto::pwhash::argon2id13::{pwhash, MEMLIMIT_INTERACTIVE, OPSLIMIT_INTERACTIVE};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize, sqlx::Type)]
|
||||
#[sqlx(transparent)]
|
||||
pub(crate) struct ApiKey(pub(crate) Uuid);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize, sqlx::Type)]
|
||||
#[sqlx(transparent)]
|
||||
pub(crate) struct Username(String);
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, Deserialize)]
|
||||
pub(crate) struct Password(String);
|
||||
|
||||
#[derive(sqlx::Type)]
|
||||
#[sqlx(transparent)]
|
||||
pub(crate) struct SaltedPasswordData(pub(crate) Vec<u8>);
|
||||
|
||||
impl Password {
|
||||
pub(crate) fn salt(self) -> Result<SaltedPasswordData, ()> {
|
||||
Ok(SaltedPasswordData(
|
||||
pwhash(
|
||||
self.0.as_bytes(),
|
||||
OPSLIMIT_INTERACTIVE,
|
||||
MEMLIMIT_INTERACTIVE,
|
||||
)?
|
||||
.as_ref()
|
||||
.to_vec(),
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn as_bytes(&self) -> &[u8] {
|
||||
self.0.as_bytes()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue