diff --git a/Cargo.lock b/Cargo.lock index efa9c21..29d1db4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -947,6 +947,7 @@ dependencies = [ "bytes", "chrono", "headers", + "lazy_static", "omegaupload-common", "rand", "rocksdb", diff --git a/cli/src/main.rs b/cli/src/main.rs index 4df629c..c8b1092 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -7,7 +7,9 @@ use anyhow::{anyhow, bail, Context, Result}; use atty::Stream; use clap::Parser; use omegaupload_common::crypto::{gen_key_nonce, open_in_place, seal_in_place, Key}; -use omegaupload_common::{base64, hash, Expiration, ParsedUrl, Url, API_ENDPOINT}; +use omegaupload_common::{ + base64, hash, Expiration, ParsedUrl, Url, API_ENDPOINT, EXPIRATION_HEADER_NAME, +}; use reqwest::blocking::Client; use reqwest::header::EXPIRES; use reqwest::StatusCode; @@ -28,6 +30,8 @@ enum Action { /// public access. #[clap(short, long)] password: Option, + #[clap(short, long)] + duration: Option, }, Download { /// The paste to download. @@ -39,14 +43,22 @@ fn main() -> Result<()> { let opts = Opts::parse(); match opts.action { - Action::Upload { url, password } => handle_upload(url, password), + Action::Upload { + url, + password, + duration, + } => handle_upload(url, password, duration), Action::Download { url } => handle_download(url), }?; Ok(()) } -fn handle_upload(mut url: Url, password: Option) -> Result<()> { +fn handle_upload( + mut url: Url, + password: Option, + duration: Option, +) -> Result<()> { url.set_fragment(None); if atty::is(Stream::Stdin) { @@ -76,11 +88,13 @@ fn handle_upload(mut url: Url, password: Option) -> Result<()> { (container, nonce, key, pw_used) }; - let res = Client::new() - .post(url.as_ref()) - .body(data) - .send() - .context("Request to server failed")?; + let mut res = Client::new().post(url.as_ref()); + + if let Some(duration) = duration { + res = res.header(&*EXPIRATION_HEADER_NAME, duration); + } + + let res = res.body(data).send().context("Request to server failed")?; if res.status() != StatusCode::OK { bail!("Upload failed. Got HTTP error {}", res.status()); @@ -104,11 +118,8 @@ fn handle_upload(mut url: Url, password: Option) -> Result<()> { } fn handle_download(mut url: ParsedUrl) -> Result<()> { - url.sanitized_url.set_path(&dbg!(format!( - "{}{}", - API_ENDPOINT, - url.sanitized_url.path() - ))); + url.sanitized_url + .set_path(&format!("{}{}", API_ENDPOINT, url.sanitized_url.path())); let res = Client::new() .get(url.sanitized_url) .send() diff --git a/common/src/lib.rs b/common/src/lib.rs index 965b956..36ed24c 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -224,13 +224,31 @@ impl FromStr for ParsedUrl { #[derive(Serialize, Deserialize, Clone, Copy, Debug)] pub enum Expiration { BurnAfterReading, + BurnAfterReadingWithDeadline(DateTime), UnixTime(DateTime), } +// This impl is used for the CLI +impl FromStr for Expiration { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "read" => Ok(Self::BurnAfterReading), + "5m" => Ok(Self::UnixTime(Utc::now() + Duration::minutes(5))), + "10m" => Ok(Self::UnixTime(Utc::now() + Duration::minutes(10))), + "1h" => Ok(Self::UnixTime(Utc::now() + Duration::hours(1))), + "1d" => Ok(Self::UnixTime(Utc::now() + Duration::days(1))), + // We disallow permanent pastes + _ => Err(s.to_owned()), + } + } +} + impl Display for Expiration { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Expiration::BurnAfterReading => { + Expiration::BurnAfterReading | Expiration::BurnAfterReadingWithDeadline(_) => { write!(f, "This item has been burned. You now have the only copy.") } Expiration::UnixTime(time) => write!( @@ -256,19 +274,9 @@ impl Header for Expiration { Self: Sized, I: Iterator, { - match values - .next() - .ok_or_else(headers::Error::invalid)? - .as_bytes() - { - b"read" => Ok(Self::BurnAfterReading), - b"5m" => Ok(Self::UnixTime(Utc::now() + Duration::minutes(5))), - b"10m" => Ok(Self::UnixTime(Utc::now() + Duration::minutes(10))), - b"1h" => Ok(Self::UnixTime(Utc::now() + Duration::hours(1))), - b"1d" => Ok(Self::UnixTime(Utc::now() + Duration::days(1))), - // We disallow permanent pastes - _ => Err(headers::Error::invalid()), - } + let bytes = values.next().ok_or_else(headers::Error::invalid)?; + + Self::try_from(bytes).map_err(|_| headers::Error::invalid()) } fn encode>(&self, container: &mut E) { @@ -282,7 +290,9 @@ impl From<&Expiration> for HeaderValue { // so we don't need the extra check. unsafe { Self::from_maybe_shared_unchecked(match expiration { - Expiration::BurnAfterReading => Bytes::from_static(b"0"), + Expiration::BurnAfterReadingWithDeadline(_) | Expiration::BurnAfterReading => { + Bytes::from_static(b"0") + } Expiration::UnixTime(duration) => Bytes::from(duration.to_rfc3339()), }) } @@ -295,6 +305,8 @@ impl From for HeaderValue { } } +pub struct ParseHeaderValueError; + #[cfg(feature = "wasm")] impl TryFrom for Expiration { type Error = ParseHeaderValueError; @@ -310,14 +322,19 @@ impl TryFrom for Expiration { } } -pub struct ParseHeaderValueError; +impl TryFrom for Expiration { + type Error = ParseHeaderValueError; + + fn try_from(value: HeaderValue) -> Result { + Self::try_from(&value) + } +} impl TryFrom<&HeaderValue> for Expiration { type Error = ParseHeaderValueError; fn try_from(value: &HeaderValue) -> Result { - value - .to_str() + std::str::from_utf8(value.as_bytes()) .map_err(|_| ParseHeaderValueError) .and_then(Self::try_from) } @@ -327,6 +344,10 @@ impl TryFrom<&str> for Expiration { type Error = ParseHeaderValueError; fn try_from(value: &str) -> Result { + if value == "0" { + return Ok(Self::BurnAfterReading); + } + value .parse::>() .map_err(|_| ParseHeaderValueError) diff --git a/server/Cargo.toml b/server/Cargo.toml index 513b6e9..00ee5fb 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -16,6 +16,7 @@ bytes = { version = "*", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] } # We just need to pull in whatever axum is pulling in headers = "*" +lazy_static = "1" rand = "0.8" rocksdb = { version = "0.17", default_features = false, features = ["zstd"] } serde = { version = "1", features = ["derive"] } diff --git a/server/src/main.rs b/server/src/main.rs index c6552e9..4ab7189 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -14,6 +14,7 @@ use axum::response::Html; use axum::{service, AddExtensionLayer, Router}; use chrono::Utc; use headers::HeaderMap; +use lazy_static::lazy_static; use omegaupload_common::{Expiration, API_ENDPOINT}; use rand::thread_rng; use rand::Rng; @@ -31,6 +32,10 @@ mod short_code; const BLOB_CF_NAME: &str = "blob"; const META_CF_NAME: &str = "meta"; +lazy_static! { + static ref MAX_PASTE_AGE: chrono::Duration = chrono::Duration::days(1); +} + #[tokio::main] async fn main() -> Result<()> { const PASTE_DB_PATH: &str = "database"; @@ -112,8 +117,10 @@ fn set_up_expirations(db: &Arc) { let expiration_time = match expiration { Expiration::BurnAfterReading => { - panic!("Got burn after reading expiration time? Invariant violated"); + warn!("Found unbounded burn after reading. Defaulting to max age"); + Utc::now() + *MAX_PASTE_AGE } + Expiration::BurnAfterReadingWithDeadline(deadline) => deadline, Expiration::UnixTime(time) => time, }; @@ -152,6 +159,15 @@ async fn upload( return Err(StatusCode::BAD_REQUEST); } + if let Some(header) = maybe_expires { + if let Expiration::UnixTime(time) = header.0 { + if (time - Utc::now()) > *MAX_PASTE_AGE { + warn!("{} exceeds allowed paste lifetime", time); + return Err(StatusCode::BAD_REQUEST); + } + } + } + // 3GB max; this is a soft-limit of RocksDb if body.len() >= 3_221_225_472 { return Err(StatusCode::PAYLOAD_TOO_LARGE); @@ -185,10 +201,6 @@ async fn upload( return Err(StatusCode::INTERNAL_SERVER_ERROR); }; - trace!("Serializing paste..."); - - trace!("Finished serializing paste."); - let db_ref = Arc::clone(&db); match task::spawn_blocking(move || { let blob_cf = db_ref.cf_handle(BLOB_CF_NAME).unwrap(); @@ -196,6 +208,11 @@ async fn upload( let data = bincode::serialize(&body).expect("bincode to serialize"); db_ref.put_cf(blob_cf, key, data)?; let expires = maybe_expires.map(|v| v.0).unwrap_or_default(); + let expires = if let Expiration::BurnAfterReading = expires { + Expiration::BurnAfterReadingWithDeadline(Utc::now() + *MAX_PASTE_AGE) + } else { + expires + }; let meta = bincode::serialize(&expires).expect("bincode to serialize"); if db_ref.put_cf(meta_cf, key, meta).is_err() { // try and roll back on metadata write failure @@ -207,7 +224,9 @@ async fn upload( { Ok(Ok(_)) => { if let Some(expires) = maybe_expires { - if let Expiration::UnixTime(expiration_time) = expires.0 { + if let Expiration::UnixTime(expiration_time) + | Expiration::BurnAfterReadingWithDeadline(expiration_time) = expires.0 + { let sleep_duration = (expiration_time - Utc::now()).to_std().unwrap_or_default(); @@ -302,16 +321,33 @@ async fn paste( }; // Check if we need to burn after read - if matches!(metadata, Expiration::BurnAfterReading) { - let join_handle = task::spawn_blocking(move || db.delete(key)) - .await - .map_err(|e| { - error!("Failed to join handle: {}", e); - StatusCode::INTERNAL_SERVER_ERROR - })?; + if matches!( + metadata, + Expiration::BurnAfterReading | Expiration::BurnAfterReadingWithDeadline(_) + ) { + let join_handle = task::spawn_blocking(move || { + let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap(); + let meta_cf = db.cf_handle(META_CF_NAME).unwrap(); + if let Err(e) = db.delete_cf(blob_cf, url.as_bytes()) { + warn!("{}", e); + return Err(()); + } - join_handle.map_err(|e| { - error!("Failed to burn paste after read: {}", e); + if let Err(e) = db.delete_cf(meta_cf, url.as_bytes()) { + warn!("{}", e); + return Err(()); + } + + Ok(()) + }) + .await + .map_err(|e| { + error!("Failed to join handle: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + join_handle.map_err(|_| { + error!("Failed to burn paste after read"); StatusCode::INTERNAL_SERVER_ERROR })?; }