Simply codebase

This commit is contained in:
Edward Shen 2021-07-10 18:53:28 -04:00
parent 154679967b
commit 93249397f1
Signed by: edward
GPG key ID: 19182661E818369F
8 changed files with 311 additions and 373 deletions

18
src/cache/disk.rs vendored
View file

@ -16,9 +16,7 @@ use tokio_stream::wrappers::ReceiverStream;
use crate::units::Bytes; use crate::units::Bytes;
use super::{ use super::{Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata};
BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata,
};
pub struct DiskCache { pub struct DiskCache {
disk_path: PathBuf, disk_path: PathBuf,
@ -210,9 +208,9 @@ impl Cache for DiskCache {
async fn put( async fn put(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: bytes::Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
) -> Result<CacheStream, CacheError> { ) -> Result<(), CacheError> {
let channel = self.db_update_channel_sender.clone(); let channel = self.db_update_channel_sender.clone();
let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key))); let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key)));
@ -225,9 +223,6 @@ impl Cache for DiskCache {
super::fs::write_file(&path, key, image, metadata, db_callback, None) super::fs::write_file(&path, key, image, metadata, db_callback, None)
.await .await
.map_err(CacheError::from) .map_err(CacheError::from)
.and_then(|(inner, maybe_header)| {
CacheStream::new(inner, maybe_header).map_err(|_| CacheError::DecryptionFailure)
})
} }
} }
@ -236,10 +231,10 @@ impl CallbackCache for DiskCache {
async fn put_with_on_completed_callback( async fn put_with_on_completed_callback(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: bytes::Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
on_complete: Sender<(CacheKey, bytes::Bytes, ImageMetadata, u64)>, on_complete: Sender<(CacheKey, bytes::Bytes, ImageMetadata, u64)>,
) -> Result<CacheStream, CacheError> { ) -> Result<(), CacheError> {
let channel = self.db_update_channel_sender.clone(); let channel = self.db_update_channel_sender.clone();
let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key))); let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key)));
@ -253,8 +248,5 @@ impl CallbackCache for DiskCache {
super::fs::write_file(&path, key, image, metadata, db_callback, Some(on_complete)) super::fs::write_file(&path, key, image, metadata, db_callback, Some(on_complete))
.await .await
.map_err(CacheError::from) .map_err(CacheError::from)
.and_then(|(inner, maybe_header)| {
CacheStream::new(inner, maybe_header).map_err(|_| CacheError::DecryptionFailure)
})
} }
} }

276
src/cache/fs.rs vendored
View file

@ -14,35 +14,26 @@
//! upstream no longer needs to process duplicate requests and sequential cache //! upstream no longer needs to process duplicate requests and sequential cache
//! misses are treated as closer as a cache hit. //! misses are treated as closer as a cache hit.
use std::collections::HashMap;
use std::error::Error; use std::error::Error;
use std::fmt::Display; use std::fmt::Display;
use std::io::SeekFrom; use std::path::Path;
use std::num::NonZeroU64;
use std::path::{Path, PathBuf};
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use actix_web::error::PayloadError; use actix_web::error::PayloadError;
use bytes::{Buf, Bytes, BytesMut}; use bytes::Bytes;
use futures::{Future, Stream, StreamExt}; use futures::Future;
use log::{debug, warn}; use log::{debug, warn};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sodiumoxide::crypto::secretstream::{ use sodiumoxide::crypto::secretstream::{
Header, Pull, Push, Stream as SecretStream, Tag, HEADERBYTES, Header, Pull, Push, Stream as SecretStream, Tag, HEADERBYTES,
}; };
use tokio::fs::{create_dir_all, remove_file, File}; use tokio::fs::{create_dir_all, remove_file, File};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use tokio::sync::watch::{channel, Receiver};
use tokio::sync::RwLock;
use tokio_stream::wrappers::WatchStream;
use tokio_util::codec::{BytesCodec, FramedRead}; use tokio_util::codec::{BytesCodec, FramedRead};
use super::{ use super::{CacheKey, ImageMetadata, InnerStream, ENCRYPTION_KEY};
BoxedImageStream, CacheKey, CacheStreamItem, ImageMetadata, InnerStream, ENCRYPTION_KEY,
};
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub enum OnDiskMetadata { pub enum OnDiskMetadata {
@ -50,25 +41,6 @@ pub enum OnDiskMetadata {
Plaintext(ImageMetadata), Plaintext(ImageMetadata),
} }
/// Keeps track of files that are currently being written to.
///
/// Why is this necessary? Consider the following situation:
///
/// Client A requests file `foo.png`. We construct a transparent file stream,
/// and now the file is being streamed into and from.
///
/// Client B requests the same file `foo.png`. A naive implementation would
/// attempt to either read directly the file as it sees the file existing. This
/// is problematic as the file could still be written to. If Client B catches
/// up to Client A's request, then Client B could receive a broken image, as it
/// thinks it's done reading the file.
///
/// We effectively use `WRITING_STATUS` as a status relay to ensure concurrent
/// reads to the file while it's being written to will wait for writing to be
/// completed.
static WRITING_STATUS: Lazy<RwLock<HashMap<PathBuf, Receiver<WritingStatus>>>> =
Lazy::new(|| RwLock::new(HashMap::new()));
/// Attempts to lookup the file on disk, returning a byte stream if it exists. /// Attempts to lookup the file on disk, returning a byte stream if it exists.
/// Note that this could return two types of streams, depending on if the file /// Note that this could return two types of streams, depending on if the file
/// is in progress of being written to. /// is in progress of being written to.
@ -100,7 +72,7 @@ pub(super) async fn read_file(
let mut file = File::from_std(file_0); let mut file = File::from_std(file_0);
let file_0 = file.try_clone().await.unwrap(); let file_0 = file.try_clone().await.unwrap();
// image is decrypted or corrupt // image is encrypted or corrupt
// If the encryption key was set, use the encrypted disk reader instead; // If the encryption key was set, use the encrypted disk reader instead;
// else, just directly read from file. // else, just directly read from file.
@ -142,20 +114,7 @@ pub(super) async fn read_file(
// successfully decoded the data; otherwise the file is garbage. // successfully decoded the data; otherwise the file is garbage.
if let Some(reader) = reader { if let Some(reader) = reader {
// False positive lint, `file` is used in both cases, which means that it's let stream = InnerStream::Completed(FramedRead::new(reader, BytesCodec::new()));
// not possible to move this into a map_or_else without cloning `file`.
#[allow(clippy::option_if_let_else)]
let stream = if let Some(status) = WRITING_STATUS.read().await.get(path).map(Clone::clone) {
debug!("Got an in-progress stream");
InnerStream::Concurrent(ConcurrentFsStream::from_reader(
reader,
WatchStream::new(status),
))
} else {
debug!("Got a completed stream");
InnerStream::Completed(FramedRead::new(reader, BytesCodec::new()))
};
parsed_metadata.map(|metadata| Ok((stream, maybe_header, metadata))) parsed_metadata.map(|metadata| Ok((stream, maybe_header, metadata)))
} else { } else {
debug!("Reader was invalid, file is corrupt"); debug!("Reader was invalid, file is corrupt");
@ -229,23 +188,19 @@ impl AsyncRead for EncryptedDiskReader {
pub(super) async fn write_file<Fut, DbCallback>( pub(super) async fn write_file<Fut, DbCallback>(
path: &Path, path: &Path,
cache_key: CacheKey, cache_key: CacheKey,
mut byte_stream: BoxedImageStream, bytes: Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
db_callback: DbCallback, db_callback: DbCallback,
on_complete: Option<Sender<(CacheKey, Bytes, ImageMetadata, u64)>>, on_complete: Option<Sender<(CacheKey, Bytes, ImageMetadata, u64)>>,
) -> Result<(InnerStream, Option<Header>), std::io::Error> ) -> Result<(), std::io::Error>
where where
Fut: 'static + Send + Sync + Future<Output = ()>, Fut: 'static + Send + Sync + Future<Output = ()>,
DbCallback: 'static + Send + Sync + FnOnce(u64) -> Fut, DbCallback: 'static + Send + Sync + FnOnce(u64) -> Fut,
{ {
let (tx, rx) = channel(WritingStatus::Writing(0));
let file = { let file = {
let mut write_lock = WRITING_STATUS.write().await;
let parent = path.parent().expect("The path to have a parent"); let parent = path.parent().expect("The path to have a parent");
create_dir_all(parent).await?; create_dir_all(parent).await?;
let file = File::create(path).await?; // we need to make sure the file exists and is truncated. let file = File::create(path).await?; // we need to make sure the file exists and is truncated.
write_lock.insert(path.to_path_buf(), rx.clone());
file file
}; };
@ -262,88 +217,43 @@ where
(Box::pin(file), None) (Box::pin(file), None)
}; };
// need owned variant because async lifetime let mut error = if let Some(header) = maybe_header {
let path_buf = path.to_path_buf(); writer.write_all(header.as_ref()).await.err()
tokio::spawn(async move { } else {
let path_buf = path_buf; // moves path buf into async None
let mut errored = false; };
let mut bytes_written: u64 = 0;
let mut acc_bytes = BytesMut::new();
let accumulate = on_complete.is_some();
writer.write_all(metadata_string.as_bytes()).await?;
while let Some(bytes) = byte_stream.next().await { if error.is_none() {
if let Ok(mut bytes) = bytes { error = writer.write_all(metadata_string.as_bytes()).await.err();
if accumulate { }
acc_bytes.extend(&bytes); if error.is_none() {
} error = error.or(writer.write_all(&bytes).await.err());
}
loop { if let Some(e) = error {
match writer.write(&bytes).await? { // It's ok if the deleting the file fails, since we truncate on
0 => break, // create anyways, but it should be best effort.
n => { //
bytes.advance(n); // We don't care about the result of the call.
bytes_written += n as u64; std::mem::drop(remove_file(path).await);
return Err(e);
}
// We don't really care if we have no receivers writer.flush().await?;
let _ = tx.send(WritingStatus::Writing(n as u64)); debug!("writing to file done");
}
}
}
} else {
errored = true;
break;
}
}
if errored { let bytes_written = (metadata_size + bytes.len()) as u64;
// It's ok if the deleting the file fails, since we truncate on tokio::spawn(db_callback(bytes_written));
// create anyways, but it should be best effort.
//
// We don't care about the result of the call.
std::mem::drop(remove_file(&path_buf).await);
} else {
writer.flush().await?;
debug!("writing to file done");
}
{ if let Some(sender) = on_complete {
let mut write_lock = WRITING_STATUS.write().await; tokio::spawn(async move {
// This needs to be written atomically with the write lock, else sender
// it's possible we have an inconsistent state .send((cache_key, bytes, metadata, bytes_written))
// .await
// We don't really care if we have no receivers });
if errored { }
let _ = tx.send(WritingStatus::Error);
}
// Explicitly drop it here since we're done with sending values.
// This is ok since we have a stream adapter on the other end.
// We must drop it here in the critical section, hence the explicit
// drop.
std::mem::drop(tx);
write_lock.remove(&path_buf);
}
tokio::spawn(db_callback(bytes_written)); Ok(())
if let Some(sender) = on_complete {
tokio::spawn(async move {
sender
.send((cache_key, acc_bytes.freeze(), metadata, bytes_written))
.await
});
}
// We don't ever check this, so the return value doesn't matter
Ok::<_, std::io::Error>(())
});
Ok((
InnerStream::Concurrent(
ConcurrentFsStream::new(path, metadata_size, WatchStream::new(rx)).await?,
),
maybe_header,
))
} }
struct EncryptedDiskWriter { struct EncryptedDiskWriter {
@ -432,43 +342,6 @@ impl AsyncWrite for EncryptedDiskWriter {
} }
} }
pub struct ConcurrentFsStream {
/// The File to read from
reader: Pin<Box<dyn AsyncRead + Send>>,
/// The channel to get updates from. The writer must send its status, else
/// this reader will never complete.
receiver: Option<Pin<Box<WatchStream<WritingStatus>>>>,
/// The number of bytes the reader has read
bytes_read: u64,
/// The number of bytes that the writer has reported it has written. If the
/// writer has not reported yet, this value is None.
bytes_total: Option<NonZeroU64>,
}
impl ConcurrentFsStream {
async fn new(
path: &Path,
seek: usize,
receiver: WatchStream<WritingStatus>,
) -> Result<Self, std::io::Error> {
let mut file = File::open(path).await?;
file.seek(SeekFrom::Start(seek as u64)).await?;
Ok(Self::from_reader(Box::pin(file), receiver))
}
fn from_reader(
reader: Pin<Box<dyn AsyncRead + Send>>,
receiver: WatchStream<WritingStatus>,
) -> Self {
Self {
reader: Box::pin(reader),
receiver: Some(Box::pin(receiver)),
bytes_read: 0,
bytes_total: None,
}
}
}
/// Represents some upstream error. /// Represents some upstream error.
#[derive(Debug)] #[derive(Debug)]
pub struct UpstreamError; pub struct UpstreamError;
@ -481,78 +354,9 @@ impl Display for UpstreamError {
} }
} }
impl Stream for ConcurrentFsStream {
type Item = CacheStreamItem;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.receiver.as_mut().map(|v| v.poll_next_unpin(cx)) {
Some(Poll::Ready(Some(WritingStatus::Writing(n)))) => match self.bytes_total.as_mut() {
Some(v) => *v = unsafe { NonZeroU64::new_unchecked(v.get() + n) },
None => self.bytes_total = unsafe { Some(NonZeroU64::new_unchecked(n)) },
},
Some(Poll::Ready(Some(WritingStatus::Error))) => {
// Upstream errored, abort reading
return Poll::Ready(Some(Err(UpstreamError)));
}
Some(Poll::Ready(None)) => {
// Take the receiver so that we can't poll it again
self.receiver.take();
}
Some(Poll::Pending) | None => (),
}
// We are entirely done if the bytes total equals the bytes read
if Some(self.bytes_read) == self.bytes_total.map(NonZeroU64::get) {
return Poll::Ready(None);
}
// We're not done, so try reading from the file.
// TODO: Might be more efficient to have a larger buffer
let mut bytes = [0; 4 * 1024].to_vec();
let mut buffer = ReadBuf::new(&mut bytes);
match self.reader.as_mut().poll_read(cx, &mut buffer) {
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(_)) => return Poll::Ready(Some(Err(UpstreamError))),
Poll::Pending => return Poll::Pending,
}
// At this point, we know that we "successfully" read some amount of
// data. Let's see if there's actual data in there...
let filled = buffer.filled().len();
if filled == 0 {
// We haven't read enough bytes, but we know there's more to read,
// so just return an empty bytes and have the executor request some
// bytes some time in the future.
//
// This case might be solved by io_uring, but for now this is this
// the best we can do.
Poll::Ready(Some(Ok(Bytes::new())))
} else {
// We have data! Give it to the reader!
self.bytes_read += filled as u64;
bytes.truncate(filled);
Poll::Ready(Some(Ok(bytes.into())))
}
}
}
impl From<UpstreamError> for actix_web::Error { impl From<UpstreamError> for actix_web::Error {
#[inline] #[inline]
fn from(_: UpstreamError) -> Self { fn from(_: UpstreamError) -> Self {
PayloadError::Incomplete(None).into() PayloadError::Incomplete(None).into()
} }
} }
#[derive(Debug, Clone, Copy)]
enum WritingStatus {
Writing(u64),
Error,
}
#[cfg(test)]
mod storage {
#[test]
fn wut() {}
}

7
src/cache/mem.rs vendored
View file

@ -2,8 +2,7 @@ use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use super::{ use super::{
BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata, Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata, InnerStream, MemStream,
InnerStream, MemStream,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
@ -164,9 +163,9 @@ where
async fn put( async fn put(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
) -> Result<CacheStream, super::CacheError> { ) -> Result<(), super::CacheError> {
self.inner self.inner
.put_with_on_completed_callback(key, image, metadata, self.master_sender.clone()) .put_with_on_completed_callback(key, image, metadata, self.master_sender.clone())
.await .await

27
src/cache/mod.rs vendored
View file

@ -9,7 +9,6 @@ use actix_web::http::HeaderValue;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use chrono::{DateTime, FixedOffset}; use chrono::{DateTime, FixedOffset};
use fs::ConcurrentFsStream;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -148,8 +147,6 @@ impl ImageMetadata {
} }
} }
type BoxedImageStream = Box<dyn Stream<Item = Result<Bytes, CacheError>> + Unpin + Send>;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum CacheError { pub enum CacheError {
#[error(transparent)] #[error(transparent)]
@ -170,9 +167,9 @@ pub trait Cache: Send + Sync {
async fn put( async fn put(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
) -> Result<CacheStream, CacheError>; ) -> Result<(), CacheError>;
} }
#[async_trait] #[async_trait]
@ -189,9 +186,9 @@ impl<T: Cache> Cache for Arc<T> {
async fn put( async fn put(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
) -> Result<CacheStream, CacheError> { ) -> Result<(), CacheError> {
self.as_ref().put(key, image, metadata).await self.as_ref().put(key, image, metadata).await
} }
} }
@ -201,10 +198,10 @@ pub trait CallbackCache: Cache {
async fn put_with_on_completed_callback( async fn put_with_on_completed_callback(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>, on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>,
) -> Result<CacheStream, CacheError>; ) -> Result<(), CacheError>;
} }
#[async_trait] #[async_trait]
@ -213,10 +210,10 @@ impl<T: CallbackCache> CallbackCache for Arc<T> {
async fn put_with_on_completed_callback( async fn put_with_on_completed_callback(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>, on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>,
) -> Result<CacheStream, CacheError> { ) -> Result<(), CacheError> {
self.as_ref() self.as_ref()
.put_with_on_completed_callback(key, image, metadata, on_complete) .put_with_on_completed_callback(key, image, metadata, on_complete)
.await .await
@ -233,7 +230,11 @@ impl CacheStream {
Ok(Self { Ok(Self {
inner, inner,
decrypt: header decrypt: header
.and_then(|header| ENCRYPTION_KEY.get().map(|key| SecretStream::init_pull(&header, key))) .and_then(|header| {
ENCRYPTION_KEY
.get()
.map(|key| SecretStream::init_pull(&header, key))
})
.transpose()?, .transpose()?,
}) })
} }
@ -263,7 +264,6 @@ impl Stream for CacheStream {
} }
pub(self) enum InnerStream { pub(self) enum InnerStream {
Concurrent(ConcurrentFsStream),
Memory(MemStream), Memory(MemStream),
Completed(FramedRead<Pin<Box<dyn AsyncRead + Send>>, BytesCodec>), Completed(FramedRead<Pin<Box<dyn AsyncRead + Send>>, BytesCodec>),
} }
@ -281,7 +281,6 @@ impl Stream for InnerStream {
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.get_mut() { match self.get_mut() {
Self::Concurrent(stream) => stream.poll_next_unpin(cx),
Self::Memory(stream) => stream.poll_next_unpin(cx), Self::Memory(stream) => stream.poll_next_unpin(cx),
Self::Completed(stream) => stream Self::Completed(stream) => stream
.poll_next_unpin(cx) .poll_next_unpin(cx)

214
src/client.rs Normal file
View file

@ -0,0 +1,214 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use actix_web::{
http::{HeaderMap, HeaderName, HeaderValue},
web::Data,
};
use bytes::Bytes;
use log::{debug, error, warn};
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use reqwest::{
header::{
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, CACHE_CONTROL, CONTENT_LENGTH,
CONTENT_TYPE, LAST_MODIFIED, X_CONTENT_TYPE_OPTIONS,
},
Client, StatusCode,
};
use tokio::sync::{
watch::{channel, Receiver},
Notify,
};
use crate::cache::{Cache, CacheKey, ImageMetadata};
pub static HTTP_CLIENT: Lazy<CachingClient> = Lazy::new(|| CachingClient {
inner: Client::builder()
.pool_idle_timeout(Duration::from_secs(180))
.https_only(true)
.http2_prior_knowledge()
.build()
.expect("Client initialization to work"),
locks: RwLock::new(HashMap::new()),
});
pub struct CachingClient {
inner: Client,
locks: RwLock<HashMap<String, Receiver<FetchResult>>>,
}
#[derive(Clone, Debug)]
pub enum FetchResult {
ServiceUnavailable,
InternalServerError,
Data(StatusCode, HeaderMap, Bytes),
Processing,
}
impl CachingClient {
pub async fn fetch_and_cache(
&'static self,
url: String,
key: CacheKey,
cache: Data<dyn Cache>,
) -> FetchResult {
if let Some(recv) = self.locks.read().get(&url) {
let mut recv = recv.clone();
loop {
if !matches!(*recv.borrow(), FetchResult::Processing) {
break;
}
if recv.changed().await.is_err() {
break;
}
}
return recv.borrow().clone();
}
let url_0 = url.clone();
let notify = Arc::new(Notify::new());
let notify2 = Arc::clone(&notify);
tokio::spawn(async move {
let (tx, rx) = channel(FetchResult::Processing);
self.locks.write().insert(url.clone(), rx);
notify.notify_one();
let resp = self.inner.get(&url).send().await;
let resp = match resp {
Ok(mut resp) => {
let content_type = resp.headers().get(CONTENT_TYPE);
let is_image = content_type
.map(|v| String::from_utf8_lossy(v.as_ref()).contains("image/"))
.unwrap_or_default();
if resp.status() != StatusCode::OK || !is_image {
warn!("Got non-OK or non-image response code from upstream, proxying and not caching result.");
let mut headers = HeaderMap::new();
if let Some(content_type) = content_type {
headers.insert(CONTENT_TYPE, content_type.clone());
}
headers.insert(X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff"));
headers.insert(
ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_static("https://mangadex.org"),
);
headers
.insert(ACCESS_CONTROL_EXPOSE_HEADERS, HeaderValue::from_static("*"));
headers.insert(
CACHE_CONTROL,
HeaderValue::from_static("public, max-age=1209600"),
);
headers.insert(
HeaderName::from_static("timing-allow-origin"),
HeaderValue::from_static("https://mangadex.org"),
);
FetchResult::Data(
resp.status(),
headers,
resp.bytes().await.unwrap_or_default(),
)
} else {
let (content_type, length, last_mod) = {
let headers = resp.headers_mut();
(
headers.remove(CONTENT_TYPE),
headers.remove(CONTENT_LENGTH),
headers.remove(LAST_MODIFIED),
)
};
let body = resp.bytes().await.unwrap();
debug!("Inserting into cache");
let metadata = ImageMetadata::new(
content_type.clone(),
length.clone(),
last_mod.clone(),
)
.unwrap();
match cache.put(key, body.clone(), metadata).await {
Ok(()) => {
debug!("Done putting into cache");
let mut headers = HeaderMap::new();
if let Some(content_type) = content_type {
headers.insert(CONTENT_TYPE, content_type);
}
if let Some(content_length) = length {
headers.insert(CONTENT_LENGTH, content_length);
}
if let Some(last_modified) = last_mod {
headers.insert(LAST_MODIFIED, last_modified);
}
headers.insert(
X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
);
headers.insert(
ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_static("https://mangadex.org"),
);
headers.insert(
ACCESS_CONTROL_EXPOSE_HEADERS,
HeaderValue::from_static("*"),
);
headers.insert(
CACHE_CONTROL,
HeaderValue::from_static("public, max-age=1209600"),
);
headers.insert(
HeaderName::from_static("timing-allow-origin"),
HeaderValue::from_static("https://mangadex.org"),
);
FetchResult::Data(StatusCode::OK, headers, body)
}
Err(e) => {
warn!("Failed to insert into cache: {}", e);
FetchResult::InternalServerError
}
}
}
}
Err(e) => {
error!("Failed to fetch image from server: {}", e);
FetchResult::ServiceUnavailable
}
};
// This shouldn't happen
tx.send(resp).unwrap();
self.locks.write().remove(&url);
});
notify2.notified().await;
let mut recv = self.locks.read().get(&url_0).unwrap().clone();
loop {
if !matches!(*recv.borrow(), FetchResult::Processing) {
break;
}
if recv.changed().await.is_err() {
break;
}
}
let resp = recv.borrow().clone();
resp
}
#[inline]
pub const fn inner(&self) -> &Client {
&self.inner
}
}

View file

@ -155,10 +155,10 @@ impl Config {
.server_settings .server_settings
.external_ip .external_ip
.map(|ip_addr| SocketAddr::new(ip_addr, external_port)), .map(|ip_addr| SocketAddr::new(ip_addr, external_port)),
ephemeral_disk_encryption: cli_args ephemeral_disk_encryption: cli_args.ephemeral_disk_encryption
.ephemeral_disk_encryption || file_extended_options
.or(file_extended_options.ephemeral_disk_encryption) .ephemeral_disk_encryption
.unwrap_or_default(), .unwrap_or_default(),
network_speed: cli_args network_speed: cli_args
.network_speed .network_speed
.unwrap_or(file_args.server_settings.external_max_kilobits_per_second), .unwrap_or(file_args.server_settings.external_max_kilobits_per_second),
@ -284,7 +284,7 @@ struct CliArgs {
/// encrypted with a key generated at runtime. There are implications to /// encrypted with a key generated at runtime. There are implications to
/// performance, privacy, and usability with this flag enabled. /// performance, privacy, and usability with this flag enabled.
#[clap(short, long)] #[clap(short, long)]
pub ephemeral_disk_encryption: Option<bool>, pub ephemeral_disk_encryption: bool,
#[clap(short, long)] #[clap(short, long)]
pub config_path: Option<PathBuf>, pub config_path: Option<PathBuf>,
#[clap(short = 't', long)] #[clap(short = 't', long)]

View file

@ -30,6 +30,7 @@ use crate::config::{CacheType, UnstableOptions, OFFLINE_MODE};
use crate::state::DynamicServerCert; use crate::state::DynamicServerCert;
mod cache; mod cache;
mod client;
mod config; mod config;
mod metrics; mod metrics;
mod ping; mod ping;

View file

@ -1,5 +1,4 @@
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::time::Duration;
use actix_web::error::ErrorNotFound; use actix_web::error::ErrorNotFound;
use actix_web::http::header::{ use actix_web::http::header::{
@ -12,16 +11,15 @@ use actix_web::{get, web::Data, HttpRequest, HttpResponse, Responder};
use base64::DecodeError; use base64::DecodeError;
use bytes::Bytes; use bytes::Bytes;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use futures::{Stream, TryStreamExt}; use futures::Stream;
use log::{debug, error, info, trace, warn}; use log::{debug, error, info, trace};
use once_cell::sync::Lazy;
use prometheus::{Encoder, TextEncoder}; use prometheus::{Encoder, TextEncoder};
use reqwest::{Client, StatusCode};
use serde::Deserialize; use serde::Deserialize;
use sodiumoxide::crypto::box_::{open_precomputed, Nonce, PrecomputedKey, NONCEBYTES}; use sodiumoxide::crypto::box_::{open_precomputed, Nonce, PrecomputedKey, NONCEBYTES};
use thiserror::Error; use thiserror::Error;
use crate::cache::{Cache, CacheKey, ImageMetadata, UpstreamError}; use crate::cache::{Cache, CacheKey, ImageMetadata, UpstreamError};
use crate::client::{FetchResult, HTTP_CLIENT};
use crate::config::{OFFLINE_MODE, VALIDATE_TOKENS}; use crate::config::{OFFLINE_MODE, VALIDATE_TOKENS};
use crate::metrics::{ use crate::metrics::{
CACHE_HIT_COUNTER, CACHE_MISS_COUNTER, REQUESTS_DATA_COUNTER, REQUESTS_DATA_SAVER_COUNTER, CACHE_HIT_COUNTER, CACHE_MISS_COUNTER, REQUESTS_DATA_COUNTER, REQUESTS_DATA_SAVER_COUNTER,
@ -31,16 +29,7 @@ use crate::state::RwLockServerState;
const BASE64_CONFIG: base64::Config = base64::Config::new(base64::CharacterSet::UrlSafe, false); const BASE64_CONFIG: base64::Config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
static HTTP_CLIENT: Lazy<Client> = Lazy::new(|| { pub enum ServerResponse {
Client::builder()
.pool_idle_timeout(Duration::from_secs(180))
.https_only(true)
.http2_prior_knowledge()
.build()
.expect("Client initialization to work")
});
enum ServerResponse {
TokenValidationError(TokenValidationError), TokenValidationError(TokenValidationError),
HttpResponse(HttpResponse), HttpResponse(HttpResponse),
} }
@ -116,7 +105,7 @@ pub async fn default(state: Data<RwLockServerState>, req: HttpRequest) -> impl R
info!("Got unknown path, just proxying: {}", path); info!("Got unknown path, just proxying: {}", path);
let resp = match HTTP_CLIENT.get(path).send().await { let resp = match HTTP_CLIENT.inner().get(path).send().await {
Ok(resp) => resp, Ok(resp) => resp,
Err(e) => { Err(e) => {
error!("{}", e); error!("{}", e);
@ -145,7 +134,7 @@ pub async fn metrics() -> impl Responder {
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
enum TokenValidationError { pub enum TokenValidationError {
#[error("Failed to decode base64 token.")] #[error("Failed to decode base64 token.")]
DecodeError(#[from] DecodeError), DecodeError(#[from] DecodeError),
#[error("Nonce was too short.")] #[error("Nonce was too short.")]
@ -208,7 +197,7 @@ fn validate_token(
} }
#[inline] #[inline]
fn push_headers(builder: &mut HttpResponseBuilder) -> &mut HttpResponseBuilder { pub fn push_headers(builder: &mut HttpResponseBuilder) -> &mut HttpResponseBuilder {
builder builder
.insert_header((X_CONTENT_TYPE_OPTIONS, "nosniff")) .insert_header((X_CONTENT_TYPE_OPTIONS, "nosniff"))
.insert_header((ACCESS_CONTROL_ALLOW_ORIGIN, "https://mangadex.org")) .insert_header((ACCESS_CONTROL_ALLOW_ORIGIN, "https://mangadex.org"))
@ -237,7 +226,7 @@ async fn fetch_image(
Some(Err(_)) => { Some(Err(_)) => {
return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish()); return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish());
} }
_ => (), None => (),
} }
CACHE_MISS_COUNTER.inc(); CACHE_MISS_COUNTER.inc();
@ -249,95 +238,35 @@ async fn fetch_image(
); );
} }
// It's important to not get a write lock before this request, else we're let url = if is_data_saver {
// holding the read lock until the await resolves. format!(
"{}/data-saver/{}/{}",
let resp = if is_data_saver { state.0.read().image_server,
HTTP_CLIENT &key.0,
.get(format!( &key.1,
"{}/data-saver/{}/{}", )
state.0.read().image_server,
&key.0,
&key.1
))
.send()
} else { } else {
HTTP_CLIENT format!("{}/data/{}/{}", state.0.read().image_server, &key.0, &key.1)
.get(format!( };
"{}/data/{}/{}",
state.0.read().image_server,
&key.0,
&key.1
))
.send()
}
.await;
match resp { match HTTP_CLIENT.fetch_and_cache(url, key, cache).await {
Ok(mut resp) => { FetchResult::ServiceUnavailable => {
let content_type = resp.headers().get(CONTENT_TYPE); ServerResponse::HttpResponse(HttpResponse::ServiceUnavailable().finish())
let is_image = content_type
.map(|v| String::from_utf8_lossy(v.as_ref()).contains("image/"))
.unwrap_or_default();
if resp.status() != StatusCode::OK || !is_image {
warn!(
"Got non-OK or non-image response code from upstream, proxying and not caching result.",
);
let mut resp_builder = HttpResponseBuilder::new(resp.status());
if let Some(content_type) = content_type {
resp_builder.insert_header((CONTENT_TYPE, content_type));
}
push_headers(&mut resp_builder);
return ServerResponse::HttpResponse(
resp_builder.body(resp.bytes().await.unwrap_or_default()),
);
}
let (content_type, length, last_mod) = {
let headers = resp.headers_mut();
(
headers.remove(CONTENT_TYPE),
headers.remove(CONTENT_LENGTH),
headers.remove(LAST_MODIFIED),
)
};
let body = resp.bytes_stream().map_err(|e| e.into());
debug!("Inserting into cache");
let metadata = ImageMetadata::new(content_type, length, last_mod).unwrap();
let stream = {
match cache.put(key, Box::new(body), metadata).await {
Ok(stream) => stream,
Err(e) => {
warn!("Failed to insert into cache: {}", e);
return ServerResponse::HttpResponse(
HttpResponse::InternalServerError().finish(),
);
}
}
};
debug!("Done putting into cache");
construct_response(stream, &metadata)
} }
Err(e) => { FetchResult::InternalServerError => {
error!("Failed to fetch image from server: {}", e); ServerResponse::HttpResponse(HttpResponse::InternalServerError().finish())
ServerResponse::HttpResponse(
push_headers(&mut HttpResponse::ServiceUnavailable()).finish(),
)
} }
FetchResult::Data(status, headers, data) => {
let mut resp = HttpResponseBuilder::new(status);
let mut resp = resp.body(data);
*resp.headers_mut() = headers;
ServerResponse::HttpResponse(resp)
}
FetchResult::Processing => panic!("Race condition found with fetch result"),
} }
} }
fn construct_response( pub fn construct_response(
data: impl Stream<Item = Result<Bytes, UpstreamError>> + Unpin + 'static, data: impl Stream<Item = Result<Bytes, UpstreamError>> + Unpin + 'static,
metadata: &ImageMetadata, metadata: &ImageMetadata,
) -> ServerResponse { ) -> ServerResponse {