diff --git a/common/src/crypto.rs b/common/src/crypto.rs index e6fcdda..ba82812 100644 --- a/common/src/crypto.rs +++ b/common/src/crypto.rs @@ -67,6 +67,7 @@ impl Deref for Key { &self.0 } } + impl DerefMut for Key { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 diff --git a/server/src/main.rs b/server/src/main.rs index b9bc45b..4a66919 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -39,7 +39,7 @@ use rocksdb::{ColumnFamilyDescriptor, IteratorMode}; use rocksdb::{Options, DB}; use signal_hook::consts::SIGUSR1; use signal_hook_tokio::Signals; -use tokio::task; +use tokio::task::{self, JoinHandle}; use tower_http::services::ServeDir; use tracing::{error, instrument, trace}; use tracing::{info, warn}; @@ -76,7 +76,7 @@ async fn main() -> Result<()> { ], )?); - set_up_expirations(&db); + set_up_expirations::(&db); let signals = Signals::new(&[SIGUSR1])?; let signals_handle = signals.handle(); @@ -113,7 +113,7 @@ async fn main() -> Result<()> { // See https://link.eddie.sh/5JHlD #[allow(clippy::cognitive_complexity)] -fn set_up_expirations(db: &Arc) { +fn set_up_expirations(db: &Arc) { let mut corrupted = 0; let mut expired = 0; let mut pending = 0; @@ -124,23 +124,14 @@ fn set_up_expirations(db: &Arc) { let db_ref = Arc::clone(db); - let delete_entry = move |key: &[u8]| { - let blob_cf = db_ref.cf_handle(BLOB_CF_NAME).unwrap(); - let meta_cf = db_ref.cf_handle(META_CF_NAME).unwrap(); - if let Err(e) = db_ref.delete_cf(blob_cf, &key) { - warn!("{}", e); - } - if let Err(e) = db_ref.delete_cf(meta_cf, &key) { - warn!("{}", e); - } - }; - for (key, value) in db.iterator_cf(meta_cf, IteratorMode::Start) { + let key: [u8; N] = (*key).try_into().unwrap(); + let expiration = if let Ok(value) = bincode::deserialize::(&value) { value } else { corrupted += 1; - delete_entry(&key); + delete_entry(Arc::clone(&db_ref), key); continue; }; @@ -156,13 +147,13 @@ fn set_up_expirations(db: &Arc) { let sleep_duration = (expiration_time - Utc::now()).to_std().unwrap_or_default(); if sleep_duration == Duration::default() { expired += 1; - delete_entry(&key); + delete_entry(Arc::clone(&db_ref), key); } else { pending += 1; - let delete_entry_ref = delete_entry.clone(); - task::spawn_blocking(move || async move { + let db = Arc::clone(&db_ref); + task::spawn(async move { tokio::time::sleep(sleep_duration).await; - delete_entry_ref(&key); + delete_entry(db, key); }); } } @@ -270,17 +261,9 @@ async fn upload( { let sleep_duration = (expiration_time - Utc::now()).to_std().unwrap_or_default(); - - task::spawn_blocking(move || async move { + task::spawn(async move { tokio::time::sleep(sleep_duration).await; - let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap(); - let meta_cf = db.cf_handle(META_CF_NAME).unwrap(); - if let Err(e) = db.delete_cf(blob_cf, key) { - warn!("{}", e); - } - if let Err(e) = db.delete_cf(meta_cf, key) { - warn!("{}", e); - } + delete_entry(db, key); }); } } @@ -322,22 +305,10 @@ async fn paste( // Check if paste has expired. if let Expiration::UnixTime(expires) = metadata { if expires < Utc::now() { - task::spawn_blocking(move || { - let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap(); - let meta_cf = db.cf_handle(META_CF_NAME).unwrap(); - if let Err(e) = db.delete_cf(blob_cf, &key) { - warn!("{}", e); - } - if let Err(e) = db.delete_cf(meta_cf, &key) { - warn!("{}", e); - } - }) - .await - .map_err(|e| { + delete_entry(db, url.as_bytes()).await.map_err(|e| { error!("Failed to join handle: {}", e); StatusCode::INTERNAL_SERVER_ERROR - })?; - + })??; return Err(StatusCode::NOT_FOUND); } } @@ -366,31 +337,10 @@ async fn paste( metadata, Expiration::BurnAfterReading | Expiration::BurnAfterReadingWithDeadline(_) ) { - let join_handle = task::spawn_blocking(move || { - let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap(); - let meta_cf = db.cf_handle(META_CF_NAME).unwrap(); - if let Err(e) = db.delete_cf(blob_cf, url.as_bytes()) { - warn!("{}", e); - return Err(()); - } - - if let Err(e) = db.delete_cf(meta_cf, url.as_bytes()) { - warn!("{}", e); - return Err(()); - } - - Ok(()) - }) - .await - .map_err(|e| { + delete_entry(db, key).await.map_err(|e| { error!("Failed to join handle: {}", e); StatusCode::INTERNAL_SERVER_ERROR - })?; - - join_handle.map_err(|_| { - error!("Failed to burn paste after read"); - StatusCode::INTERNAL_SERVER_ERROR - })?; + })??; } let mut map = HeaderMap::new(); @@ -404,24 +354,24 @@ async fn delete( Extension(db): Extension>, Path(url): Path>, ) -> StatusCode { - match task::spawn_blocking(move || { - let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap(); - let meta_cf = db.cf_handle(META_CF_NAME).unwrap(); - if let Err(e) = db.delete_cf(blob_cf, url.as_bytes()) { - warn!("{}", e); - return Err(()); - } - - if let Err(e) = db.delete_cf(meta_cf, url.as_bytes()) { - warn!("{}", e); - return Err(()); - } - - Ok(()) - }) - .await - { + match delete_entry(db, url.as_bytes()).await { Ok(_) => StatusCode::OK, _ => StatusCode::INTERNAL_SERVER_ERROR, } } + +fn delete_entry(db: Arc, key: [u8; N]) -> JoinHandle> { + task::spawn_blocking(move || { + let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap(); + let meta_cf = db.cf_handle(META_CF_NAME).unwrap(); + if let Err(e) = db.delete_cf(blob_cf, &key) { + warn!("{}", e); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + if let Err(e) = db.delete_cf(meta_cf, &key) { + warn!("{}", e); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + Ok(()) + }) +}