support burns

This commit is contained in:
Edward Shen 2021-10-27 19:16:43 -07:00
parent c29340c93b
commit 8a08e8e100
Signed by: edward
GPG key ID: 19182661E818369F
5 changed files with 116 additions and 46 deletions

1
Cargo.lock generated
View file

@ -947,6 +947,7 @@ dependencies = [
"bytes",
"chrono",
"headers",
"lazy_static",
"omegaupload-common",
"rand",
"rocksdb",

View file

@ -7,7 +7,9 @@ use anyhow::{anyhow, bail, Context, Result};
use atty::Stream;
use clap::Parser;
use omegaupload_common::crypto::{gen_key_nonce, open_in_place, seal_in_place, Key};
use omegaupload_common::{base64, hash, Expiration, ParsedUrl, Url, API_ENDPOINT};
use omegaupload_common::{
base64, hash, Expiration, ParsedUrl, Url, API_ENDPOINT, EXPIRATION_HEADER_NAME,
};
use reqwest::blocking::Client;
use reqwest::header::EXPIRES;
use reqwest::StatusCode;
@ -28,6 +30,8 @@ enum Action {
/// public access.
#[clap(short, long)]
password: Option<SecretString>,
#[clap(short, long)]
duration: Option<Expiration>,
},
Download {
/// The paste to download.
@ -39,14 +43,22 @@ fn main() -> Result<()> {
let opts = Opts::parse();
match opts.action {
Action::Upload { url, password } => handle_upload(url, password),
Action::Upload {
url,
password,
duration,
} => handle_upload(url, password, duration),
Action::Download { url } => handle_download(url),
}?;
Ok(())
}
fn handle_upload(mut url: Url, password: Option<SecretString>) -> Result<()> {
fn handle_upload(
mut url: Url,
password: Option<SecretString>,
duration: Option<Expiration>,
) -> Result<()> {
url.set_fragment(None);
if atty::is(Stream::Stdin) {
@ -76,11 +88,13 @@ fn handle_upload(mut url: Url, password: Option<SecretString>) -> Result<()> {
(container, nonce, key, pw_used)
};
let res = Client::new()
.post(url.as_ref())
.body(data)
.send()
.context("Request to server failed")?;
let mut res = Client::new().post(url.as_ref());
if let Some(duration) = duration {
res = res.header(&*EXPIRATION_HEADER_NAME, duration);
}
let res = res.body(data).send().context("Request to server failed")?;
if res.status() != StatusCode::OK {
bail!("Upload failed. Got HTTP error {}", res.status());
@ -104,11 +118,8 @@ fn handle_upload(mut url: Url, password: Option<SecretString>) -> Result<()> {
}
fn handle_download(mut url: ParsedUrl) -> Result<()> {
url.sanitized_url.set_path(&dbg!(format!(
"{}{}",
API_ENDPOINT,
url.sanitized_url.path()
)));
url.sanitized_url
.set_path(&format!("{}{}", API_ENDPOINT, url.sanitized_url.path()));
let res = Client::new()
.get(url.sanitized_url)
.send()

View file

@ -224,13 +224,31 @@ impl FromStr for ParsedUrl {
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub enum Expiration {
BurnAfterReading,
BurnAfterReadingWithDeadline(DateTime<Utc>),
UnixTime(DateTime<Utc>),
}
// This impl is used for the CLI
impl FromStr for Expiration {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"read" => Ok(Self::BurnAfterReading),
"5m" => Ok(Self::UnixTime(Utc::now() + Duration::minutes(5))),
"10m" => Ok(Self::UnixTime(Utc::now() + Duration::minutes(10))),
"1h" => Ok(Self::UnixTime(Utc::now() + Duration::hours(1))),
"1d" => Ok(Self::UnixTime(Utc::now() + Duration::days(1))),
// We disallow permanent pastes
_ => Err(s.to_owned()),
}
}
}
impl Display for Expiration {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expiration::BurnAfterReading => {
Expiration::BurnAfterReading | Expiration::BurnAfterReadingWithDeadline(_) => {
write!(f, "This item has been burned. You now have the only copy.")
}
Expiration::UnixTime(time) => write!(
@ -256,19 +274,9 @@ impl Header for Expiration {
Self: Sized,
I: Iterator<Item = &'i HeaderValue>,
{
match values
.next()
.ok_or_else(headers::Error::invalid)?
.as_bytes()
{
b"read" => Ok(Self::BurnAfterReading),
b"5m" => Ok(Self::UnixTime(Utc::now() + Duration::minutes(5))),
b"10m" => Ok(Self::UnixTime(Utc::now() + Duration::minutes(10))),
b"1h" => Ok(Self::UnixTime(Utc::now() + Duration::hours(1))),
b"1d" => Ok(Self::UnixTime(Utc::now() + Duration::days(1))),
// We disallow permanent pastes
_ => Err(headers::Error::invalid()),
}
let bytes = values.next().ok_or_else(headers::Error::invalid)?;
Self::try_from(bytes).map_err(|_| headers::Error::invalid())
}
fn encode<E: Extend<HeaderValue>>(&self, container: &mut E) {
@ -282,7 +290,9 @@ impl From<&Expiration> for HeaderValue {
// so we don't need the extra check.
unsafe {
Self::from_maybe_shared_unchecked(match expiration {
Expiration::BurnAfterReading => Bytes::from_static(b"0"),
Expiration::BurnAfterReadingWithDeadline(_) | Expiration::BurnAfterReading => {
Bytes::from_static(b"0")
}
Expiration::UnixTime(duration) => Bytes::from(duration.to_rfc3339()),
})
}
@ -295,6 +305,8 @@ impl From<Expiration> for HeaderValue {
}
}
pub struct ParseHeaderValueError;
#[cfg(feature = "wasm")]
impl TryFrom<web_sys::Headers> for Expiration {
type Error = ParseHeaderValueError;
@ -310,14 +322,19 @@ impl TryFrom<web_sys::Headers> for Expiration {
}
}
pub struct ParseHeaderValueError;
impl TryFrom<HeaderValue> for Expiration {
type Error = ParseHeaderValueError;
fn try_from(value: HeaderValue) -> Result<Self, Self::Error> {
Self::try_from(&value)
}
}
impl TryFrom<&HeaderValue> for Expiration {
type Error = ParseHeaderValueError;
fn try_from(value: &HeaderValue) -> Result<Self, Self::Error> {
value
.to_str()
std::str::from_utf8(value.as_bytes())
.map_err(|_| ParseHeaderValueError)
.and_then(Self::try_from)
}
@ -327,6 +344,10 @@ impl TryFrom<&str> for Expiration {
type Error = ParseHeaderValueError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
if value == "0" {
return Ok(Self::BurnAfterReading);
}
value
.parse::<DateTime<Utc>>()
.map_err(|_| ParseHeaderValueError)

View file

@ -16,6 +16,7 @@ bytes = { version = "*", features = ["serde"] }
chrono = { version = "0.4", features = ["serde"] }
# We just need to pull in whatever axum is pulling in
headers = "*"
lazy_static = "1"
rand = "0.8"
rocksdb = { version = "0.17", default_features = false, features = ["zstd"] }
serde = { version = "1", features = ["derive"] }

View file

@ -14,6 +14,7 @@ use axum::response::Html;
use axum::{service, AddExtensionLayer, Router};
use chrono::Utc;
use headers::HeaderMap;
use lazy_static::lazy_static;
use omegaupload_common::{Expiration, API_ENDPOINT};
use rand::thread_rng;
use rand::Rng;
@ -31,6 +32,10 @@ mod short_code;
const BLOB_CF_NAME: &str = "blob";
const META_CF_NAME: &str = "meta";
lazy_static! {
static ref MAX_PASTE_AGE: chrono::Duration = chrono::Duration::days(1);
}
#[tokio::main]
async fn main() -> Result<()> {
const PASTE_DB_PATH: &str = "database";
@ -112,8 +117,10 @@ fn set_up_expirations(db: &Arc<DB>) {
let expiration_time = match expiration {
Expiration::BurnAfterReading => {
panic!("Got burn after reading expiration time? Invariant violated");
warn!("Found unbounded burn after reading. Defaulting to max age");
Utc::now() + *MAX_PASTE_AGE
}
Expiration::BurnAfterReadingWithDeadline(deadline) => deadline,
Expiration::UnixTime(time) => time,
};
@ -152,6 +159,15 @@ async fn upload<const N: usize>(
return Err(StatusCode::BAD_REQUEST);
}
if let Some(header) = maybe_expires {
if let Expiration::UnixTime(time) = header.0 {
if (time - Utc::now()) > *MAX_PASTE_AGE {
warn!("{} exceeds allowed paste lifetime", time);
return Err(StatusCode::BAD_REQUEST);
}
}
}
// 3GB max; this is a soft-limit of RocksDb
if body.len() >= 3_221_225_472 {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
@ -185,10 +201,6 @@ async fn upload<const N: usize>(
return Err(StatusCode::INTERNAL_SERVER_ERROR);
};
trace!("Serializing paste...");
trace!("Finished serializing paste.");
let db_ref = Arc::clone(&db);
match task::spawn_blocking(move || {
let blob_cf = db_ref.cf_handle(BLOB_CF_NAME).unwrap();
@ -196,6 +208,11 @@ async fn upload<const N: usize>(
let data = bincode::serialize(&body).expect("bincode to serialize");
db_ref.put_cf(blob_cf, key, data)?;
let expires = maybe_expires.map(|v| v.0).unwrap_or_default();
let expires = if let Expiration::BurnAfterReading = expires {
Expiration::BurnAfterReadingWithDeadline(Utc::now() + *MAX_PASTE_AGE)
} else {
expires
};
let meta = bincode::serialize(&expires).expect("bincode to serialize");
if db_ref.put_cf(meta_cf, key, meta).is_err() {
// try and roll back on metadata write failure
@ -207,7 +224,9 @@ async fn upload<const N: usize>(
{
Ok(Ok(_)) => {
if let Some(expires) = maybe_expires {
if let Expiration::UnixTime(expiration_time) = expires.0 {
if let Expiration::UnixTime(expiration_time)
| Expiration::BurnAfterReadingWithDeadline(expiration_time) = expires.0
{
let sleep_duration =
(expiration_time - Utc::now()).to_std().unwrap_or_default();
@ -302,16 +321,33 @@ async fn paste<const N: usize>(
};
// Check if we need to burn after read
if matches!(metadata, Expiration::BurnAfterReading) {
let join_handle = task::spawn_blocking(move || db.delete(key))
.await
.map_err(|e| {
error!("Failed to join handle: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
if matches!(
metadata,
Expiration::BurnAfterReading | Expiration::BurnAfterReadingWithDeadline(_)
) {
let join_handle = task::spawn_blocking(move || {
let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap();
let meta_cf = db.cf_handle(META_CF_NAME).unwrap();
if let Err(e) = db.delete_cf(blob_cf, url.as_bytes()) {
warn!("{}", e);
return Err(());
}
join_handle.map_err(|e| {
error!("Failed to burn paste after read: {}", e);
if let Err(e) = db.delete_cf(meta_cf, url.as_bytes()) {
warn!("{}", e);
return Err(());
}
Ok(())
})
.await
.map_err(|e| {
error!("Failed to join handle: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
join_handle.map_err(|_| {
error!("Failed to burn paste after read");
StatusCode::INTERNAL_SERVER_ERROR
})?;
}