94 lines
3.3 KiB
Rust
94 lines
3.3 KiB
Rust
use crate::operations::ServerOperation;
|
|
use anyhow::{bail, Result};
|
|
use bytes::BytesMut;
|
|
use log::{debug, error, info};
|
|
use operations::{MarketOperation, QueryOperation, UserOperation};
|
|
use sqlx::PgPool;
|
|
use state::AppState;
|
|
use std::env;
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
|
|
|
|
mod market;
|
|
mod operations;
|
|
mod state;
|
|
mod user;
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
dotenv::dotenv().ok();
|
|
// If we can't successfully initialize our crypto library, fail immediately.
|
|
sodiumoxide::init().unwrap();
|
|
simple_logger::SimpleLogger::default().init()?;
|
|
|
|
let mut listener_stream = TcpListener::bind(&env::var("BIND_ADDRESS")?)
|
|
.await
|
|
.map(TcpListenerStream::new)?;
|
|
|
|
info!("Successfully bound to port.");
|
|
|
|
let db_pool = PgPool::connect(&env::var("DATABASE_URL")?).await?;
|
|
|
|
info!("Successfully established connection to database.");
|
|
|
|
while let Some(Ok(stream)) = listener_stream.next().await {
|
|
// clone simply clones the arc reference, so this is cheap.
|
|
let db_pool = db_pool.clone();
|
|
tokio::task::spawn(async {
|
|
match handle_stream(stream, db_pool).await {
|
|
Ok(_) => (),
|
|
Err(e) => {
|
|
// stream.write_all(e);
|
|
error!("{}", e);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
info!("Cleanly shut down. Goodbye!");
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_stream(mut socket: TcpStream, pool: PgPool) -> Result<()> {
|
|
// only accept data that can fit in 256 bytes
|
|
// the assumption is that a single request should always be within n bytes,
|
|
// otherwise there's a good chance that it's mal{formed,icious}.
|
|
let mut buffer = BytesMut::with_capacity(256);
|
|
let mut state = AppState::new();
|
|
loop {
|
|
let bytes_read = socket.read_buf(&mut buffer).await?;
|
|
if bytes_read == 0 {
|
|
bail!("Failed to read bytes, assuming socket is closed.")
|
|
}
|
|
let data = buffer.split_to(bytes_read); // O(1)
|
|
let parsed = serde_json::from_slice::<ServerOperation>(&data)?;
|
|
debug!("Parsed operation: {:?}", parsed);
|
|
let response = match parsed {
|
|
ServerOperation::Query(op) => match op {
|
|
QueryOperation::StockInfo { stock } => state.stock_info(stock, &pool).await?,
|
|
QueryOperation::User(username) => state.user_info(username, &pool).await?,
|
|
},
|
|
ServerOperation::User(op) => match op {
|
|
UserOperation::Login { api_key } => state.login(api_key, &pool).await?,
|
|
UserOperation::Register { username, password } => {
|
|
state.register(username, password, &pool).await?
|
|
}
|
|
UserOperation::GetKey { username, password } => {
|
|
state.generate_api_key(username, password, &pool).await?
|
|
}
|
|
},
|
|
ServerOperation::Market(op) => match op {
|
|
MarketOperation::Buy(stock_name) => state.buy(stock_name, &pool)?,
|
|
MarketOperation::Sell(stock_name) => state.sell(stock_name, &pool)?,
|
|
},
|
|
};
|
|
|
|
socket
|
|
.write_all(serde_json::to_string(&response).unwrap().as_bytes())
|
|
.await?;
|
|
buffer.unsplit(data); // O(1)
|
|
buffer.clear();
|
|
}
|
|
}
|