partition rocksdb

This commit is contained in:
Edward Shen 2021-10-24 18:07:48 -07:00
parent beda106f6a
commit 8ea15ef395
Signed by: edward
GPG key ID: 19182661E818369F
4 changed files with 144 additions and 103 deletions

1
Cargo.lock generated
View file

@ -1070,6 +1070,7 @@ dependencies = [
"bincode", "bincode",
"bytes", "bytes",
"chrono", "chrono",
"futures",
"headers", "headers",
"omegaupload-common", "omegaupload-common",
"rand", "rand",

View file

@ -14,6 +14,7 @@ bincode = "1"
# to enable the feature # to enable the feature
bytes = { version = "*", features = ["serde"] } bytes = { version = "*", features = ["serde"] }
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
futures = "0.3"
# We just need to pull in whatever axum is pulling in # We just need to pull in whatever axum is pulling in
headers = "*" headers = "*"
rand = "0.8" rand = "0.8"

View file

@ -1,6 +1,7 @@
#![warn(clippy::nursery, clippy::pedantic)] #![warn(clippy::nursery, clippy::pedantic)]
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use anyhow::Result; use anyhow::Result;
use axum::body::Bytes; use axum::body::Bytes;
@ -14,26 +15,38 @@ use headers::HeaderMap;
use omegaupload_common::Expiration; use omegaupload_common::Expiration;
use rand::thread_rng; use rand::thread_rng;
use rand::Rng; use rand::Rng;
use rocksdb::IteratorMode; use rocksdb::{ColumnFamilyDescriptor, IteratorMode};
use rocksdb::{Options, DB}; use rocksdb::{Options, DB};
use tokio::task; use tokio::task;
use tracing::{error, instrument, trace}; use tracing::{error, instrument, trace};
use tracing::{info, warn}; use tracing::{info, warn};
use crate::paste::Paste;
use crate::short_code::ShortCode; use crate::short_code::ShortCode;
mod paste;
mod short_code; mod short_code;
const BLOB_CF_NAME: &str = "blob";
const META_CF_NAME: &str = "meta";
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
const DB_PATH: &str = "database"; const PASTE_DB_PATH: &str = "database";
const SHORT_CODE_SIZE: usize = 12; const SHORT_CODE_SIZE: usize = 12;
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let db = Arc::new(DB::open_default(DB_PATH)?); let mut db_options = Options::default();
db_options.create_if_missing(true);
db_options.create_missing_column_families(true);
db_options.set_compression_type(rocksdb::DBCompressionType::Zstd);
let db = Arc::new(DB::open_cf_descriptors(
&db_options,
PASTE_DB_PATH,
[
ColumnFamilyDescriptor::new(BLOB_CF_NAME, Options::default()),
ColumnFamilyDescriptor::new(META_CF_NAME, Options::default()),
],
)?);
set_up_expirations(Arc::clone(&db)); set_up_expirations(Arc::clone(&db));
@ -51,7 +64,7 @@ async fn main() -> Result<()> {
.await?; .await?;
// Must be called for correct shutdown // Must be called for correct shutdown
DB::destroy(&Options::default(), DB_PATH)?; DB::destroy(&Options::default(), PASTE_DB_PATH)?;
Ok(()) Ok(())
} }
@ -59,42 +72,51 @@ fn set_up_expirations(db: Arc<DB>) {
let mut corrupted = 0; let mut corrupted = 0;
let mut expired = 0; let mut expired = 0;
let mut pending = 0; let mut pending = 0;
let mut permanent = 0;
info!("Setting up cleanup timers, please wait..."); info!("Setting up cleanup timers, please wait...");
for (key, value) in db.iterator(IteratorMode::Start) { let meta_cf = db.cf_handle(META_CF_NAME).unwrap();
let paste = if let Ok(value) = bincode::deserialize::<Paste>(&value) {
let db_ref = Arc::clone(&db);
let delete_entry = move |key: &[u8]| {
let blob_cf = db_ref.cf_handle(BLOB_CF_NAME).unwrap();
let meta_cf = db_ref.cf_handle(META_CF_NAME).unwrap();
if let Err(e) = db_ref.delete_cf(blob_cf, &key) {
warn!("{}", e);
}
if let Err(e) = db_ref.delete_cf(meta_cf, &key) {
warn!("{}", e);
}
};
for (key, value) in db.iterator_cf(meta_cf, IteratorMode::Start) {
let expires = if let Ok(value) = bincode::deserialize::<Expiration>(&value) {
value value
} else { } else {
corrupted += 1; corrupted += 1;
if let Err(e) = db.delete(key) { delete_entry(&key);
warn!("{}", e);
}
continue; continue;
}; };
if let Some(Expiration::UnixTime(time)) = paste.expiration {
let now = Utc::now();
if time < now { let expiration_time = match expires {
expired += 1; Expiration::BurnAfterReading => {
if let Err(e) = db.delete(key) { panic!("Got burn after reading expiration time? Invariant violated");
warn!("{}", e);
}
} else {
let sleep_duration = (time - now).to_std().unwrap();
pending += 1;
let db_ref = Arc::clone(&db);
task::spawn_blocking(move || async move {
tokio::time::sleep(sleep_duration).await;
if let Err(e) = db_ref.delete(key) {
warn!("{}", e);
}
});
} }
Expiration::UnixTime(time) => time,
};
let sleep_duration = (expiration_time - Utc::now()).to_std().unwrap_or_default();
if sleep_duration != Duration::default() {
pending += 1;
let delete_entry_ref = delete_entry.clone();
task::spawn_blocking(move || async move {
tokio::time::sleep(sleep_duration).await;
delete_entry_ref(&key);
});
} else { } else {
permanent += 1; expired += 1;
delete_entry(&key);
} }
} }
@ -105,7 +127,6 @@ fn set_up_expirations(db: Arc<DB>) {
} }
info!("Found {} expired pastes.", expired); info!("Found {} expired pastes.", expired);
info!("Found {} active pastes.", pending); info!("Found {} active pastes.", pending);
info!("Found {} permanent pastes.", permanent);
info!("Cleanup timers have been initialized."); info!("Cleanup timers have been initialized.");
} }
@ -124,7 +145,6 @@ async fn upload<const N: usize>(
return Err(StatusCode::PAYLOAD_TOO_LARGE); return Err(StatusCode::PAYLOAD_TOO_LARGE);
} }
let paste = Paste::new(maybe_expires.map(|v| v.0).unwrap_or_default(), body);
let mut new_key = None; let mut new_key = None;
trace!("Generating short code..."); trace!("Generating short code...");
@ -135,7 +155,10 @@ async fn upload<const N: usize>(
let code: ShortCode<N> = thread_rng().sample(short_code::Generator); let code: ShortCode<N> = thread_rng().sample(short_code::Generator);
let db = Arc::clone(&db); let db = Arc::clone(&db);
let key = code.as_bytes(); let key = code.as_bytes();
let query = task::spawn_blocking(move || db.key_may_exist(key)).await; let query = task::spawn_blocking(move || {
db.key_may_exist_cf(db.cf_handle(META_CF_NAME).unwrap(), key)
})
.await;
if matches!(query, Ok(false)) { if matches!(query, Ok(false)) {
new_key = Some(key); new_key = Some(key);
trace!("Found new key after {} attempts.", i); trace!("Found new key after {} attempts.", i);
@ -151,36 +174,42 @@ async fn upload<const N: usize>(
}; };
trace!("Serializing paste..."); trace!("Serializing paste...");
let value = if let Ok(v) = bincode::serialize(&paste) {
v
} else {
error!("Failed to serialize paste?!");
return Err(StatusCode::INTERNAL_SERVER_ERROR);
};
trace!("Finished serializing paste."); trace!("Finished serializing paste.");
let db_ref = Arc::clone(&db); let db_ref = Arc::clone(&db);
match task::spawn_blocking(move || db_ref.put(key, value)).await { match task::spawn_blocking(move || {
let blob_cf = db_ref.cf_handle(BLOB_CF_NAME).unwrap();
let meta_cf = db_ref.cf_handle(META_CF_NAME).unwrap();
let data = bincode::serialize(&body).expect("bincode to serialize");
db_ref.put_cf(blob_cf, key, data)?;
let expires = maybe_expires.map(|v| v.0).unwrap_or_default();
let meta = bincode::serialize(&expires).expect("bincode to serialize");
if db_ref.put_cf(meta_cf, key, meta).is_err() {
// try and roll back on metadata write failure
db_ref.delete_cf(blob_cf, key)?;
}
Result::<_, anyhow::Error>::Ok(())
})
.await
{
Ok(Ok(_)) => { Ok(Ok(_)) => {
if let Some(expires) = maybe_expires { if let Some(expires) = maybe_expires {
if let Expiration::UnixTime(time) = expires.0 { if let Expiration::UnixTime(expiration_time) = expires.0 {
let now = Utc::now(); let sleep_duration =
(expiration_time - Utc::now()).to_std().unwrap_or_default();
if time < now { task::spawn_blocking(move || async move {
if let Err(e) = db.delete(key) { tokio::time::sleep(sleep_duration).await;
let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap();
let meta_cf = db.cf_handle(META_CF_NAME).unwrap();
if let Err(e) = db.delete_cf(blob_cf, key) {
warn!("{}", e); warn!("{}", e);
} }
} else { if let Err(e) = db.delete_cf(meta_cf, key) {
let sleep_duration = (time - now).to_std().unwrap(); warn!("{}", e);
}
task::spawn_blocking(move || async move { });
tokio::time::sleep(sleep_duration).await;
if let Err(e) = db.delete(key) {
warn!("{}", e);
}
});
}
} }
} }
} }
@ -200,9 +229,9 @@ async fn paste<const N: usize>(
) -> Result<(HeaderMap, Bytes), StatusCode> { ) -> Result<(HeaderMap, Bytes), StatusCode> {
let key = url.as_bytes(); let key = url.as_bytes();
let parsed: Paste = { let metadata: Expiration = {
// not sure if perf of get_pinned is better than spawn_blocking let meta_cf = db.cf_handle(META_CF_NAME).unwrap();
let query_result = db.get_pinned(key).map_err(|e| { let query_result = db.get_cf(meta_cf, key).map_err(|e| {
error!("Failed to fetch initial query: {}", e); error!("Failed to fetch initial query: {}", e);
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})?; })?;
@ -218,23 +247,50 @@ async fn paste<const N: usize>(
})? })?
}; };
if parsed.expired() { // Check if paste has expired.
let join_handle = task::spawn_blocking(move || db.delete(key)) if let Expiration::UnixTime(expires) = metadata {
if expires < Utc::now() {
task::spawn_blocking(move || {
let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap();
let meta_cf = db.cf_handle(META_CF_NAME).unwrap();
if let Err(e) = db.delete_cf(blob_cf, &key) {
warn!("{}", e);
}
if let Err(e) = db.delete_cf(meta_cf, &key) {
warn!("{}", e);
}
})
.await .await
.map_err(|e| { .map_err(|e| {
error!("Failed to join handle: {}", e); error!("Failed to join handle: {}", e);
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})?; })?;
join_handle.map_err(|e| { return Err(StatusCode::NOT_FOUND);
error!("Failed to delete expired paste: {}", e); }
}
let paste: Bytes = {
// not sure if perf of get_pinned is better than spawn_blocking
let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap();
let query_result = db.get_pinned_cf(blob_cf, key).map_err(|e| {
error!("Failed to fetch initial query: {}", e);
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})?; })?;
return Err(StatusCode::NOT_FOUND); let data = match query_result {
} Some(data) => data,
None => return Err(StatusCode::NOT_FOUND),
};
if parsed.is_burn_after_read() { bincode::deserialize(&data).map_err(|_| {
error!("Failed to deserialize data?!");
StatusCode::INTERNAL_SERVER_ERROR
})?
};
// Check if we need to burn after read
if matches!(metadata, Expiration::BurnAfterReading) {
let join_handle = task::spawn_blocking(move || db.delete(key)) let join_handle = task::spawn_blocking(move || db.delete(key))
.await .await
.map_err(|e| { .map_err(|e| {
@ -249,10 +305,9 @@ async fn paste<const N: usize>(
} }
let mut map = HeaderMap::new(); let mut map = HeaderMap::new();
if let Some(expiration) = parsed.expiration { map.insert(EXPIRES, metadata.into());
map.insert(EXPIRES, expiration.into());
} Ok((map, paste))
Ok((map, parsed.bytes))
} }
#[instrument(skip(db))] #[instrument(skip(db))]
@ -260,8 +315,24 @@ async fn delete<const N: usize>(
Extension(db): Extension<Arc<DB>>, Extension(db): Extension<Arc<DB>>,
Path(url): Path<ShortCode<N>>, Path(url): Path<ShortCode<N>>,
) -> StatusCode { ) -> StatusCode {
match task::spawn_blocking(move || db.delete(url.as_bytes())).await { match task::spawn_blocking(move || {
Ok(Ok(_)) => StatusCode::OK, let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap();
let meta_cf = db.cf_handle(META_CF_NAME).unwrap();
if let Err(e) = db.delete_cf(blob_cf, url.as_bytes()) {
warn!("{}", e);
return Err(());
}
if let Err(e) = db.delete_cf(meta_cf, url.as_bytes()) {
warn!("{}", e);
return Err(());
}
Ok(())
})
.await
{
Ok(_) => StatusCode::OK,
_ => StatusCode::INTERNAL_SERVER_ERROR, _ => StatusCode::INTERNAL_SERVER_ERROR,
} }
} }

View file

@ -1,32 +0,0 @@
use axum::body::Bytes;
use chrono::Utc;
use omegaupload_common::Expiration;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct Paste {
pub expiration: Option<Expiration>,
pub bytes: Bytes,
}
impl Paste {
pub fn new(expiration: impl Into<Option<Expiration>>, bytes: Bytes) -> Self {
Self {
expiration: expiration.into(),
bytes,
}
}
pub fn expired(&self) -> bool {
self.expiration
.map(|expires| match expires {
Expiration::BurnAfterReading => false,
Expiration::UnixTime(expiration) => expiration < Utc::now(),
})
.unwrap_or_default()
}
pub const fn is_burn_after_read(&self) -> bool {
matches!(self.expiration, Some(Expiration::BurnAfterReading))
}
}