vtse/vtse-server/src/main.rs

125 lines
4.4 KiB
Rust

use crate::operations::ServerOperation;
use anyhow::{bail, Result};
use args::Args;
use bytes::BytesMut;
use clap::Clap;
use log::{debug, error, info};
use market::Market;
use operations::{MarketOperation, QueryOperation, UserOperation};
use simple_logger::SimpleLogger;
use sqlx::{postgres::PgConnectOptions, ConnectOptions, PgPool};
use state::AppState;
use std::{str::FromStr, sync::Arc};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
mod args;
mod market;
mod operations;
mod state;
mod user;
#[tokio::main]
async fn main() -> Result<()> {
// Must init dotenv before initializing args, else args won't default from
// env properly
dotenv::dotenv().ok();
let args: Args = Args::parse();
// If we can't successfully initialize our crypto library, fail immediately.
sodiumoxide::init().expect("to initialize crypto library");
SimpleLogger::default().init()?;
let mut listener_stream = TcpListener::bind(&args.bind_address)
.await
.map(TcpListenerStream::new)?;
info!("Successfully bound to port.");
let db_pool = {
let mut connect_options = PgConnectOptions::from_str(&args.database_url)?;
connect_options.log_statements(args.db_log_level);
PgPool::connect_with(connect_options).await?
};
info!("Successfully established connection to database.");
let market = Arc::new(Market::new());
while let Some(Ok(stream)) = listener_stream.next().await {
// clone simply clones the arc reference, so this is cheap.
let db_pool = db_pool.clone();
let market = Arc::clone(&market);
tokio::task::spawn(async {
if let Err(e) = handle_stream(stream, db_pool, market).await {
error!("{}", e);
}
});
}
info!("Cleanly shut down. Goodbye!");
Ok(())
}
async fn handle_stream(mut socket: TcpStream, pool: PgPool, mut market: Arc<Market>) -> Result<()> {
// only accept data that can fit in 256 bytes
// the assumption is that a single request should always be within n bytes,
// otherwise there's a good chance that it's mal{formed,icious}.
let mut buffer = BytesMut::with_capacity(256);
let mut state = AppState::new();
loop {
let bytes_read = socket.read_buf(&mut buffer).await?;
if bytes_read == 0 {
bail!("Failed to read bytes, assuming socket is closed.")
}
let data = buffer.split_to(bytes_read); // O(1)
dbg!(&data);
let iter = serde_json::Deserializer::from_slice(&data).into_iter();
for messages in iter {
let message = match messages {
Ok(p) => p,
Err(e) if e.is_eof() || e.is_io() => return Ok(()),
Err(e) => return Err(e.into()),
};
debug!("Parsed operation: {:?}", message);
let response = match message {
ServerOperation::Query(op) => match op {
QueryOperation::StockInfo { stock } => {
AppState::stock_info(stock, &pool).await?
}
QueryOperation::User { username } => {
AppState::user_info(username, &pool).await?
}
},
ServerOperation::User(op) => match op {
UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
UserOperation::Register { username, password } => {
state.register(username, password, &pool).await?
}
UserOperation::GetKey { username, password } => {
state.generate_api_key(username, password, &pool).await?
}
},
ServerOperation::Market(op) => match op {
MarketOperation::Buy { symbol, amount } => {
state.buy(symbol, amount, &pool, &mut market).await?
}
MarketOperation::Sell { symbol, amount } => {
state.sell(symbol, amount, &pool, &mut market).await?
}
},
};
socket
.write_all(serde_json::to_string(&response).unwrap().as_bytes())
.await?;
}
buffer.unsplit(data); // O(1)
buffer.clear();
}
}