diff --git a/src/cache/fs.rs b/src/cache/fs.rs index e60de4d..87922fa 100644 --- a/src/cache/fs.rs +++ b/src/cache/fs.rs @@ -1,20 +1,20 @@ -use std::collections::HashMap; +use bytes::BytesMut; +use futures::{Future, Stream, StreamExt}; +use once_cell::sync::Lazy; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; - -use bytes::{Bytes, BytesMut}; -use futures::{Future, Stream, StreamExt}; -use once_cell::sync::Lazy; -use reqwest::Error; +use std::{collections::HashMap, fmt::Display}; use tokio::fs::{remove_file, File}; use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf}; use tokio::sync::RwLock; use tokio::time::Sleep; +use super::{BoxedImageStream, CacheStreamItem}; + /// Keeps track of files that are currently being written to. /// /// Why is this necessary? Consider the following situation: @@ -35,7 +35,7 @@ static WRITING_STATUS: Lazy>>> = Lazy::new(|| RwLock::new(HashMap::new())); /// Tries to read from the file, returning a byte stream if it exists -pub async fn read_file(path: &Path) -> Option> { +pub async fn read_file(path: &Path) -> Option> { if path.exists() { let status = WRITING_STATUS .read() @@ -43,7 +43,7 @@ pub async fn read_file(path: &Path) -> Option Option> + Unpin + Send + 'static, -) -> Result { + mut byte_stream: BoxedImageStream, +) -> Result { let done_writing_flag = Arc::new(CacheStatus::new()); let mut file = { @@ -102,16 +102,16 @@ pub async fn write_file( Ok::<_, std::io::Error>(()) }); - Ok(FromFsStream::new(path, done_writing_flag).await?) + Ok(FsStream::new(path, done_writing_flag).await?) } -pub struct FromFsStream { +pub struct FsStream { file: Pin>, sleep: Pin>, is_file_done_writing: Arc, } -impl FromFsStream { +impl FsStream { async fn new(path: &Path, is_done: Arc) -> Result { Ok(Self { file: Box::pin(File::open(path).await?), @@ -123,10 +123,19 @@ impl FromFsStream { } /// Represents some upstream error. +#[derive(Debug)] pub struct UpstreamError; -impl Stream for FromFsStream { - type Item = Result; +impl std::error::Error for UpstreamError {} + +impl Display for UpstreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "An upstream error occurred") + } +} + +impl Stream for FsStream { + type Item = CacheStreamItem; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let status = self.is_file_done_writing.load(); @@ -148,6 +157,12 @@ impl Stream for FromFsStream { } } +impl From for actix_web::Error { + fn from(_: UpstreamError) -> Self { + todo!() + } +} + struct CacheStatus(AtomicU8); impl CacheStatus { diff --git a/src/cache/generational.rs b/src/cache/generational.rs index aaf917f..5630fd8 100644 --- a/src/cache/generational.rs +++ b/src/cache/generational.rs @@ -2,12 +2,15 @@ use std::path::PathBuf; use async_trait::async_trait; use bytes::Bytes; +use futures::{stream::StreamExt, TryStreamExt}; use log::{debug, warn}; use lru::LruCache; use tokio::fs::{remove_file, File}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use super::{Cache, CacheKey, CachedImage, ImageMetadata}; +use super::{ + BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, CachedImage, ImageMetadata, +}; pub struct GenerationalCache { in_memory: LruCache, @@ -132,9 +135,16 @@ impl GenerationalCache { #[async_trait] impl Cache for GenerationalCache { - async fn get(&mut self, key: &CacheKey) -> Option<&(CachedImage, ImageMetadata)> { + async fn get( + &mut self, + key: &CacheKey, + ) -> Option> { if self.in_memory.contains(key) { - return self.in_memory.get(key); + return self + .in_memory + .get(key) + // TODO: get rid of clone? + .map(|(image, metadata)| Ok((CacheStream::from(image.clone()), metadata))); } if let Some(metadata) = self.on_disk.pop(key) { @@ -149,7 +159,7 @@ impl Cache for GenerationalCache { let mut buffer = metadata .content_length - .map_or_else(Vec::new, Vec::with_capacity); + .map_or_else(Vec::new, |v| Vec::with_capacity(v as usize)); match file { Ok(mut file) => { @@ -173,20 +183,30 @@ impl Cache for GenerationalCache { buffer.shrink_to_fit(); self.disk_cur_size -= buffer.len() as u64; - let image = CachedImage(Bytes::from(buffer)); + let image = CacheStream::from(CachedImage(Bytes::from(buffer))).map_err(|e| e.into()); - // Since we just put it in the in-memory cache it should be there - // when we retrieve it - self.put(key.clone(), image, metadata).await; - return self.get(key).await; + return Some(self.put(key.clone(), Box::new(image), metadata).await); } None } - #[inline] - async fn put(&mut self, key: CacheKey, image: CachedImage, metadata: ImageMetadata) { + async fn put( + &mut self, + key: CacheKey, + mut image: BoxedImageStream, + metadata: ImageMetadata, + ) -> Result<(CacheStream, &ImageMetadata), CacheError> { let mut hot_evicted = vec![]; + + let image = { + let mut resolved = vec![]; + while let Some(bytes) = image.next().await { + resolved.extend(bytes?); + } + CachedImage(Bytes::from(resolved)) + }; + let new_img_size = image.0.len() as u64; if self.memory_max_size >= new_img_size { @@ -204,17 +224,19 @@ impl Cache for GenerationalCache { } } - self.in_memory.put(key, (image, metadata)); + self.in_memory.put(key.clone(), (image, metadata)); self.memory_cur_size += new_img_size; } else { // Image was larger than memory capacity, push directly into cold // storage. - self.push_into_cold(key, image, metadata).await; + self.push_into_cold(key.clone(), image, metadata).await; }; // Push evicted hot entires into cold storage. for (key, image, metadata) in hot_evicted { self.push_into_cold(key, image, metadata).await; } + + self.get(&key).await.unwrap() } } diff --git a/src/cache/low_mem.rs b/src/cache/low_mem.rs index d3ff8a4..59b9162 100644 --- a/src/cache/low_mem.rs +++ b/src/cache/low_mem.rs @@ -1,16 +1,14 @@ //! Low memory caching stuff -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use async_trait::async_trait; -use bytes::Bytes; -use futures::Stream; use lru::LruCache; -use super::{fs::FromFsStream, ByteStream, Cache, CacheKey}; +use super::{BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, ImageMetadata}; pub struct LowMemCache { - on_disk: LruCache, + on_disk: LruCache, disk_path: PathBuf, disk_max_size: u64, disk_cur_size: u64, @@ -27,21 +25,37 @@ impl LowMemCache { } } +// todo: schedule eviction + #[async_trait] impl Cache for LowMemCache { - async fn get_stream(&mut self, key: &CacheKey) -> Option> { - if self.on_disk.get(key).is_some() { - super::fs::read_file(Path::new(&key.to_string())).await + async fn get( + &mut self, + key: &CacheKey, + ) -> Option> { + if let Some(metadata) = self.on_disk.get(key) { + let path = self.disk_path.clone().join(PathBuf::from(key.clone())); + super::fs::read_file(&path).await.map(|res| { + res.map(|stream| (CacheStream::Fs(stream), metadata)) + .map_err(Into::into) + }) } else { None } } - async fn put_stream( + async fn put( &mut self, key: CacheKey, - image: ByteStream, - ) -> Result { - super::fs::write_file(&PathBuf::from(key), image).await + image: BoxedImageStream, + metadata: ImageMetadata, + ) -> Result<(CacheStream, &ImageMetadata), CacheError> { + let path = self.disk_path.clone().join(PathBuf::from(key.clone())); + self.on_disk.put(key.clone(), metadata); + super::fs::write_file(&path, image) + .await + .map(CacheStream::Fs) + .map(move |stream| (stream, self.on_disk.get(&key).unwrap())) + .map_err(Into::into) } } diff --git a/src/cache/mod.rs b/src/cache/mod.rs index 55819b0..5055d60 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -1,17 +1,21 @@ +use std::fmt::Display; use std::path::PathBuf; -use std::{fmt::Display, str::FromStr}; +use std::pin::Pin; +use std::str::FromStr; +use std::task::{Context, Poll}; use actix_web::http::HeaderValue; use async_trait::async_trait; use bytes::Bytes; use chrono::{DateTime, FixedOffset}; -use futures::Stream; +use fs::FsStream; +use futures::{Stream, StreamExt}; +use thiserror::Error; +pub use fs::UpstreamError; pub use generational::GenerationalCache; pub use low_mem::LowMemCache; -use self::fs::FromFsStream; - mod fs; mod generational; mod low_mem; @@ -36,23 +40,23 @@ impl From for PathBuf { } } +#[derive(Clone)] pub struct CachedImage(pub Bytes); #[derive(Copy, Clone)] pub struct ImageMetadata { pub content_type: Option, - pub content_length: Option, + // If we can guarantee a non-zero u32 here we can save 4 bytes + pub content_length: Option, pub last_modified: Option>, } -// Note to self: If these are wrong blame Triscuit 9 +// Confirmed by Ply to be these types: https://link.eddie.sh/ZXfk0 #[derive(Copy, Clone)] pub enum ImageContentType { Png, Jpeg, Gif, - Bmp, - Tif, } pub struct InvalidContentType; @@ -66,8 +70,6 @@ impl FromStr for ImageContentType { "image/png" => Ok(Self::Png), "image/jpeg" => Ok(Self::Jpeg), "image/gif" => Ok(Self::Gif), - "image/bmp" => Ok(Self::Bmp), - "image/tif" => Ok(Self::Tif), _ => Err(InvalidContentType), } } @@ -80,8 +82,6 @@ impl AsRef for ImageContentType { Self::Png => "image/png", Self::Jpeg => "image/jpeg", Self::Gif => "image/gif", - Self::Bmp => "image/bmp", - Self::Tif => "image/tif", } } } @@ -130,41 +130,78 @@ impl ImageMetadata { } } +type BoxedImageStream = Box> + Unpin + Send>; + +#[derive(Error, Debug)] +pub enum CacheError { + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Reqwest(#[from] reqwest::Error), + #[error(transparent)] + Upstream(#[from] UpstreamError), +} + #[async_trait] pub trait Cache: Send + Sync { - async fn get(&mut self, _key: &CacheKey) -> Option<&(CachedImage, ImageMetadata)> { - unimplemented!() - } - - async fn put(&mut self, _key: CacheKey, _image: CachedImage, _metadata: ImageMetadata) { - unimplemented!() - } - - async fn get_stream( + async fn get( &mut self, - _key: &CacheKey, - ) -> Option> { - unimplemented!() - } - - async fn put_stream( + key: &CacheKey, + ) -> Option>; + async fn put( &mut self, - _key: CacheKey, - _image: ByteStream, - ) -> Result { - unimplemented!() + key: CacheKey, + image: BoxedImageStream, + metadata: ImageMetadata, + ) -> Result<(CacheStream, &ImageMetadata), CacheError>; +} + +pub enum CacheStream { + Fs(FsStream), + Memory(MemStream), +} + +impl From for CacheStream { + fn from(image: CachedImage) -> Self { + Self::Memory(MemStream(image.0)) } } -pub enum ByteStream {} +type CacheStreamItem = Result; -impl Stream for ByteStream { - type Item = Result; +impl Stream for CacheStream { + type Item = CacheStreamItem; - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - todo!() + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Fs(stream) => stream.poll_next_unpin(cx), + Self::Memory(stream) => stream.poll_next_unpin(cx), + } + } +} + +pub struct MemStream(Bytes); + +impl Stream for MemStream { + type Item = CacheStreamItem; + + fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let mut new_bytes = Bytes::new(); + std::mem::swap(&mut self.0, &mut new_bytes); + if new_bytes.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(new_bytes))) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn metadata_size() { + assert_eq!(std::mem::size_of::(), 32); } } diff --git a/src/routes.rs b/src/routes.rs index 5f46beb..37bb7ae 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,4 +1,3 @@ -use std::convert::Infallible; use std::sync::atomic::Ordering; use actix_web::dev::HttpResponseBuilder; @@ -11,14 +10,14 @@ use actix_web::{get, web::Data, HttpRequest, HttpResponse, Responder}; use base64::DecodeError; use bytes::Bytes; use chrono::{DateTime, Utc}; -use futures::stream; +use futures::{Stream, TryStreamExt}; use log::{error, info, warn}; use parking_lot::Mutex; use serde::Deserialize; use sodiumoxide::crypto::box_::{open_precomputed, Nonce, PrecomputedKey, NONCEBYTES}; use thiserror::Error; -use crate::cache::{Cache, CacheKey, CachedImage, ImageMetadata}; +use crate::cache::{Cache, CacheKey, ImageMetadata, UpstreamError}; use crate::client_api_version; use crate::config::{SEND_SERVER_VERSION, VALIDATE_TOKENS}; use crate::state::RwLockServerState; @@ -182,8 +181,10 @@ async fn fetch_image( ) -> ServerResponse { let key = CacheKey(chapter_hash, file_name, is_data_saver); - if let Some((image, metadata)) = cache.lock().get(&key).await { - return construct_response(image, metadata); + match cache.lock().get(&key).await { + Some(Ok((image, metadata))) => return construct_response(image, metadata), + Some(Err(_)) => return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish()), + _ => (), } // It's important to not get a write lock before this request, else we're @@ -238,22 +239,17 @@ async fn fetch_image( headers.remove(LAST_MODIFIED), ) }; - let body = resp.bytes().await; - match body { - Ok(bytes) => { - let cached = ImageMetadata::new(content_type, length, last_mod).unwrap(); - let image = CachedImage(bytes); - let resp = construct_response(&image, &cached); - cache.lock().put(key, image, cached).await; - return resp; + + let body = resp.bytes_stream().map_err(|e| e.into()); + let metadata = ImageMetadata::new(content_type, length, last_mod).unwrap(); + let (stream, metadata) = { + match cache.lock().put(key, Box::new(body), metadata).await { + Ok((stream, metadata)) => (stream, *metadata), + Err(_) => todo!(), } - Err(e) => { - warn!("Got payload error from image server: {}", e); - ServerResponse::HttpResponse( - push_headers(&mut HttpResponse::ServiceUnavailable()).finish(), - ) - } - } + }; + + return construct_response(stream, &metadata); } Err(e) => { error!("Failed to fetch image from server: {}", e); @@ -264,23 +260,22 @@ async fn fetch_image( } } -fn construct_response(cached: &CachedImage, metadata: &ImageMetadata) -> ServerResponse { - let data: Vec> = cached - .0 - .to_vec() - .chunks(1460) // TCP MSS default size - .map(|v| Ok(Bytes::from(v.to_vec()))) - .collect(); +fn construct_response( + data: impl Stream> + Unpin + 'static, + metadata: &ImageMetadata, +) -> ServerResponse { let mut resp = HttpResponse::Ok(); - if let Some(content_type) = &metadata.content_type { + if let Some(content_type) = metadata.content_type { resp.append_header((CONTENT_TYPE, content_type.as_ref())); } - if let Some(content_length) = &metadata.content_length { - resp.append_header((CONTENT_LENGTH, content_length.to_string())); + + if let Some(content_length) = metadata.content_length { + resp.append_header((CONTENT_LENGTH, content_length)); } - if let Some(last_modified) = &metadata.last_modified { + + if let Some(last_modified) = metadata.last_modified { resp.append_header((LAST_MODIFIED, last_modified.to_rfc2822())); } - ServerResponse::HttpResponse(push_headers(&mut resp).streaming(stream::iter(data))) + ServerResponse::HttpResponse(push_headers(&mut resp).streaming(data)) }