use crate::operations::ServerOperation; use anyhow::{bail, Result}; use bytes::BytesMut; use log::{debug, error, info}; use operations::{MarketOperation, QueryOperation, UserOperation}; use sqlx::PgPool; use state::AppState; use std::env; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; mod market; mod operations; mod state; mod user; #[tokio::main] async fn main() -> Result<()> { dotenv::dotenv().ok(); // If we can't successfully initialize our crypto library, fail immediately. sodiumoxide::init().unwrap(); simple_logger::SimpleLogger::default().init()?; let mut listener_stream = TcpListener::bind(&env::var("BIND_ADDRESS")?) .await .map(TcpListenerStream::new)?; info!("Successfully bound to port."); let db_pool = PgPool::connect(&env::var("DATABASE_URL")?).await?; info!("Successfully established connection to database."); while let Some(Ok(stream)) = listener_stream.next().await { // clone simply clones the arc reference, so this is cheap. let db_pool = db_pool.clone(); tokio::task::spawn(async { match handle_stream(stream, db_pool).await { Ok(_) => (), Err(e) => { // stream.write_all(e); error!("{}", e); } } }); } info!("Cleanly shut down. Goodbye!"); Ok(()) } async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> { // only accept data that can fit in 256 bytes // the assumption is that a single request should always be within n bytes, // otherwise there's a good chance that it's mal{formed,icious}. let mut buffer = BytesMut::with_capacity(256); let mut state = AppState::new(); loop { let bytes_read = socket.read_buf(&mut buffer).await?; if bytes_read == 0 { bail!("Failed to read bytes, assuming socket is closed.") } let data = buffer.split_to(bytes_read); // O(1) let parsed = serde_json::from_slice::(&data)?; debug!("Parsed operation: {:?}", parsed); let response = match parsed { ServerOperation::Query(op) => match op { QueryOperation::StockInfo { stock } => state.stock_info(stock, &pool).await?, QueryOperation::User(username) => state.user_info(username, &pool).await?, }, ServerOperation::User(op) => match op { UserOperation::Login { api_key } => state.login(api_key, &pool).await?, UserOperation::Register { username, password } => { state.register(username, password, &pool).await? } UserOperation::GetKey { username, password } => { state.generate_api_key(username, password, &pool).await? } }, ServerOperation::Market(op) => match op { MarketOperation::Buy(stock_name) => state.buy(stock_name, &pool)?, MarketOperation::Sell(stock_name) => state.sell(stock_name, &pool)?, }, }; socket .write_all(serde_json::to_string(&response).unwrap().as_bytes()) .await?; buffer.unsplit(data); // O(1) buffer.clear(); } }