support burns
This commit is contained in:
parent
c29340c93b
commit
8a08e8e100
5 changed files with 116 additions and 46 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -947,6 +947,7 @@ dependencies = [
|
|||
"bytes",
|
||||
"chrono",
|
||||
"headers",
|
||||
"lazy_static",
|
||||
"omegaupload-common",
|
||||
"rand",
|
||||
"rocksdb",
|
||||
|
|
|
@ -7,7 +7,9 @@ use anyhow::{anyhow, bail, Context, Result};
|
|||
use atty::Stream;
|
||||
use clap::Parser;
|
||||
use omegaupload_common::crypto::{gen_key_nonce, open_in_place, seal_in_place, Key};
|
||||
use omegaupload_common::{base64, hash, Expiration, ParsedUrl, Url, API_ENDPOINT};
|
||||
use omegaupload_common::{
|
||||
base64, hash, Expiration, ParsedUrl, Url, API_ENDPOINT, EXPIRATION_HEADER_NAME,
|
||||
};
|
||||
use reqwest::blocking::Client;
|
||||
use reqwest::header::EXPIRES;
|
||||
use reqwest::StatusCode;
|
||||
|
@ -28,6 +30,8 @@ enum Action {
|
|||
/// public access.
|
||||
#[clap(short, long)]
|
||||
password: Option<SecretString>,
|
||||
#[clap(short, long)]
|
||||
duration: Option<Expiration>,
|
||||
},
|
||||
Download {
|
||||
/// The paste to download.
|
||||
|
@ -39,14 +43,22 @@ fn main() -> Result<()> {
|
|||
let opts = Opts::parse();
|
||||
|
||||
match opts.action {
|
||||
Action::Upload { url, password } => handle_upload(url, password),
|
||||
Action::Upload {
|
||||
url,
|
||||
password,
|
||||
duration,
|
||||
} => handle_upload(url, password, duration),
|
||||
Action::Download { url } => handle_download(url),
|
||||
}?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_upload(mut url: Url, password: Option<SecretString>) -> Result<()> {
|
||||
fn handle_upload(
|
||||
mut url: Url,
|
||||
password: Option<SecretString>,
|
||||
duration: Option<Expiration>,
|
||||
) -> Result<()> {
|
||||
url.set_fragment(None);
|
||||
|
||||
if atty::is(Stream::Stdin) {
|
||||
|
@ -76,11 +88,13 @@ fn handle_upload(mut url: Url, password: Option<SecretString>) -> Result<()> {
|
|||
(container, nonce, key, pw_used)
|
||||
};
|
||||
|
||||
let res = Client::new()
|
||||
.post(url.as_ref())
|
||||
.body(data)
|
||||
.send()
|
||||
.context("Request to server failed")?;
|
||||
let mut res = Client::new().post(url.as_ref());
|
||||
|
||||
if let Some(duration) = duration {
|
||||
res = res.header(&*EXPIRATION_HEADER_NAME, duration);
|
||||
}
|
||||
|
||||
let res = res.body(data).send().context("Request to server failed")?;
|
||||
|
||||
if res.status() != StatusCode::OK {
|
||||
bail!("Upload failed. Got HTTP error {}", res.status());
|
||||
|
@ -104,11 +118,8 @@ fn handle_upload(mut url: Url, password: Option<SecretString>) -> Result<()> {
|
|||
}
|
||||
|
||||
fn handle_download(mut url: ParsedUrl) -> Result<()> {
|
||||
url.sanitized_url.set_path(&dbg!(format!(
|
||||
"{}{}",
|
||||
API_ENDPOINT,
|
||||
url.sanitized_url.path()
|
||||
)));
|
||||
url.sanitized_url
|
||||
.set_path(&format!("{}{}", API_ENDPOINT, url.sanitized_url.path()));
|
||||
let res = Client::new()
|
||||
.get(url.sanitized_url)
|
||||
.send()
|
||||
|
|
|
@ -224,13 +224,31 @@ impl FromStr for ParsedUrl {
|
|||
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
|
||||
pub enum Expiration {
|
||||
BurnAfterReading,
|
||||
BurnAfterReadingWithDeadline(DateTime<Utc>),
|
||||
UnixTime(DateTime<Utc>),
|
||||
}
|
||||
|
||||
// This impl is used for the CLI
|
||||
impl FromStr for Expiration {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"read" => Ok(Self::BurnAfterReading),
|
||||
"5m" => Ok(Self::UnixTime(Utc::now() + Duration::minutes(5))),
|
||||
"10m" => Ok(Self::UnixTime(Utc::now() + Duration::minutes(10))),
|
||||
"1h" => Ok(Self::UnixTime(Utc::now() + Duration::hours(1))),
|
||||
"1d" => Ok(Self::UnixTime(Utc::now() + Duration::days(1))),
|
||||
// We disallow permanent pastes
|
||||
_ => Err(s.to_owned()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Expiration {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Expiration::BurnAfterReading => {
|
||||
Expiration::BurnAfterReading | Expiration::BurnAfterReadingWithDeadline(_) => {
|
||||
write!(f, "This item has been burned. You now have the only copy.")
|
||||
}
|
||||
Expiration::UnixTime(time) => write!(
|
||||
|
@ -256,19 +274,9 @@ impl Header for Expiration {
|
|||
Self: Sized,
|
||||
I: Iterator<Item = &'i HeaderValue>,
|
||||
{
|
||||
match values
|
||||
.next()
|
||||
.ok_or_else(headers::Error::invalid)?
|
||||
.as_bytes()
|
||||
{
|
||||
b"read" => Ok(Self::BurnAfterReading),
|
||||
b"5m" => Ok(Self::UnixTime(Utc::now() + Duration::minutes(5))),
|
||||
b"10m" => Ok(Self::UnixTime(Utc::now() + Duration::minutes(10))),
|
||||
b"1h" => Ok(Self::UnixTime(Utc::now() + Duration::hours(1))),
|
||||
b"1d" => Ok(Self::UnixTime(Utc::now() + Duration::days(1))),
|
||||
// We disallow permanent pastes
|
||||
_ => Err(headers::Error::invalid()),
|
||||
}
|
||||
let bytes = values.next().ok_or_else(headers::Error::invalid)?;
|
||||
|
||||
Self::try_from(bytes).map_err(|_| headers::Error::invalid())
|
||||
}
|
||||
|
||||
fn encode<E: Extend<HeaderValue>>(&self, container: &mut E) {
|
||||
|
@ -282,7 +290,9 @@ impl From<&Expiration> for HeaderValue {
|
|||
// so we don't need the extra check.
|
||||
unsafe {
|
||||
Self::from_maybe_shared_unchecked(match expiration {
|
||||
Expiration::BurnAfterReading => Bytes::from_static(b"0"),
|
||||
Expiration::BurnAfterReadingWithDeadline(_) | Expiration::BurnAfterReading => {
|
||||
Bytes::from_static(b"0")
|
||||
}
|
||||
Expiration::UnixTime(duration) => Bytes::from(duration.to_rfc3339()),
|
||||
})
|
||||
}
|
||||
|
@ -295,6 +305,8 @@ impl From<Expiration> for HeaderValue {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct ParseHeaderValueError;
|
||||
|
||||
#[cfg(feature = "wasm")]
|
||||
impl TryFrom<web_sys::Headers> for Expiration {
|
||||
type Error = ParseHeaderValueError;
|
||||
|
@ -310,14 +322,19 @@ impl TryFrom<web_sys::Headers> for Expiration {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct ParseHeaderValueError;
|
||||
impl TryFrom<HeaderValue> for Expiration {
|
||||
type Error = ParseHeaderValueError;
|
||||
|
||||
fn try_from(value: HeaderValue) -> Result<Self, Self::Error> {
|
||||
Self::try_from(&value)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&HeaderValue> for Expiration {
|
||||
type Error = ParseHeaderValueError;
|
||||
|
||||
fn try_from(value: &HeaderValue) -> Result<Self, Self::Error> {
|
||||
value
|
||||
.to_str()
|
||||
std::str::from_utf8(value.as_bytes())
|
||||
.map_err(|_| ParseHeaderValueError)
|
||||
.and_then(Self::try_from)
|
||||
}
|
||||
|
@ -327,6 +344,10 @@ impl TryFrom<&str> for Expiration {
|
|||
type Error = ParseHeaderValueError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
if value == "0" {
|
||||
return Ok(Self::BurnAfterReading);
|
||||
}
|
||||
|
||||
value
|
||||
.parse::<DateTime<Utc>>()
|
||||
.map_err(|_| ParseHeaderValueError)
|
||||
|
|
|
@ -16,6 +16,7 @@ bytes = { version = "*", features = ["serde"] }
|
|||
chrono = { version = "0.4", features = ["serde"] }
|
||||
# We just need to pull in whatever axum is pulling in
|
||||
headers = "*"
|
||||
lazy_static = "1"
|
||||
rand = "0.8"
|
||||
rocksdb = { version = "0.17", default_features = false, features = ["zstd"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
|
|
|
@ -14,6 +14,7 @@ use axum::response::Html;
|
|||
use axum::{service, AddExtensionLayer, Router};
|
||||
use chrono::Utc;
|
||||
use headers::HeaderMap;
|
||||
use lazy_static::lazy_static;
|
||||
use omegaupload_common::{Expiration, API_ENDPOINT};
|
||||
use rand::thread_rng;
|
||||
use rand::Rng;
|
||||
|
@ -31,6 +32,10 @@ mod short_code;
|
|||
const BLOB_CF_NAME: &str = "blob";
|
||||
const META_CF_NAME: &str = "meta";
|
||||
|
||||
lazy_static! {
|
||||
static ref MAX_PASTE_AGE: chrono::Duration = chrono::Duration::days(1);
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
const PASTE_DB_PATH: &str = "database";
|
||||
|
@ -112,8 +117,10 @@ fn set_up_expirations(db: &Arc<DB>) {
|
|||
|
||||
let expiration_time = match expiration {
|
||||
Expiration::BurnAfterReading => {
|
||||
panic!("Got burn after reading expiration time? Invariant violated");
|
||||
warn!("Found unbounded burn after reading. Defaulting to max age");
|
||||
Utc::now() + *MAX_PASTE_AGE
|
||||
}
|
||||
Expiration::BurnAfterReadingWithDeadline(deadline) => deadline,
|
||||
Expiration::UnixTime(time) => time,
|
||||
};
|
||||
|
||||
|
@ -152,6 +159,15 @@ async fn upload<const N: usize>(
|
|||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
if let Some(header) = maybe_expires {
|
||||
if let Expiration::UnixTime(time) = header.0 {
|
||||
if (time - Utc::now()) > *MAX_PASTE_AGE {
|
||||
warn!("{} exceeds allowed paste lifetime", time);
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3GB max; this is a soft-limit of RocksDb
|
||||
if body.len() >= 3_221_225_472 {
|
||||
return Err(StatusCode::PAYLOAD_TOO_LARGE);
|
||||
|
@ -185,10 +201,6 @@ async fn upload<const N: usize>(
|
|||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
};
|
||||
|
||||
trace!("Serializing paste...");
|
||||
|
||||
trace!("Finished serializing paste.");
|
||||
|
||||
let db_ref = Arc::clone(&db);
|
||||
match task::spawn_blocking(move || {
|
||||
let blob_cf = db_ref.cf_handle(BLOB_CF_NAME).unwrap();
|
||||
|
@ -196,6 +208,11 @@ async fn upload<const N: usize>(
|
|||
let data = bincode::serialize(&body).expect("bincode to serialize");
|
||||
db_ref.put_cf(blob_cf, key, data)?;
|
||||
let expires = maybe_expires.map(|v| v.0).unwrap_or_default();
|
||||
let expires = if let Expiration::BurnAfterReading = expires {
|
||||
Expiration::BurnAfterReadingWithDeadline(Utc::now() + *MAX_PASTE_AGE)
|
||||
} else {
|
||||
expires
|
||||
};
|
||||
let meta = bincode::serialize(&expires).expect("bincode to serialize");
|
||||
if db_ref.put_cf(meta_cf, key, meta).is_err() {
|
||||
// try and roll back on metadata write failure
|
||||
|
@ -207,7 +224,9 @@ async fn upload<const N: usize>(
|
|||
{
|
||||
Ok(Ok(_)) => {
|
||||
if let Some(expires) = maybe_expires {
|
||||
if let Expiration::UnixTime(expiration_time) = expires.0 {
|
||||
if let Expiration::UnixTime(expiration_time)
|
||||
| Expiration::BurnAfterReadingWithDeadline(expiration_time) = expires.0
|
||||
{
|
||||
let sleep_duration =
|
||||
(expiration_time - Utc::now()).to_std().unwrap_or_default();
|
||||
|
||||
|
@ -302,16 +321,33 @@ async fn paste<const N: usize>(
|
|||
};
|
||||
|
||||
// Check if we need to burn after read
|
||||
if matches!(metadata, Expiration::BurnAfterReading) {
|
||||
let join_handle = task::spawn_blocking(move || db.delete(key))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to join handle: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
if matches!(
|
||||
metadata,
|
||||
Expiration::BurnAfterReading | Expiration::BurnAfterReadingWithDeadline(_)
|
||||
) {
|
||||
let join_handle = task::spawn_blocking(move || {
|
||||
let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap();
|
||||
let meta_cf = db.cf_handle(META_CF_NAME).unwrap();
|
||||
if let Err(e) = db.delete_cf(blob_cf, url.as_bytes()) {
|
||||
warn!("{}", e);
|
||||
return Err(());
|
||||
}
|
||||
|
||||
join_handle.map_err(|e| {
|
||||
error!("Failed to burn paste after read: {}", e);
|
||||
if let Err(e) = db.delete_cf(meta_cf, url.as_bytes()) {
|
||||
warn!("{}", e);
|
||||
return Err(());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to join handle: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
join_handle.map_err(|_| {
|
||||
error!("Failed to burn paste after read");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue