From 93249397f178311399b2e302e196dfff59241d26 Mon Sep 17 00:00:00 2001 From: Edward Shen Date: Sat, 10 Jul 2021 18:53:28 -0400 Subject: [PATCH] Simply codebase --- src/cache/disk.rs | 18 +-- src/cache/fs.rs | 276 +++++++--------------------------------------- src/cache/mem.rs | 7 +- src/cache/mod.rs | 27 +++-- src/client.rs | 214 +++++++++++++++++++++++++++++++++++ src/config.rs | 10 +- src/main.rs | 1 + src/routes.rs | 131 +++++----------------- 8 files changed, 311 insertions(+), 373 deletions(-) create mode 100644 src/client.rs diff --git a/src/cache/disk.rs b/src/cache/disk.rs index 643b1a6..75b0dc0 100644 --- a/src/cache/disk.rs +++ b/src/cache/disk.rs @@ -16,9 +16,7 @@ use tokio_stream::wrappers::ReceiverStream; use crate::units::Bytes; -use super::{ - BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata, -}; +use super::{Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata}; pub struct DiskCache { disk_path: PathBuf, @@ -210,9 +208,9 @@ impl Cache for DiskCache { async fn put( &self, key: CacheKey, - image: BoxedImageStream, + image: bytes::Bytes, metadata: ImageMetadata, - ) -> Result { + ) -> Result<(), CacheError> { let channel = self.db_update_channel_sender.clone(); let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key))); @@ -225,9 +223,6 @@ impl Cache for DiskCache { super::fs::write_file(&path, key, image, metadata, db_callback, None) .await .map_err(CacheError::from) - .and_then(|(inner, maybe_header)| { - CacheStream::new(inner, maybe_header).map_err(|_| CacheError::DecryptionFailure) - }) } } @@ -236,10 +231,10 @@ impl CallbackCache for DiskCache { async fn put_with_on_completed_callback( &self, key: CacheKey, - image: BoxedImageStream, + image: bytes::Bytes, metadata: ImageMetadata, on_complete: Sender<(CacheKey, bytes::Bytes, ImageMetadata, u64)>, - ) -> Result { + ) -> Result<(), CacheError> { let channel = self.db_update_channel_sender.clone(); let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key))); @@ -253,8 +248,5 @@ impl CallbackCache for DiskCache { super::fs::write_file(&path, key, image, metadata, db_callback, Some(on_complete)) .await .map_err(CacheError::from) - .and_then(|(inner, maybe_header)| { - CacheStream::new(inner, maybe_header).map_err(|_| CacheError::DecryptionFailure) - }) } } diff --git a/src/cache/fs.rs b/src/cache/fs.rs index 13bacce..c06cfd4 100644 --- a/src/cache/fs.rs +++ b/src/cache/fs.rs @@ -14,35 +14,26 @@ //! upstream no longer needs to process duplicate requests and sequential cache //! misses are treated as closer as a cache hit. -use std::collections::HashMap; use std::error::Error; use std::fmt::Display; -use std::io::SeekFrom; -use std::num::NonZeroU64; -use std::path::{Path, PathBuf}; +use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; use actix_web::error::PayloadError; -use bytes::{Buf, Bytes, BytesMut}; -use futures::{Future, Stream, StreamExt}; +use bytes::Bytes; +use futures::Future; use log::{debug, warn}; -use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::secretstream::{ Header, Pull, Push, Stream as SecretStream, Tag, HEADERBYTES, }; use tokio::fs::{create_dir_all, remove_file, File}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::sync::mpsc::Sender; -use tokio::sync::watch::{channel, Receiver}; -use tokio::sync::RwLock; -use tokio_stream::wrappers::WatchStream; use tokio_util::codec::{BytesCodec, FramedRead}; -use super::{ - BoxedImageStream, CacheKey, CacheStreamItem, ImageMetadata, InnerStream, ENCRYPTION_KEY, -}; +use super::{CacheKey, ImageMetadata, InnerStream, ENCRYPTION_KEY}; #[derive(Serialize, Deserialize)] pub enum OnDiskMetadata { @@ -50,25 +41,6 @@ pub enum OnDiskMetadata { Plaintext(ImageMetadata), } -/// Keeps track of files that are currently being written to. -/// -/// Why is this necessary? Consider the following situation: -/// -/// Client A requests file `foo.png`. We construct a transparent file stream, -/// and now the file is being streamed into and from. -/// -/// Client B requests the same file `foo.png`. A naive implementation would -/// attempt to either read directly the file as it sees the file existing. This -/// is problematic as the file could still be written to. If Client B catches -/// up to Client A's request, then Client B could receive a broken image, as it -/// thinks it's done reading the file. -/// -/// We effectively use `WRITING_STATUS` as a status relay to ensure concurrent -/// reads to the file while it's being written to will wait for writing to be -/// completed. -static WRITING_STATUS: Lazy>>> = - Lazy::new(|| RwLock::new(HashMap::new())); - /// Attempts to lookup the file on disk, returning a byte stream if it exists. /// Note that this could return two types of streams, depending on if the file /// is in progress of being written to. @@ -100,7 +72,7 @@ pub(super) async fn read_file( let mut file = File::from_std(file_0); let file_0 = file.try_clone().await.unwrap(); - // image is decrypted or corrupt + // image is encrypted or corrupt // If the encryption key was set, use the encrypted disk reader instead; // else, just directly read from file. @@ -142,20 +114,7 @@ pub(super) async fn read_file( // successfully decoded the data; otherwise the file is garbage. if let Some(reader) = reader { - // False positive lint, `file` is used in both cases, which means that it's - // not possible to move this into a map_or_else without cloning `file`. - #[allow(clippy::option_if_let_else)] - let stream = if let Some(status) = WRITING_STATUS.read().await.get(path).map(Clone::clone) { - debug!("Got an in-progress stream"); - InnerStream::Concurrent(ConcurrentFsStream::from_reader( - reader, - WatchStream::new(status), - )) - } else { - debug!("Got a completed stream"); - InnerStream::Completed(FramedRead::new(reader, BytesCodec::new())) - }; - + let stream = InnerStream::Completed(FramedRead::new(reader, BytesCodec::new())); parsed_metadata.map(|metadata| Ok((stream, maybe_header, metadata))) } else { debug!("Reader was invalid, file is corrupt"); @@ -229,23 +188,19 @@ impl AsyncRead for EncryptedDiskReader { pub(super) async fn write_file( path: &Path, cache_key: CacheKey, - mut byte_stream: BoxedImageStream, + bytes: Bytes, metadata: ImageMetadata, db_callback: DbCallback, on_complete: Option>, -) -> Result<(InnerStream, Option
), std::io::Error> +) -> Result<(), std::io::Error> where Fut: 'static + Send + Sync + Future, DbCallback: 'static + Send + Sync + FnOnce(u64) -> Fut, { - let (tx, rx) = channel(WritingStatus::Writing(0)); - let file = { - let mut write_lock = WRITING_STATUS.write().await; let parent = path.parent().expect("The path to have a parent"); create_dir_all(parent).await?; let file = File::create(path).await?; // we need to make sure the file exists and is truncated. - write_lock.insert(path.to_path_buf(), rx.clone()); file }; @@ -262,88 +217,43 @@ where (Box::pin(file), None) }; - // need owned variant because async lifetime - let path_buf = path.to_path_buf(); - tokio::spawn(async move { - let path_buf = path_buf; // moves path buf into async - let mut errored = false; - let mut bytes_written: u64 = 0; - let mut acc_bytes = BytesMut::new(); - let accumulate = on_complete.is_some(); - writer.write_all(metadata_string.as_bytes()).await?; + let mut error = if let Some(header) = maybe_header { + writer.write_all(header.as_ref()).await.err() + } else { + None + }; - while let Some(bytes) = byte_stream.next().await { - if let Ok(mut bytes) = bytes { - if accumulate { - acc_bytes.extend(&bytes); - } + if error.is_none() { + error = writer.write_all(metadata_string.as_bytes()).await.err(); + } + if error.is_none() { + error = error.or(writer.write_all(&bytes).await.err()); + } - loop { - match writer.write(&bytes).await? { - 0 => break, - n => { - bytes.advance(n); - bytes_written += n as u64; + if let Some(e) = error { + // It's ok if the deleting the file fails, since we truncate on + // create anyways, but it should be best effort. + // + // We don't care about the result of the call. + std::mem::drop(remove_file(path).await); + return Err(e); + } - // We don't really care if we have no receivers - let _ = tx.send(WritingStatus::Writing(n as u64)); - } - } - } - } else { - errored = true; - break; - } - } + writer.flush().await?; + debug!("writing to file done"); - if errored { - // It's ok if the deleting the file fails, since we truncate on - // create anyways, but it should be best effort. - // - // We don't care about the result of the call. - std::mem::drop(remove_file(&path_buf).await); - } else { - writer.flush().await?; - debug!("writing to file done"); - } + let bytes_written = (metadata_size + bytes.len()) as u64; + tokio::spawn(db_callback(bytes_written)); - { - let mut write_lock = WRITING_STATUS.write().await; - // This needs to be written atomically with the write lock, else - // it's possible we have an inconsistent state - // - // We don't really care if we have no receivers - if errored { - let _ = tx.send(WritingStatus::Error); - } - // Explicitly drop it here since we're done with sending values. - // This is ok since we have a stream adapter on the other end. - // We must drop it here in the critical section, hence the explicit - // drop. - std::mem::drop(tx); - write_lock.remove(&path_buf); - } + if let Some(sender) = on_complete { + tokio::spawn(async move { + sender + .send((cache_key, bytes, metadata, bytes_written)) + .await + }); + } - tokio::spawn(db_callback(bytes_written)); - - if let Some(sender) = on_complete { - tokio::spawn(async move { - sender - .send((cache_key, acc_bytes.freeze(), metadata, bytes_written)) - .await - }); - } - - // We don't ever check this, so the return value doesn't matter - Ok::<_, std::io::Error>(()) - }); - - Ok(( - InnerStream::Concurrent( - ConcurrentFsStream::new(path, metadata_size, WatchStream::new(rx)).await?, - ), - maybe_header, - )) + Ok(()) } struct EncryptedDiskWriter { @@ -432,43 +342,6 @@ impl AsyncWrite for EncryptedDiskWriter { } } -pub struct ConcurrentFsStream { - /// The File to read from - reader: Pin>, - /// The channel to get updates from. The writer must send its status, else - /// this reader will never complete. - receiver: Option>>>, - /// The number of bytes the reader has read - bytes_read: u64, - /// The number of bytes that the writer has reported it has written. If the - /// writer has not reported yet, this value is None. - bytes_total: Option, -} - -impl ConcurrentFsStream { - async fn new( - path: &Path, - seek: usize, - receiver: WatchStream, - ) -> Result { - let mut file = File::open(path).await?; - file.seek(SeekFrom::Start(seek as u64)).await?; - Ok(Self::from_reader(Box::pin(file), receiver)) - } - - fn from_reader( - reader: Pin>, - receiver: WatchStream, - ) -> Self { - Self { - reader: Box::pin(reader), - receiver: Some(Box::pin(receiver)), - bytes_read: 0, - bytes_total: None, - } - } -} - /// Represents some upstream error. #[derive(Debug)] pub struct UpstreamError; @@ -481,78 +354,9 @@ impl Display for UpstreamError { } } -impl Stream for ConcurrentFsStream { - type Item = CacheStreamItem; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.receiver.as_mut().map(|v| v.poll_next_unpin(cx)) { - Some(Poll::Ready(Some(WritingStatus::Writing(n)))) => match self.bytes_total.as_mut() { - Some(v) => *v = unsafe { NonZeroU64::new_unchecked(v.get() + n) }, - None => self.bytes_total = unsafe { Some(NonZeroU64::new_unchecked(n)) }, - }, - Some(Poll::Ready(Some(WritingStatus::Error))) => { - // Upstream errored, abort reading - return Poll::Ready(Some(Err(UpstreamError))); - } - Some(Poll::Ready(None)) => { - // Take the receiver so that we can't poll it again - self.receiver.take(); - } - Some(Poll::Pending) | None => (), - } - - // We are entirely done if the bytes total equals the bytes read - if Some(self.bytes_read) == self.bytes_total.map(NonZeroU64::get) { - return Poll::Ready(None); - } - - // We're not done, so try reading from the file. - - // TODO: Might be more efficient to have a larger buffer - let mut bytes = [0; 4 * 1024].to_vec(); - let mut buffer = ReadBuf::new(&mut bytes); - match self.reader.as_mut().poll_read(cx, &mut buffer) { - Poll::Ready(Ok(_)) => (), - Poll::Ready(Err(_)) => return Poll::Ready(Some(Err(UpstreamError))), - Poll::Pending => return Poll::Pending, - } - - // At this point, we know that we "successfully" read some amount of - // data. Let's see if there's actual data in there... - - let filled = buffer.filled().len(); - if filled == 0 { - // We haven't read enough bytes, but we know there's more to read, - // so just return an empty bytes and have the executor request some - // bytes some time in the future. - // - // This case might be solved by io_uring, but for now this is this - // the best we can do. - Poll::Ready(Some(Ok(Bytes::new()))) - } else { - // We have data! Give it to the reader! - self.bytes_read += filled as u64; - bytes.truncate(filled); - Poll::Ready(Some(Ok(bytes.into()))) - } - } -} - impl From for actix_web::Error { #[inline] fn from(_: UpstreamError) -> Self { PayloadError::Incomplete(None).into() } } - -#[derive(Debug, Clone, Copy)] -enum WritingStatus { - Writing(u64), - Error, -} - -#[cfg(test)] -mod storage { - #[test] - fn wut() {} -} diff --git a/src/cache/mem.rs b/src/cache/mem.rs index f9f5701..1a7060a 100644 --- a/src/cache/mem.rs +++ b/src/cache/mem.rs @@ -2,8 +2,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use super::{ - BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata, - InnerStream, MemStream, + Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata, InnerStream, MemStream, }; use async_trait::async_trait; use bytes::Bytes; @@ -164,9 +163,9 @@ where async fn put( &self, key: CacheKey, - image: BoxedImageStream, + image: Bytes, metadata: ImageMetadata, - ) -> Result { + ) -> Result<(), super::CacheError> { self.inner .put_with_on_completed_callback(key, image, metadata, self.master_sender.clone()) .await diff --git a/src/cache/mod.rs b/src/cache/mod.rs index e52298e..baf1604 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -9,7 +9,6 @@ use actix_web::http::HeaderValue; use async_trait::async_trait; use bytes::{Bytes, BytesMut}; use chrono::{DateTime, FixedOffset}; -use fs::ConcurrentFsStream; use futures::{Stream, StreamExt}; use once_cell::sync::OnceCell; use serde::{Deserialize, Serialize}; @@ -148,8 +147,6 @@ impl ImageMetadata { } } -type BoxedImageStream = Box> + Unpin + Send>; - #[derive(Error, Debug)] pub enum CacheError { #[error(transparent)] @@ -170,9 +167,9 @@ pub trait Cache: Send + Sync { async fn put( &self, key: CacheKey, - image: BoxedImageStream, + image: Bytes, metadata: ImageMetadata, - ) -> Result; + ) -> Result<(), CacheError>; } #[async_trait] @@ -189,9 +186,9 @@ impl Cache for Arc { async fn put( &self, key: CacheKey, - image: BoxedImageStream, + image: Bytes, metadata: ImageMetadata, - ) -> Result { + ) -> Result<(), CacheError> { self.as_ref().put(key, image, metadata).await } } @@ -201,10 +198,10 @@ pub trait CallbackCache: Cache { async fn put_with_on_completed_callback( &self, key: CacheKey, - image: BoxedImageStream, + image: Bytes, metadata: ImageMetadata, on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>, - ) -> Result; + ) -> Result<(), CacheError>; } #[async_trait] @@ -213,10 +210,10 @@ impl CallbackCache for Arc { async fn put_with_on_completed_callback( &self, key: CacheKey, - image: BoxedImageStream, + image: Bytes, metadata: ImageMetadata, on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>, - ) -> Result { + ) -> Result<(), CacheError> { self.as_ref() .put_with_on_completed_callback(key, image, metadata, on_complete) .await @@ -233,7 +230,11 @@ impl CacheStream { Ok(Self { inner, decrypt: header - .and_then(|header| ENCRYPTION_KEY.get().map(|key| SecretStream::init_pull(&header, key))) + .and_then(|header| { + ENCRYPTION_KEY + .get() + .map(|key| SecretStream::init_pull(&header, key)) + }) .transpose()?, }) } @@ -263,7 +264,6 @@ impl Stream for CacheStream { } pub(self) enum InnerStream { - Concurrent(ConcurrentFsStream), Memory(MemStream), Completed(FramedRead>, BytesCodec>), } @@ -281,7 +281,6 @@ impl Stream for InnerStream { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - Self::Concurrent(stream) => stream.poll_next_unpin(cx), Self::Memory(stream) => stream.poll_next_unpin(cx), Self::Completed(stream) => stream .poll_next_unpin(cx) diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..9e35d7d --- /dev/null +++ b/src/client.rs @@ -0,0 +1,214 @@ +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use actix_web::{ + http::{HeaderMap, HeaderName, HeaderValue}, + web::Data, +}; +use bytes::Bytes; +use log::{debug, error, warn}; +use once_cell::sync::Lazy; +use parking_lot::RwLock; +use reqwest::{ + header::{ + ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, CACHE_CONTROL, CONTENT_LENGTH, + CONTENT_TYPE, LAST_MODIFIED, X_CONTENT_TYPE_OPTIONS, + }, + Client, StatusCode, +}; +use tokio::sync::{ + watch::{channel, Receiver}, + Notify, +}; + +use crate::cache::{Cache, CacheKey, ImageMetadata}; + +pub static HTTP_CLIENT: Lazy = Lazy::new(|| CachingClient { + inner: Client::builder() + .pool_idle_timeout(Duration::from_secs(180)) + .https_only(true) + .http2_prior_knowledge() + .build() + .expect("Client initialization to work"), + locks: RwLock::new(HashMap::new()), +}); + +pub struct CachingClient { + inner: Client, + locks: RwLock>>, +} + +#[derive(Clone, Debug)] +pub enum FetchResult { + ServiceUnavailable, + InternalServerError, + Data(StatusCode, HeaderMap, Bytes), + Processing, +} + +impl CachingClient { + pub async fn fetch_and_cache( + &'static self, + url: String, + key: CacheKey, + cache: Data, + ) -> FetchResult { + if let Some(recv) = self.locks.read().get(&url) { + let mut recv = recv.clone(); + loop { + if !matches!(*recv.borrow(), FetchResult::Processing) { + break; + } + if recv.changed().await.is_err() { + break; + } + } + + return recv.borrow().clone(); + } + let url_0 = url.clone(); + + let notify = Arc::new(Notify::new()); + let notify2 = Arc::clone(¬ify); + + tokio::spawn(async move { + let (tx, rx) = channel(FetchResult::Processing); + + self.locks.write().insert(url.clone(), rx); + notify.notify_one(); + let resp = self.inner.get(&url).send().await; + + let resp = match resp { + Ok(mut resp) => { + let content_type = resp.headers().get(CONTENT_TYPE); + + let is_image = content_type + .map(|v| String::from_utf8_lossy(v.as_ref()).contains("image/")) + .unwrap_or_default(); + + if resp.status() != StatusCode::OK || !is_image { + warn!("Got non-OK or non-image response code from upstream, proxying and not caching result."); + + let mut headers = HeaderMap::new(); + + if let Some(content_type) = content_type { + headers.insert(CONTENT_TYPE, content_type.clone()); + } + + headers.insert(X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff")); + headers.insert( + ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_static("https://mangadex.org"), + ); + headers + .insert(ACCESS_CONTROL_EXPOSE_HEADERS, HeaderValue::from_static("*")); + headers.insert( + CACHE_CONTROL, + HeaderValue::from_static("public, max-age=1209600"), + ); + headers.insert( + HeaderName::from_static("timing-allow-origin"), + HeaderValue::from_static("https://mangadex.org"), + ); + + FetchResult::Data( + resp.status(), + headers, + resp.bytes().await.unwrap_or_default(), + ) + } else { + let (content_type, length, last_mod) = { + let headers = resp.headers_mut(); + ( + headers.remove(CONTENT_TYPE), + headers.remove(CONTENT_LENGTH), + headers.remove(LAST_MODIFIED), + ) + }; + + let body = resp.bytes().await.unwrap(); + + debug!("Inserting into cache"); + + let metadata = ImageMetadata::new( + content_type.clone(), + length.clone(), + last_mod.clone(), + ) + .unwrap(); + + match cache.put(key, body.clone(), metadata).await { + Ok(()) => { + debug!("Done putting into cache"); + + let mut headers = HeaderMap::new(); + if let Some(content_type) = content_type { + headers.insert(CONTENT_TYPE, content_type); + } + + if let Some(content_length) = length { + headers.insert(CONTENT_LENGTH, content_length); + } + + if let Some(last_modified) = last_mod { + headers.insert(LAST_MODIFIED, last_modified); + } + + headers.insert( + X_CONTENT_TYPE_OPTIONS, + HeaderValue::from_static("nosniff"), + ); + headers.insert( + ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_static("https://mangadex.org"), + ); + headers.insert( + ACCESS_CONTROL_EXPOSE_HEADERS, + HeaderValue::from_static("*"), + ); + headers.insert( + CACHE_CONTROL, + HeaderValue::from_static("public, max-age=1209600"), + ); + headers.insert( + HeaderName::from_static("timing-allow-origin"), + HeaderValue::from_static("https://mangadex.org"), + ); + FetchResult::Data(StatusCode::OK, headers, body) + } + Err(e) => { + warn!("Failed to insert into cache: {}", e); + FetchResult::InternalServerError + } + } + } + } + Err(e) => { + error!("Failed to fetch image from server: {}", e); + FetchResult::ServiceUnavailable + } + }; + // This shouldn't happen + tx.send(resp).unwrap(); + self.locks.write().remove(&url); + }); + + notify2.notified().await; + + let mut recv = self.locks.read().get(&url_0).unwrap().clone(); + loop { + if !matches!(*recv.borrow(), FetchResult::Processing) { + break; + } + if recv.changed().await.is_err() { + break; + } + } + let resp = recv.borrow().clone(); + resp + } + + #[inline] + pub const fn inner(&self) -> &Client { + &self.inner + } +} diff --git a/src/config.rs b/src/config.rs index 0ec0962..313db08 100644 --- a/src/config.rs +++ b/src/config.rs @@ -155,10 +155,10 @@ impl Config { .server_settings .external_ip .map(|ip_addr| SocketAddr::new(ip_addr, external_port)), - ephemeral_disk_encryption: cli_args - .ephemeral_disk_encryption - .or(file_extended_options.ephemeral_disk_encryption) - .unwrap_or_default(), + ephemeral_disk_encryption: cli_args.ephemeral_disk_encryption + || file_extended_options + .ephemeral_disk_encryption + .unwrap_or_default(), network_speed: cli_args .network_speed .unwrap_or(file_args.server_settings.external_max_kilobits_per_second), @@ -284,7 +284,7 @@ struct CliArgs { /// encrypted with a key generated at runtime. There are implications to /// performance, privacy, and usability with this flag enabled. #[clap(short, long)] - pub ephemeral_disk_encryption: Option, + pub ephemeral_disk_encryption: bool, #[clap(short, long)] pub config_path: Option, #[clap(short = 't', long)] diff --git a/src/main.rs b/src/main.rs index 02e3320..c54f0f4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -30,6 +30,7 @@ use crate::config::{CacheType, UnstableOptions, OFFLINE_MODE}; use crate::state::DynamicServerCert; mod cache; +mod client; mod config; mod metrics; mod ping; diff --git a/src/routes.rs b/src/routes.rs index ee00a5b..52d07e1 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,5 +1,4 @@ use std::sync::atomic::Ordering; -use std::time::Duration; use actix_web::error::ErrorNotFound; use actix_web::http::header::{ @@ -12,16 +11,15 @@ use actix_web::{get, web::Data, HttpRequest, HttpResponse, Responder}; use base64::DecodeError; use bytes::Bytes; use chrono::{DateTime, Utc}; -use futures::{Stream, TryStreamExt}; -use log::{debug, error, info, trace, warn}; -use once_cell::sync::Lazy; +use futures::Stream; +use log::{debug, error, info, trace}; use prometheus::{Encoder, TextEncoder}; -use reqwest::{Client, StatusCode}; use serde::Deserialize; use sodiumoxide::crypto::box_::{open_precomputed, Nonce, PrecomputedKey, NONCEBYTES}; use thiserror::Error; use crate::cache::{Cache, CacheKey, ImageMetadata, UpstreamError}; +use crate::client::{FetchResult, HTTP_CLIENT}; use crate::config::{OFFLINE_MODE, VALIDATE_TOKENS}; use crate::metrics::{ CACHE_HIT_COUNTER, CACHE_MISS_COUNTER, REQUESTS_DATA_COUNTER, REQUESTS_DATA_SAVER_COUNTER, @@ -31,16 +29,7 @@ use crate::state::RwLockServerState; const BASE64_CONFIG: base64::Config = base64::Config::new(base64::CharacterSet::UrlSafe, false); -static HTTP_CLIENT: Lazy = Lazy::new(|| { - Client::builder() - .pool_idle_timeout(Duration::from_secs(180)) - .https_only(true) - .http2_prior_knowledge() - .build() - .expect("Client initialization to work") -}); - -enum ServerResponse { +pub enum ServerResponse { TokenValidationError(TokenValidationError), HttpResponse(HttpResponse), } @@ -116,7 +105,7 @@ pub async fn default(state: Data, req: HttpRequest) -> impl R info!("Got unknown path, just proxying: {}", path); - let resp = match HTTP_CLIENT.get(path).send().await { + let resp = match HTTP_CLIENT.inner().get(path).send().await { Ok(resp) => resp, Err(e) => { error!("{}", e); @@ -145,7 +134,7 @@ pub async fn metrics() -> impl Responder { } #[derive(Error, Debug)] -enum TokenValidationError { +pub enum TokenValidationError { #[error("Failed to decode base64 token.")] DecodeError(#[from] DecodeError), #[error("Nonce was too short.")] @@ -208,7 +197,7 @@ fn validate_token( } #[inline] -fn push_headers(builder: &mut HttpResponseBuilder) -> &mut HttpResponseBuilder { +pub fn push_headers(builder: &mut HttpResponseBuilder) -> &mut HttpResponseBuilder { builder .insert_header((X_CONTENT_TYPE_OPTIONS, "nosniff")) .insert_header((ACCESS_CONTROL_ALLOW_ORIGIN, "https://mangadex.org")) @@ -237,7 +226,7 @@ async fn fetch_image( Some(Err(_)) => { return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish()); } - _ => (), + None => (), } CACHE_MISS_COUNTER.inc(); @@ -249,95 +238,35 @@ async fn fetch_image( ); } - // It's important to not get a write lock before this request, else we're - // holding the read lock until the await resolves. - - let resp = if is_data_saver { - HTTP_CLIENT - .get(format!( - "{}/data-saver/{}/{}", - state.0.read().image_server, - &key.0, - &key.1 - )) - .send() + let url = if is_data_saver { + format!( + "{}/data-saver/{}/{}", + state.0.read().image_server, + &key.0, + &key.1, + ) } else { - HTTP_CLIENT - .get(format!( - "{}/data/{}/{}", - state.0.read().image_server, - &key.0, - &key.1 - )) - .send() - } - .await; + format!("{}/data/{}/{}", state.0.read().image_server, &key.0, &key.1) + }; - match resp { - Ok(mut resp) => { - let content_type = resp.headers().get(CONTENT_TYPE); - - let is_image = content_type - .map(|v| String::from_utf8_lossy(v.as_ref()).contains("image/")) - .unwrap_or_default(); - - if resp.status() != StatusCode::OK || !is_image { - warn!( - "Got non-OK or non-image response code from upstream, proxying and not caching result.", - ); - - let mut resp_builder = HttpResponseBuilder::new(resp.status()); - if let Some(content_type) = content_type { - resp_builder.insert_header((CONTENT_TYPE, content_type)); - } - - push_headers(&mut resp_builder); - - return ServerResponse::HttpResponse( - resp_builder.body(resp.bytes().await.unwrap_or_default()), - ); - } - - let (content_type, length, last_mod) = { - let headers = resp.headers_mut(); - ( - headers.remove(CONTENT_TYPE), - headers.remove(CONTENT_LENGTH), - headers.remove(LAST_MODIFIED), - ) - }; - - let body = resp.bytes_stream().map_err(|e| e.into()); - - debug!("Inserting into cache"); - - let metadata = ImageMetadata::new(content_type, length, last_mod).unwrap(); - let stream = { - match cache.put(key, Box::new(body), metadata).await { - Ok(stream) => stream, - Err(e) => { - warn!("Failed to insert into cache: {}", e); - return ServerResponse::HttpResponse( - HttpResponse::InternalServerError().finish(), - ); - } - } - }; - - debug!("Done putting into cache"); - - construct_response(stream, &metadata) + match HTTP_CLIENT.fetch_and_cache(url, key, cache).await { + FetchResult::ServiceUnavailable => { + ServerResponse::HttpResponse(HttpResponse::ServiceUnavailable().finish()) } - Err(e) => { - error!("Failed to fetch image from server: {}", e); - ServerResponse::HttpResponse( - push_headers(&mut HttpResponse::ServiceUnavailable()).finish(), - ) + FetchResult::InternalServerError => { + ServerResponse::HttpResponse(HttpResponse::InternalServerError().finish()) } + FetchResult::Data(status, headers, data) => { + let mut resp = HttpResponseBuilder::new(status); + let mut resp = resp.body(data); + *resp.headers_mut() = headers; + ServerResponse::HttpResponse(resp) + } + FetchResult::Processing => panic!("Race condition found with fetch result"), } } -fn construct_response( +pub fn construct_response( data: impl Stream> + Unpin + 'static, metadata: &ImageMetadata, ) -> ServerResponse {