use crate::operations::ServerOperation; use anyhow::{bail, Result}; use args::Args; use bytes::BytesMut; use clap::Clap; use log::{debug, error, info}; use market::Market; use operations::{MarketOperation, QueryOperation, UserOperation}; use simple_logger::SimpleLogger; use sqlx::{postgres::PgConnectOptions, ConnectOptions, PgPool}; use state::AppState; use std::{str::FromStr, sync::Arc}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; mod args; mod market; mod operations; mod state; mod user; #[tokio::main] async fn main() -> Result<()> { // Must init dotenv before initializing args, else args won't default from // env properly dotenv::dotenv().ok(); let args: Args = Args::parse(); // If we can't successfully initialize our crypto library, fail immediately. sodiumoxide::init().expect("to initialize crypto library"); SimpleLogger::default().init()?; let mut listener_stream = TcpListener::bind(&args.bind_address) .await .map(TcpListenerStream::new)?; info!("Successfully bound to port."); let db_pool = { let mut connect_options = PgConnectOptions::from_str(&args.database_url)?; connect_options.log_statements(args.db_log_level); PgPool::connect_with(connect_options).await? }; info!("Successfully established connection to database."); let market = Arc::new(Market::new()); while let Some(Ok(stream)) = listener_stream.next().await { // clone simply clones the arc reference, so this is cheap. let db_pool = db_pool.clone(); let market = Arc::clone(&market); tokio::task::spawn(async { if let Err(e) = handle_stream(stream, db_pool, market).await { error!("{}", e); } }); } info!("Cleanly shut down. Goodbye!"); Ok(()) } async fn handle_stream(mut socket: TcpStream, pool: PgPool, mut market: Arc) -> Result<()> { // only accept data that can fit in 256 bytes // the assumption is that a single request should always be within n bytes, // otherwise there's a good chance that it's mal{formed,icious}. let mut buffer = BytesMut::with_capacity(256); let mut state = AppState::new(); loop { let bytes_read = socket.read_buf(&mut buffer).await?; if bytes_read == 0 { bail!("Failed to read bytes, assuming socket is closed.") } let data = buffer.split_to(bytes_read); // O(1) dbg!(&data); let iter = serde_json::Deserializer::from_slice(&data).into_iter(); for messages in iter { let message = match messages { Ok(p) => p, Err(e) if e.is_eof() || e.is_io() => return Ok(()), Err(e) => return Err(e.into()), }; debug!("Parsed operation: {:?}", message); let response = match message { ServerOperation::Query(op) => match op { QueryOperation::StockInfo { stock } => { AppState::stock_info(stock, &pool).await? } QueryOperation::User { username } => { AppState::user_info(username, &pool).await? } }, ServerOperation::User(op) => match op { UserOperation::Login { api_key } => state.login(api_key, &pool).await?, UserOperation::Register { username, password } => { state.register(username, password, &pool).await? } UserOperation::GetKey { username, password } => { state.generate_api_key(username, password, &pool).await? } }, ServerOperation::Market(op) => match op { MarketOperation::Buy { symbol, amount } => { state.buy(symbol, amount, &pool, &mut market).await? } MarketOperation::Sell { symbol, amount } => { state.sell(symbol, amount, &pool, &mut market).await? } }, }; socket .write_all(serde_json::to_string(&response).unwrap().as_bytes()) .await?; } buffer.unsplit(data); // O(1) buffer.clear(); } }