partition rocksdb
This commit is contained in:
parent
beda106f6a
commit
8ea15ef395
4 changed files with 144 additions and 103 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -1070,6 +1070,7 @@ dependencies = [
|
|||
"bincode",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"futures",
|
||||
"headers",
|
||||
"omegaupload-common",
|
||||
"rand",
|
||||
|
|
|
@ -14,6 +14,7 @@ bincode = "1"
|
|||
# to enable the feature
|
||||
bytes = { version = "*", features = ["serde"] }
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
futures = "0.3"
|
||||
# We just need to pull in whatever axum is pulling in
|
||||
headers = "*"
|
||||
rand = "0.8"
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#![warn(clippy::nursery, clippy::pedantic)]
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::body::Bytes;
|
||||
|
@ -14,26 +15,38 @@ use headers::HeaderMap;
|
|||
use omegaupload_common::Expiration;
|
||||
use rand::thread_rng;
|
||||
use rand::Rng;
|
||||
use rocksdb::IteratorMode;
|
||||
use rocksdb::{ColumnFamilyDescriptor, IteratorMode};
|
||||
use rocksdb::{Options, DB};
|
||||
use tokio::task;
|
||||
use tracing::{error, instrument, trace};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::paste::Paste;
|
||||
use crate::short_code::ShortCode;
|
||||
|
||||
mod paste;
|
||||
mod short_code;
|
||||
|
||||
const BLOB_CF_NAME: &str = "blob";
|
||||
const META_CF_NAME: &str = "meta";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
const DB_PATH: &str = "database";
|
||||
const PASTE_DB_PATH: &str = "database";
|
||||
const SHORT_CODE_SIZE: usize = 12;
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let db = Arc::new(DB::open_default(DB_PATH)?);
|
||||
let mut db_options = Options::default();
|
||||
db_options.create_if_missing(true);
|
||||
db_options.create_missing_column_families(true);
|
||||
db_options.set_compression_type(rocksdb::DBCompressionType::Zstd);
|
||||
let db = Arc::new(DB::open_cf_descriptors(
|
||||
&db_options,
|
||||
PASTE_DB_PATH,
|
||||
[
|
||||
ColumnFamilyDescriptor::new(BLOB_CF_NAME, Options::default()),
|
||||
ColumnFamilyDescriptor::new(META_CF_NAME, Options::default()),
|
||||
],
|
||||
)?);
|
||||
|
||||
set_up_expirations(Arc::clone(&db));
|
||||
|
||||
|
@ -51,7 +64,7 @@ async fn main() -> Result<()> {
|
|||
.await?;
|
||||
|
||||
// Must be called for correct shutdown
|
||||
DB::destroy(&Options::default(), DB_PATH)?;
|
||||
DB::destroy(&Options::default(), PASTE_DB_PATH)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -59,42 +72,51 @@ fn set_up_expirations(db: Arc<DB>) {
|
|||
let mut corrupted = 0;
|
||||
let mut expired = 0;
|
||||
let mut pending = 0;
|
||||
let mut permanent = 0;
|
||||
|
||||
info!("Setting up cleanup timers, please wait...");
|
||||
|
||||
for (key, value) in db.iterator(IteratorMode::Start) {
|
||||
let paste = if let Ok(value) = bincode::deserialize::<Paste>(&value) {
|
||||
let meta_cf = db.cf_handle(META_CF_NAME).unwrap();
|
||||
|
||||
let db_ref = Arc::clone(&db);
|
||||
|
||||
let delete_entry = move |key: &[u8]| {
|
||||
let blob_cf = db_ref.cf_handle(BLOB_CF_NAME).unwrap();
|
||||
let meta_cf = db_ref.cf_handle(META_CF_NAME).unwrap();
|
||||
if let Err(e) = db_ref.delete_cf(blob_cf, &key) {
|
||||
warn!("{}", e);
|
||||
}
|
||||
if let Err(e) = db_ref.delete_cf(meta_cf, &key) {
|
||||
warn!("{}", e);
|
||||
}
|
||||
};
|
||||
|
||||
for (key, value) in db.iterator_cf(meta_cf, IteratorMode::Start) {
|
||||
let expires = if let Ok(value) = bincode::deserialize::<Expiration>(&value) {
|
||||
value
|
||||
} else {
|
||||
corrupted += 1;
|
||||
if let Err(e) = db.delete(key) {
|
||||
warn!("{}", e);
|
||||
}
|
||||
delete_entry(&key);
|
||||
continue;
|
||||
};
|
||||
if let Some(Expiration::UnixTime(time)) = paste.expiration {
|
||||
let now = Utc::now();
|
||||
|
||||
if time < now {
|
||||
expired += 1;
|
||||
if let Err(e) = db.delete(key) {
|
||||
warn!("{}", e);
|
||||
let expiration_time = match expires {
|
||||
Expiration::BurnAfterReading => {
|
||||
panic!("Got burn after reading expiration time? Invariant violated");
|
||||
}
|
||||
} else {
|
||||
let sleep_duration = (time - now).to_std().unwrap();
|
||||
pending += 1;
|
||||
Expiration::UnixTime(time) => time,
|
||||
};
|
||||
|
||||
let db_ref = Arc::clone(&db);
|
||||
let sleep_duration = (expiration_time - Utc::now()).to_std().unwrap_or_default();
|
||||
if sleep_duration != Duration::default() {
|
||||
pending += 1;
|
||||
let delete_entry_ref = delete_entry.clone();
|
||||
task::spawn_blocking(move || async move {
|
||||
tokio::time::sleep(sleep_duration).await;
|
||||
if let Err(e) = db_ref.delete(key) {
|
||||
warn!("{}", e);
|
||||
}
|
||||
delete_entry_ref(&key);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
permanent += 1;
|
||||
expired += 1;
|
||||
delete_entry(&key);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -105,7 +127,6 @@ fn set_up_expirations(db: Arc<DB>) {
|
|||
}
|
||||
info!("Found {} expired pastes.", expired);
|
||||
info!("Found {} active pastes.", pending);
|
||||
info!("Found {} permanent pastes.", permanent);
|
||||
info!("Cleanup timers have been initialized.");
|
||||
}
|
||||
|
||||
|
@ -124,7 +145,6 @@ async fn upload<const N: usize>(
|
|||
return Err(StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
|
||||
let paste = Paste::new(maybe_expires.map(|v| v.0).unwrap_or_default(), body);
|
||||
let mut new_key = None;
|
||||
|
||||
trace!("Generating short code...");
|
||||
|
@ -135,7 +155,10 @@ async fn upload<const N: usize>(
|
|||
let code: ShortCode<N> = thread_rng().sample(short_code::Generator);
|
||||
let db = Arc::clone(&db);
|
||||
let key = code.as_bytes();
|
||||
let query = task::spawn_blocking(move || db.key_may_exist(key)).await;
|
||||
let query = task::spawn_blocking(move || {
|
||||
db.key_may_exist_cf(db.cf_handle(META_CF_NAME).unwrap(), key)
|
||||
})
|
||||
.await;
|
||||
if matches!(query, Ok(false)) {
|
||||
new_key = Some(key);
|
||||
trace!("Found new key after {} attempts.", i);
|
||||
|
@ -151,39 +174,45 @@ async fn upload<const N: usize>(
|
|||
};
|
||||
|
||||
trace!("Serializing paste...");
|
||||
let value = if let Ok(v) = bincode::serialize(&paste) {
|
||||
v
|
||||
} else {
|
||||
error!("Failed to serialize paste?!");
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
};
|
||||
|
||||
trace!("Finished serializing paste.");
|
||||
|
||||
let db_ref = Arc::clone(&db);
|
||||
match task::spawn_blocking(move || db_ref.put(key, value)).await {
|
||||
match task::spawn_blocking(move || {
|
||||
let blob_cf = db_ref.cf_handle(BLOB_CF_NAME).unwrap();
|
||||
let meta_cf = db_ref.cf_handle(META_CF_NAME).unwrap();
|
||||
let data = bincode::serialize(&body).expect("bincode to serialize");
|
||||
db_ref.put_cf(blob_cf, key, data)?;
|
||||
let expires = maybe_expires.map(|v| v.0).unwrap_or_default();
|
||||
let meta = bincode::serialize(&expires).expect("bincode to serialize");
|
||||
if db_ref.put_cf(meta_cf, key, meta).is_err() {
|
||||
// try and roll back on metadata write failure
|
||||
db_ref.delete_cf(blob_cf, key)?;
|
||||
}
|
||||
Result::<_, anyhow::Error>::Ok(())
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(Ok(_)) => {
|
||||
if let Some(expires) = maybe_expires {
|
||||
if let Expiration::UnixTime(time) = expires.0 {
|
||||
let now = Utc::now();
|
||||
|
||||
if time < now {
|
||||
if let Err(e) = db.delete(key) {
|
||||
warn!("{}", e);
|
||||
}
|
||||
} else {
|
||||
let sleep_duration = (time - now).to_std().unwrap();
|
||||
if let Expiration::UnixTime(expiration_time) = expires.0 {
|
||||
let sleep_duration =
|
||||
(expiration_time - Utc::now()).to_std().unwrap_or_default();
|
||||
|
||||
task::spawn_blocking(move || async move {
|
||||
tokio::time::sleep(sleep_duration).await;
|
||||
if let Err(e) = db.delete(key) {
|
||||
let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap();
|
||||
let meta_cf = db.cf_handle(META_CF_NAME).unwrap();
|
||||
if let Err(e) = db.delete_cf(blob_cf, key) {
|
||||
warn!("{}", e);
|
||||
}
|
||||
if let Err(e) = db.delete_cf(meta_cf, key) {
|
||||
warn!("{}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
e => {
|
||||
error!("Failed to insert paste into db: {:?}", e);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
@ -200,9 +229,9 @@ async fn paste<const N: usize>(
|
|||
) -> Result<(HeaderMap, Bytes), StatusCode> {
|
||||
let key = url.as_bytes();
|
||||
|
||||
let parsed: Paste = {
|
||||
// not sure if perf of get_pinned is better than spawn_blocking
|
||||
let query_result = db.get_pinned(key).map_err(|e| {
|
||||
let metadata: Expiration = {
|
||||
let meta_cf = db.cf_handle(META_CF_NAME).unwrap();
|
||||
let query_result = db.get_cf(meta_cf, key).map_err(|e| {
|
||||
error!("Failed to fetch initial query: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
@ -218,23 +247,50 @@ async fn paste<const N: usize>(
|
|||
})?
|
||||
};
|
||||
|
||||
if parsed.expired() {
|
||||
let join_handle = task::spawn_blocking(move || db.delete(key))
|
||||
// Check if paste has expired.
|
||||
if let Expiration::UnixTime(expires) = metadata {
|
||||
if expires < Utc::now() {
|
||||
task::spawn_blocking(move || {
|
||||
let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap();
|
||||
let meta_cf = db.cf_handle(META_CF_NAME).unwrap();
|
||||
if let Err(e) = db.delete_cf(blob_cf, &key) {
|
||||
warn!("{}", e);
|
||||
}
|
||||
if let Err(e) = db.delete_cf(meta_cf, &key) {
|
||||
warn!("{}", e);
|
||||
}
|
||||
})
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to join handle: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
join_handle.map_err(|e| {
|
||||
error!("Failed to delete expired paste: {}", e);
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
}
|
||||
|
||||
let paste: Bytes = {
|
||||
// not sure if perf of get_pinned is better than spawn_blocking
|
||||
let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap();
|
||||
let query_result = db.get_pinned_cf(blob_cf, key).map_err(|e| {
|
||||
error!("Failed to fetch initial query: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
let data = match query_result {
|
||||
Some(data) => data,
|
||||
None => return Err(StatusCode::NOT_FOUND),
|
||||
};
|
||||
|
||||
if parsed.is_burn_after_read() {
|
||||
bincode::deserialize(&data).map_err(|_| {
|
||||
error!("Failed to deserialize data?!");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
};
|
||||
|
||||
// Check if we need to burn after read
|
||||
if matches!(metadata, Expiration::BurnAfterReading) {
|
||||
let join_handle = task::spawn_blocking(move || db.delete(key))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
|
@ -249,10 +305,9 @@ async fn paste<const N: usize>(
|
|||
}
|
||||
|
||||
let mut map = HeaderMap::new();
|
||||
if let Some(expiration) = parsed.expiration {
|
||||
map.insert(EXPIRES, expiration.into());
|
||||
}
|
||||
Ok((map, parsed.bytes))
|
||||
map.insert(EXPIRES, metadata.into());
|
||||
|
||||
Ok((map, paste))
|
||||
}
|
||||
|
||||
#[instrument(skip(db))]
|
||||
|
@ -260,8 +315,24 @@ async fn delete<const N: usize>(
|
|||
Extension(db): Extension<Arc<DB>>,
|
||||
Path(url): Path<ShortCode<N>>,
|
||||
) -> StatusCode {
|
||||
match task::spawn_blocking(move || db.delete(url.as_bytes())).await {
|
||||
Ok(Ok(_)) => StatusCode::OK,
|
||||
match task::spawn_blocking(move || {
|
||||
let blob_cf = db.cf_handle(BLOB_CF_NAME).unwrap();
|
||||
let meta_cf = db.cf_handle(META_CF_NAME).unwrap();
|
||||
if let Err(e) = db.delete_cf(blob_cf, url.as_bytes()) {
|
||||
warn!("{}", e);
|
||||
return Err(());
|
||||
}
|
||||
|
||||
if let Err(e) = db.delete_cf(meta_cf, url.as_bytes()) {
|
||||
warn!("{}", e);
|
||||
return Err(());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(_) => StatusCode::OK,
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
use axum::body::Bytes;
|
||||
use chrono::Utc;
|
||||
use omegaupload_common::Expiration;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Paste {
|
||||
pub expiration: Option<Expiration>,
|
||||
pub bytes: Bytes,
|
||||
}
|
||||
|
||||
impl Paste {
|
||||
pub fn new(expiration: impl Into<Option<Expiration>>, bytes: Bytes) -> Self {
|
||||
Self {
|
||||
expiration: expiration.into(),
|
||||
bytes,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn expired(&self) -> bool {
|
||||
self.expiration
|
||||
.map(|expires| match expires {
|
||||
Expiration::BurnAfterReading => false,
|
||||
Expiration::UnixTime(expiration) => expiration < Utc::now(),
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub const fn is_burn_after_read(&self) -> bool {
|
||||
matches!(self.expiration, Some(Expiration::BurnAfterReading))
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue