diff --git a/src/cache/fs.rs b/src/cache/fs.rs index 1335eb5..efc8038 100644 --- a/src/cache/fs.rs +++ b/src/cache/fs.rs @@ -1,5 +1,5 @@ use actix_web::error::PayloadError; -use bytes::{Buf, Bytes}; +use bytes::{Buf, Bytes, BytesMut}; use futures::{Future, Stream, StreamExt}; use log::debug; use once_cell::sync::Lazy; @@ -13,12 +13,13 @@ use std::pin::Pin; use std::task::{Context, Poll}; use tokio::fs::{create_dir_all, remove_file, File}; use tokio::io::{AsyncRead, AsyncSeekExt, AsyncWriteExt, BufReader, ReadBuf}; +use tokio::sync::mpsc::Sender; use tokio::sync::watch::{channel, Receiver}; use tokio::sync::RwLock; use tokio_stream::wrappers::WatchStream; use tokio_util::codec::{BytesCodec, FramedRead}; -use super::{BoxedImageStream, CacheStream, CacheStreamItem, ImageMetadata}; +use super::{BoxedImageStream, CacheKey, CacheStream, CacheStreamItem, ImageMetadata}; /// Keeps track of files that are currently being written to. /// @@ -70,12 +71,14 @@ pub async fn read_file( /// a stream that reads from disk instead. pub async fn write_file< Fut: 'static + Send + Sync + Future, - F: 'static + Send + Sync + FnOnce(u32) -> Fut, + DbCallback: 'static + Send + Sync + FnOnce(u32) -> Fut, >( path: &Path, + cache_key: CacheKey, mut byte_stream: BoxedImageStream, metadata: ImageMetadata, - db_callback: F, + db_callback: DbCallback, + on_complete: Option>, ) -> Result { let (tx, rx) = channel(WritingStatus::NotDone); @@ -88,17 +91,24 @@ pub async fn write_file< file }; - let metadata = serde_json::to_string(&metadata).unwrap(); - let metadata_size = metadata.len(); + let metadata_string = serde_json::to_string(&metadata).unwrap(); + let metadata_size = metadata_string.len(); // need owned variant because async lifetime let path_buf = path.to_path_buf(); tokio::spawn(async move { let path_buf = path_buf; // moves path buf into async let mut errored = false; let mut bytes_written: u32 = 0; - file.write_all(&metadata.as_bytes()).await?; + let mut acc_bytes = BytesMut::new(); + let accumulate = on_complete.is_some(); + file.write_all(metadata_string.as_bytes()).await?; + while let Some(bytes) = byte_stream.next().await { if let Ok(mut bytes) = bytes { + if accumulate { + acc_bytes.extend(&bytes); + } + loop { match file.write(&bytes).await? { 0 => break, @@ -141,6 +151,19 @@ pub async fn write_file< } tokio::spawn(db_callback(bytes_written)); + if accumulate { + tokio::spawn(async move { + let sender = on_complete.unwrap(); + sender + .send(( + cache_key, + acc_bytes.freeze(), + metadata, + bytes_written as usize, + )) + .await + }); + } // We don't ever check this, so the return value doesn't matter Ok::<_, std::io::Error>(()) diff --git a/src/cache/low_mem.rs b/src/cache/low_mem.rs deleted file mode 100644 index 7a7ebf6..0000000 --- a/src/cache/low_mem.rs +++ /dev/null @@ -1,185 +0,0 @@ -//! Low memory caching stuff - -use std::path::PathBuf; -use std::str::FromStr; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; - -use async_trait::async_trait; -use futures::StreamExt; -use log::{warn, LevelFilter}; -use sqlx::sqlite::SqliteConnectOptions; -use sqlx::{ConnectOptions, SqlitePool}; -use tokio::sync::mpsc::{channel, Sender}; -use tokio::{fs::remove_file, sync::mpsc::Receiver}; -use tokio_stream::wrappers::ReceiverStream; - -use super::{BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, ImageMetadata}; - -pub struct LowMemCache { - disk_path: PathBuf, - disk_cur_size: AtomicU64, - db_update_channel_sender: Sender, -} - -enum DbMessage { - Get(Arc), - Put(Arc, u32), -} - -impl LowMemCache { - /// Constructs a new low memory cache at the provided path and capaci ty. - /// This internally spawns a task that will wait for filesystem - /// notifications when a file has been written. - #[allow(clippy::new_ret_no_self)] - pub async fn new(disk_max_size: u64, disk_path: PathBuf) -> Arc> { - let (db_tx, db_rx) = channel(128); - let db_pool = { - let db_url = format!("sqlite:{}/metadata.sqlite", disk_path.to_str().unwrap()); - let mut options = SqliteConnectOptions::from_str(&db_url) - .unwrap() - .create_if_missing(true); - options.log_statements(LevelFilter::Trace); - let db = SqlitePool::connect_with(options).await.unwrap(); - - // Run db init - sqlx::query_file!("./db_queries/init.sql") - .execute(&mut db.acquire().await.unwrap()) - .await - .unwrap(); - - db - }; - - let new_self: Arc> = Arc::new(Box::new(Self { - disk_path, - disk_cur_size: AtomicU64::new(0), - db_update_channel_sender: db_tx, - })); - - tokio::spawn(db_listener( - Arc::clone(&new_self), - db_rx, - db_pool, - disk_max_size / 20 * 19, - )); - - new_self - } -} - -/// Spawn a new task that will listen for updates to the db, pruning if the size -/// becomes too large. -async fn db_listener( - cache: Arc>, - db_rx: Receiver, - db_pool: SqlitePool, - max_on_disk_size: u64, -) { - let mut recv_stream = ReceiverStream::new(db_rx).ready_chunks(128); - while let Some(messages) = recv_stream.next().await { - let now = chrono::Utc::now(); - let mut transaction = db_pool.begin().await.unwrap(); - for message in messages { - match message { - DbMessage::Get(entry) => { - let key = entry.as_os_str().to_str(); - let query = - sqlx::query!("update Images set accessed = ? where id = ?", now, key) - .execute(&mut transaction) - .await; - if let Err(e) = query { - warn!("Failed to update timestamp in db for {:?}: {}", key, e); - } - } - DbMessage::Put(entry, size) => { - let key = entry.as_os_str().to_str(); - let query = sqlx::query!( - "insert into Images (id, size, accessed) values (?, ?, ?)", - key, - size, - now, - ) - .execute(&mut transaction) - .await; - if let Err(e) = query { - warn!("Failed to add {:?} to db: {}", key, e); - } - - cache.increase_usage(size); - } - } - } - transaction.commit().await.unwrap(); - - if cache.on_disk_size() >= max_on_disk_size { - let mut conn = db_pool.acquire().await.unwrap(); - let items = - sqlx::query!("select id, size from Images order by accessed asc limit 1000") - .fetch_all(&mut conn) - .await - .unwrap(); - - let mut size_freed = 0; - for item in items { - size_freed += item.size as u64; - tokio::spawn(remove_file(item.id)); - } - - cache.decrease_usage(size_freed); - } - } -} - -#[async_trait] -impl Cache for LowMemCache { - async fn get( - &self, - key: &CacheKey, - ) -> Option> { - let channel = self.db_update_channel_sender.clone(); - - let path = Arc::new(self.disk_path.clone().join(PathBuf::from(key))); - let path_0 = Arc::clone(&path); - - tokio::spawn(async move { channel.send(DbMessage::Get(path_0)).await }); - - super::fs::read_file(&path) - .await - .map(|res| res.map_err(Into::into)) - } - - async fn put( - &self, - key: CacheKey, - image: BoxedImageStream, - metadata: ImageMetadata, - ) -> Result { - let channel = self.db_update_channel_sender.clone(); - - let path = Arc::new(self.disk_path.clone().join(PathBuf::from(key))); - let path_0 = Arc::clone(&path); - - let db_callback = |size: u32| async move { - let _ = channel.send(DbMessage::Put(path_0, size)).await; - }; - - super::fs::write_file(&path, image, metadata, db_callback) - .await - .map_err(Into::into) - } - - #[inline] - fn increase_usage(&self, amt: u32) { - self.disk_cur_size.fetch_add(amt as u64, Ordering::Release); - } - - #[inline] - fn on_disk_size(&self) -> u64 { - (self.disk_cur_size.load(Ordering::Acquire) + 4095) / 4096 * 4096 - } - - fn decrease_usage(&self, amt: u64) { - self.disk_cur_size.fetch_sub(amt, Ordering::Release); - } -} diff --git a/src/cache/mod.rs b/src/cache/mod.rs index cb65ab7..7fc5e09 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -12,14 +12,18 @@ use fs::ConcurrentFsStream; use futures::{Stream, StreamExt}; use serde::{Deserialize, Serialize}; use thiserror::Error; -use tokio::{fs::File, io::BufReader}; +use tokio::fs::File; +use tokio::io::BufReader; +use tokio::sync::mpsc::Sender; use tokio_util::codec::{BytesCodec, FramedRead}; +pub use disk_cache::LowMemCache; pub use fs::UpstreamError; -pub use low_mem::LowMemCache; +pub use mem_cache::MemoryLruCache; +mod disk_cache; mod fs; -mod low_mem; +mod mem_cache; #[derive(PartialEq, Eq, Hash, Clone)] pub struct CacheKey(pub String, pub String, pub bool); @@ -61,7 +65,7 @@ pub struct ImageMetadata { // Confirmed by Ply to be these types: https://link.eddie.sh/ZXfk0 #[derive(Copy, Clone, Serialize, Deserialize)] pub enum ImageContentType { - Png, + Png = 0, Jpeg, Gif, } @@ -167,6 +171,20 @@ pub trait Cache: Send + Sync { fn decrease_usage(&self, amt: u64); fn on_disk_size(&self) -> u64; + + fn mem_size(&self) -> u64; + + async fn put_with_on_completed_callback( + &self, + key: CacheKey, + image: BoxedImageStream, + metadata: ImageMetadata, + on_complete: Sender<(CacheKey, Bytes, ImageMetadata, usize)>, + ) -> Result; + + async fn put_internal(&self, key: CacheKey, image: Bytes, metadata: ImageMetadata, size: usize); + + async fn pop_memory(&self) -> Option<(CacheKey, Bytes, ImageMetadata, usize)>; } pub enum CacheStream { @@ -198,7 +216,7 @@ impl Stream for CacheStream { } } -pub struct MemStream(Bytes); +pub struct MemStream(pub Bytes); impl Stream for MemStream { type Item = CacheStreamItem; diff --git a/src/main.rs b/src/main.rs index b5a7515..9e3954a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,6 +23,8 @@ use state::{RwLockServerState, ServerState}; use stop::send_stop; use thiserror::Error; +use crate::cache::MemoryLruCache; + mod cache; mod config; mod ping; @@ -123,17 +125,12 @@ async fn main() -> Result<(), std::io::Error> { } }); - // let cache: Arc> = if low_mem_mode { - // LowMemCache::new(disk_quota, cache_path.clone()).await - // } else { - // Arc::new(Box::new(GenerationalCache::new( - // memory_max_size, - // disk_quota, - // cache_path.clone(), - // ))) - // }; + let cache: Arc> = if low_mem_mode { + LowMemCache::new(disk_quota, cache_path.clone()).await + } else { + MemoryLruCache::new(disk_quota, cache_path.clone(), memory_max_size).await + }; - let cache = LowMemCache::new(disk_quota, cache_path.clone()).await; let cache_0 = Arc::clone(&cache); // Start HTTPS server