#![warn(clippy::nursery, clippy::pedantic)] use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::sync::Arc; use axum::http::StatusCode; use paste::Expiration; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; use rocksdb::IteratorMode; use rocksdb::WriteBatch; use rocksdb::{Options, DB}; use short_code::ShortCode; use anyhow::Result; use axum::body::Bytes; use axum::extract::{Extension, Path, TypedHeader}; use axum::handler::{get, post}; use axum::{AddExtensionLayer, Router}; use tokio::task; use tracing::warn; use tracing::{error, instrument}; use crate::paste::Paste; use crate::time::FIVE_MINUTES; mod paste; mod short_code; mod time; #[tokio::main] async fn main() -> Result<()> { const DB_PATH: &str = "database"; const SHORT_CODE_SIZE: usize = 12; tracing_subscriber::fmt::init(); let db = Arc::new(DB::open_default(DB_PATH)?); let stop_signal = Arc::new(AtomicBool::new(false)); task::spawn(cleanup(Arc::clone(&stop_signal), Arc::clone(&db))); axum::Server::bind(&"0.0.0.0:8080".parse()?) .serve( Router::new() .route("/", post(upload::)) .route( "/:code", get(paste::).delete(delete::), ) .layer(AddExtensionLayer::new(db)) .layer(AddExtensionLayer::new(StdRng::from_entropy())) .into_make_service(), ) .await?; stop_signal.store(true, Ordering::Release); // Must be called for correct shutdown DB::destroy(&Options::default(), DB_PATH)?; Ok(()) } #[instrument(skip(db), err)] async fn upload( Extension(db): Extension>, Extension(mut rng): Extension, maybe_expires: Option>, body: Bytes, ) -> Result, StatusCode> { if body.is_empty() { return Err(StatusCode::BAD_REQUEST); } // 3GB max; this is a soft-limit of RocksDb if body.len() >= 3_221_225_472 { return Err(StatusCode::PAYLOAD_TOO_LARGE); } let paste = Paste::new(maybe_expires.map(|v| v.0).unwrap_or_default(), body); let mut new_key = None; // Try finding a code; give up after 1000 attempts // Statistics show that this is very unlikely to happen for _ in 0..1000 { let code: ShortCode = rng.sample(short_code::Generator); let db = Arc::clone(&db); let key = code.as_bytes(); let query = task::spawn_blocking(move || db.key_may_exist(key)).await; if matches!(query, Ok(false)) { new_key = Some(key); } } let key = if let Some(key) = new_key { key } else { error!("Failed to generate a valid shortcode"); return Err(StatusCode::INTERNAL_SERVER_ERROR); }; let value = if let Ok(v) = bincode::serialize(&paste) { v } else { error!("Failed to serialize paste?!"); return Err(StatusCode::INTERNAL_SERVER_ERROR); }; match task::spawn_blocking(move || db.put(key, value)).await { Ok(Ok(_)) => (), e => { error!("Failed to insert paste into db: {:?}", e); return Err(StatusCode::INTERNAL_SERVER_ERROR); } } Ok(Vec::from(key)) } #[instrument(skip(db), err)] async fn paste( Extension(db): Extension>, Path(url): Path>, ) -> Result { let key = url.as_bytes(); let parsed: Paste = { // not sure if perf of get_pinned is better than spawn_blocking let query_result = db.get_pinned(key).map_err(|e| { error!("Failed to fetch initial query: {}", e); StatusCode::INTERNAL_SERVER_ERROR })?; let data = match query_result { Some(data) => data, None => return Err(StatusCode::NOT_FOUND), }; bincode::deserialize(&data).map_err(|_| { error!("Failed to deserialize data?!"); StatusCode::INTERNAL_SERVER_ERROR })? }; if parsed.expired() { let join_handle = task::spawn_blocking(move || db.delete(key)) .await .map_err(|e| { error!("Failed to join handle: {}", e); StatusCode::INTERNAL_SERVER_ERROR })?; join_handle.map_err(|e| { error!("Failed to delete expired paste: {}", e); StatusCode::INTERNAL_SERVER_ERROR })?; return Err(StatusCode::NOT_FOUND); } if parsed.is_burn_after_read() { let join_handle = task::spawn_blocking(move || db.delete(key)) .await .map_err(|e| { error!("Failed to join handle: {}", e); StatusCode::INTERNAL_SERVER_ERROR })?; join_handle.map_err(|e| { error!("Failed to burn paste after read: {}", e); StatusCode::INTERNAL_SERVER_ERROR })?; } Ok(parsed.bytes) } #[instrument(skip(db))] async fn delete( Extension(db): Extension>, Path(url): Path>, ) -> StatusCode { match task::spawn_blocking(move || db.delete(url.as_bytes())).await { Ok(Ok(_)) => StatusCode::OK, _ => StatusCode::INTERNAL_SERVER_ERROR, } } /// Periodic clean-up task that deletes expired entries. async fn cleanup(stop_signal: Arc, db: Arc) { while !stop_signal.load(Ordering::Acquire) { tokio::time::sleep(*FIVE_MINUTES).await; let mut batch = WriteBatch::default(); for (key, value) in db.snapshot().iterator(IteratorMode::Start) { // TODO: only partially decode struct for max perf let join_handle = task::spawn_blocking(move || { bincode::deserialize::(&value) .as_ref() .map(Paste::expired) .unwrap_or_default() }) .await; let should_delete = match join_handle { Ok(should_delete) => should_delete, Err(e) => { error!("Failed to join thread?! {}", e); false } }; if should_delete { batch.delete(key); } } let db = Arc::clone(&db); let join_handle = task::spawn_blocking(move || db.write(batch)).await; let db_op_res = match join_handle { Ok(res) => res, Err(e) => { error!("Failed to join handle?! {}", e); continue; } }; if let Err(e) = db_op_res { warn!("Failed to cleanup db: {}", e); } } }