125 lines
4.4 KiB
Rust
125 lines
4.4 KiB
Rust
use crate::operations::ServerOperation;
|
|
use anyhow::{bail, Result};
|
|
use args::Args;
|
|
use bytes::BytesMut;
|
|
use clap::Clap;
|
|
use log::{debug, error, info};
|
|
use market::Market;
|
|
use operations::{MarketOperation, QueryOperation, UserOperation};
|
|
use simple_logger::SimpleLogger;
|
|
use sqlx::{postgres::PgConnectOptions, ConnectOptions, PgPool};
|
|
use state::AppState;
|
|
use std::{str::FromStr, sync::Arc};
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
|
|
|
|
mod args;
|
|
mod market;
|
|
mod operations;
|
|
mod state;
|
|
mod user;
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
// Must init dotenv before initializing args, else args won't default from
|
|
// env properly
|
|
dotenv::dotenv().ok();
|
|
let args: Args = Args::parse();
|
|
|
|
// If we can't successfully initialize our crypto library, fail immediately.
|
|
sodiumoxide::init().expect("to initialize crypto library");
|
|
SimpleLogger::default().init()?;
|
|
|
|
let mut listener_stream = TcpListener::bind(&args.bind_address)
|
|
.await
|
|
.map(TcpListenerStream::new)?;
|
|
info!("Successfully bound to port.");
|
|
|
|
let db_pool = {
|
|
let mut connect_options = PgConnectOptions::from_str(&args.database_url)?;
|
|
connect_options.log_statements(args.db_log_level);
|
|
PgPool::connect_with(connect_options).await?
|
|
};
|
|
info!("Successfully established connection to database.");
|
|
|
|
let market = Arc::new(Market::new());
|
|
|
|
while let Some(Ok(stream)) = listener_stream.next().await {
|
|
// clone simply clones the arc reference, so this is cheap.
|
|
let db_pool = db_pool.clone();
|
|
let market = Arc::clone(&market);
|
|
tokio::task::spawn(async {
|
|
if let Err(e) = handle_stream(stream, db_pool, market).await {
|
|
error!("{}", e);
|
|
}
|
|
});
|
|
}
|
|
|
|
info!("Cleanly shut down. Goodbye!");
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_stream(mut socket: TcpStream, pool: PgPool, mut market: Arc<Market>) -> Result<()> {
|
|
// only accept data that can fit in 256 bytes
|
|
// the assumption is that a single request should always be within n bytes,
|
|
// otherwise there's a good chance that it's mal{formed,icious}.
|
|
let mut buffer = BytesMut::with_capacity(256);
|
|
let mut state = AppState::new();
|
|
loop {
|
|
let bytes_read = socket.read_buf(&mut buffer).await?;
|
|
if bytes_read == 0 {
|
|
bail!("Failed to read bytes, assuming socket is closed.")
|
|
}
|
|
|
|
let data = buffer.split_to(bytes_read); // O(1)
|
|
dbg!(&data);
|
|
let iter = serde_json::Deserializer::from_slice(&data).into_iter();
|
|
|
|
for messages in iter {
|
|
let message = match messages {
|
|
Ok(p) => p,
|
|
Err(e) if e.is_eof() || e.is_io() => return Ok(()),
|
|
Err(e) => return Err(e.into()),
|
|
};
|
|
|
|
debug!("Parsed operation: {:?}", message);
|
|
|
|
let response = match message {
|
|
ServerOperation::Query(op) => match op {
|
|
QueryOperation::StockInfo { stock } => {
|
|
AppState::stock_info(stock, &pool).await?
|
|
}
|
|
QueryOperation::User { username } => {
|
|
AppState::user_info(username, &pool).await?
|
|
}
|
|
},
|
|
ServerOperation::User(op) => match op {
|
|
UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
|
|
UserOperation::Register { username, password } => {
|
|
state.register(username, password, &pool).await?
|
|
}
|
|
UserOperation::GetKey { username, password } => {
|
|
state.generate_api_key(username, password, &pool).await?
|
|
}
|
|
},
|
|
ServerOperation::Market(op) => match op {
|
|
MarketOperation::Buy { symbol, amount } => {
|
|
state.buy(symbol, amount, &pool, &mut market).await?
|
|
}
|
|
MarketOperation::Sell { symbol, amount } => {
|
|
state.sell(symbol, amount, &pool, &mut market).await?
|
|
}
|
|
},
|
|
};
|
|
socket
|
|
.write_all(serde_json::to_string(&response).unwrap().as_bytes())
|
|
.await?;
|
|
}
|
|
|
|
buffer.unsplit(data); // O(1)
|
|
buffer.clear();
|
|
}
|
|
}
|