From 1ac7b619cf3ef2774caaad862a6194ea85888f9b Mon Sep 17 00:00:00 2001 From: Edward Shen Date: Sun, 18 Apr 2021 23:06:18 -0400 Subject: [PATCH] debug streaming --- .gitignore | 3 +- Cargo.lock | 1 + Cargo.toml | 1 + src/cache/fs.rs | 81 ++++++++++++++++++++++++++++---------------- src/cache/low_mem.rs | 8 ++--- src/cache/mod.rs | 22 ++++++++---- src/config.rs | 14 ++++++++ src/main.rs | 64 ++++++++++++++++++++-------------- src/ping.rs | 15 +++++--- src/routes.rs | 29 ++++++++++++---- src/state.rs | 27 ++++++++++----- 11 files changed, 178 insertions(+), 87 deletions(-) diff --git a/.gitignore b/.gitignore index 0b745e2..00f8953 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target -.env \ No newline at end of file +.env +cache \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index ea69bde..0519afd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1000,6 +1000,7 @@ dependencies = [ "ssri", "thiserror", "tokio", + "tokio-util", "url", ] diff --git a/Cargo.toml b/Cargo.toml index 092e936..a3f8f79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ sodiumoxide = "0.2" ssri = "5" thiserror = "1" tokio = { version = "1", features = [ "full", "parking_lot" ] } +tokio-util = { version = "0.6", features = [ "codec" ] } url = { version = "2", features = [ "serde" ] } [profile.release] diff --git a/src/cache/fs.rs b/src/cache/fs.rs index 2274915..b8b8156 100644 --- a/src/cache/fs.rs +++ b/src/cache/fs.rs @@ -1,20 +1,22 @@ use actix_web::error::PayloadError; -use bytes::BytesMut; -use futures::{Future, Stream, StreamExt}; +use futures::{Stream, StreamExt}; +use log::debug; use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::fmt::Display; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; -use std::{collections::HashMap, fmt::Display}; -use tokio::fs::{remove_file, File}; +use tokio::fs::{create_dir_all, remove_file, File}; use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf}; use tokio::sync::RwLock; -use tokio::time::Sleep; +use tokio::time::Interval; +use tokio_util::codec::{BytesCodec, FramedRead}; -use super::{BoxedImageStream, CacheStreamItem}; +use super::{BoxedImageStream, CacheStream, CacheStreamItem}; /// Keeps track of files that are currently being written to. /// @@ -36,15 +38,23 @@ static WRITING_STATUS: Lazy>>> = Lazy::new(|| RwLock::new(HashMap::new())); /// Tries to read from the file, returning a byte stream if it exists -pub async fn read_file(path: &Path) -> Option> { +pub async fn read_file(path: &Path) -> Option> { if path.exists() { - let status = WRITING_STATUS - .read() - .await - .get(path) - .map_or_else(|| Arc::new(CacheStatus::done()), Arc::clone); + let status = WRITING_STATUS.read().await.get(path).map(Arc::clone); - Some(FsStream::new(path, status).await) + if let Some(status) = status { + Some( + ConcurrentFsStream::new(path, status) + .await + .map(CacheStream::Concurrent), + ) + } else { + Some( + File::open(path) + .await + .map(|f| CacheStream::Completed(FramedRead::new(f, BytesCodec::new()))), + ) + } } else { None } @@ -55,11 +65,13 @@ pub async fn read_file(path: &Path) -> Option> pub async fn write_file( path: &Path, mut byte_stream: BoxedImageStream, -) -> Result { +) -> Result { let done_writing_flag = Arc::new(CacheStatus::new()); let mut file = { let mut write_lock = WRITING_STATUS.write().await; + let parent = path.parent().unwrap(); + create_dir_all(parent).await?; let file = File::create(path).await?; // we need to make sure the file exists and is truncated. write_lock.insert(path.to_path_buf(), Arc::clone(&done_writing_flag)); file @@ -87,6 +99,7 @@ pub async fn write_file( } else { file.flush().await?; file.sync_all().await?; // we need metadata + debug!("writing to file done"); } let mut write_lock = WRITING_STATUS.write().await; @@ -103,21 +116,23 @@ pub async fn write_file( Ok::<_, std::io::Error>(()) }); - Ok(FsStream::new(path, done_writing_flag).await?) + Ok(CacheStream::Concurrent( + ConcurrentFsStream::new(path, done_writing_flag).await?, + )) } -pub struct FsStream { +pub struct ConcurrentFsStream { file: Pin>, - sleep: Pin>, + sleep: Pin>, is_file_done_writing: Arc, } -impl FsStream { +impl ConcurrentFsStream { async fn new(path: &Path, is_done: Arc) -> Result { Ok(Self { file: Box::pin(File::open(path).await?), // 0.5ms - sleep: Box::pin(tokio::time::sleep(Duration::from_micros(500))), + sleep: Box::pin(tokio::time::interval(Duration::from_micros(500))), is_file_done_writing: is_done, }) } @@ -135,25 +150,35 @@ impl Display for UpstreamError { } } -impl Stream for FsStream { +impl Stream for ConcurrentFsStream { type Item = CacheStreamItem; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let status = self.is_file_done_writing.load(); - let mut bytes = BytesMut::with_capacity(1460); + let mut bytes = [0; 1460].to_vec(); let mut buffer = ReadBuf::new(&mut bytes); let polled_result = self.file.as_mut().poll_read(cx, &mut buffer); - - match (status, buffer.filled().len()) { + let filled = buffer.filled().len(); + match (status, filled) { // Prematurely reached EOF, schedule a poll in the future (WritingStatus::NotDone, 0) => { - let _ = self.sleep.as_mut().poll(cx); + let _ = self.sleep.as_mut().poll_tick(cx); Poll::Pending } // We got an error, abort the read. (WritingStatus::Error, _) => Poll::Ready(Some(Err(UpstreamError))), - _ => polled_result.map(|_| Some(Ok(bytes.split().into()))), + _ => { + bytes.truncate(filled); + polled_result.map(|_| { + if bytes.is_empty() { + dbg!(line!()); + None + } else { + Some(Ok(bytes.into())) + } + }) + } } } } @@ -173,11 +198,6 @@ impl CacheStatus { Self(AtomicU8::new(WritingStatus::NotDone as u8)) } - #[inline] - const fn done() -> Self { - Self(AtomicU8::new(WritingStatus::Done as u8)) - } - #[inline] fn store(&self, status: WritingStatus) { self.0.store(status as u8, Ordering::Release); @@ -189,6 +209,7 @@ impl CacheStatus { } } +#[derive(Debug)] enum WritingStatus { NotDone = 0, Done, diff --git a/src/cache/low_mem.rs b/src/cache/low_mem.rs index caf23ac..485a699 100644 --- a/src/cache/low_mem.rs +++ b/src/cache/low_mem.rs @@ -34,10 +34,9 @@ impl Cache for LowMemCache { ) -> Option> { let metadata = self.on_disk.get(key)?; let path = self.disk_path.clone().join(PathBuf::from(key.clone())); - super::fs::read_file(&path).await.map(|res| { - res.map(|stream| (CacheStream::Fs(stream), metadata)) - .map_err(Into::into) - }) + super::fs::read_file(&path) + .await + .map(|res| res.map(|stream| (stream, metadata)).map_err(Into::into)) } async fn put( @@ -50,7 +49,6 @@ impl Cache for LowMemCache { self.on_disk.put(key.clone(), metadata); super::fs::write_file(&path, image) .await - .map(CacheStream::Fs) .map(move |stream| (stream, self.on_disk.get(&key).unwrap())) .map_err(Into::into) } diff --git a/src/cache/mod.rs b/src/cache/mod.rs index 5b097a0..b224460 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -8,7 +8,7 @@ use actix_web::http::HeaderValue; use async_trait::async_trait; use bytes::Bytes; use chrono::{DateTime, FixedOffset}; -use fs::FsStream; +use fs::ConcurrentFsStream; use futures::{Stream, StreamExt}; use log::debug; use thiserror::Error; @@ -16,6 +16,8 @@ use thiserror::Error; pub use fs::UpstreamError; pub use generational::GenerationalCache; pub use low_mem::LowMemCache; +use tokio::fs::File; +use tokio_util::codec::{BytesCodec, FramedRead}; mod fs; mod generational; @@ -163,8 +165,9 @@ pub trait Cache: Send + Sync { } pub enum CacheStream { - Fs(FsStream), + Concurrent(ConcurrentFsStream), Memory(MemStream), + Completed(FramedRead), } impl From for CacheStream { @@ -180,8 +183,12 @@ impl Stream for CacheStream { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - Self::Fs(stream) => stream.poll_next_unpin(cx), + Self::Concurrent(stream) => stream.poll_next_unpin(cx), Self::Memory(stream) => stream.poll_next_unpin(cx), + Self::Completed(stream) => stream + .poll_next_unpin(cx) + .map_ok(|v| v.freeze()) + .map_err(|_| UpstreamError), } } } @@ -192,9 +199,12 @@ impl Stream for MemStream { type Item = CacheStreamItem; fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - let mut new_bytes = Bytes::new(); - std::mem::swap(&mut self.0, &mut new_bytes); - Poll::Ready(Some(Ok(new_bytes))) + let new_bytes = self.0.split_to(1460); + if new_bytes.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(new_bytes))) + } } } diff --git a/src/config.rs b/src/config.rs index 6ecd731..77185b2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,6 +3,7 @@ use std::path::PathBuf; use std::sync::atomic::AtomicBool; use clap::{crate_authors, crate_description, crate_version, Clap}; +use url::Url; // Validate tokens is an atomic because it's faster than locking on rwlock. pub static VALIDATE_TOKENS: AtomicBool = AtomicBool::new(false); @@ -42,7 +43,20 @@ pub struct CliArgs { env = "LOW_MEMORY_MODE", takes_value = false )] + /// Changes the caching behavior to avoid buffering images in memory, and + /// instead use the filesystem as the buffer backing. This is useful for + /// clients in low (< 1GB) RAM environments. pub low_memory: bool, + /// Changes verbosity. Default verbosity is INFO, while increasing counts + /// of verbose flags increases to DEBUG and TRACE, respectively. #[clap(short, long, parse(from_occurrences))] pub verbose: usize, + /// Overrides the upstream URL to fetch images from. Don't use this unless + /// you know what you're dealing with. + #[clap(long)] + pub override_upstream: Option, + /// Disables token validation. Don't use this unless you know the + /// ramifications of this command. + #[clap(long)] + pub disable_token_validation: bool, } diff --git a/src/main.rs b/src/main.rs index 4ee3c7e..98f94d3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -51,29 +51,6 @@ async fn main() -> Result<(), std::io::Error> { dotenv::dotenv().ok(); let cli_args = CliArgs::parse(); - println!(concat!( - env!("CARGO_PKG_NAME"), - " ", - env!("CARGO_PKG_VERSION"), - " Copyright (C) 2021 ", - env!("CARGO_PKG_AUTHORS"), - "\n\n", - env!("CARGO_PKG_NAME"), - " is free software: you can redistribute it and/or modify\n\ - it under the terms of the GNU General Public License as published by\n\ - the Free Software Foundation, either version 3 of the License, or\n\ - (at your option) any later version.\n\n", - env!("CARGO_PKG_NAME"), - " is distributed in the hope that it will be useful,\n\ - but WITHOUT ANY WARRANTY; without even the implied warranty of\n\ - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n\ - GNU General Public License for more details.\n\n\ - You should have received a copy of the GNU General Public License\n\ - along with ", - env!("CARGO_PKG_NAME"), - ". If not, see .\n" - )); - let port = cli_args.port; let memory_max_size = cli_args.memory_quota.get(); let disk_quota = cli_args.disk_quota; @@ -88,6 +65,8 @@ async fn main() -> Result<(), std::io::Error> { .init() .unwrap(); + print_preamble_and_warnings(); + let client_secret = if let Ok(v) = env::var("CLIENT_SECRET") { v } else { @@ -111,13 +90,21 @@ async fn main() -> Result<(), std::io::Error> { // Set ctrl+c to send a stop message let running = Arc::new(AtomicBool::new(true)); - let r = running.clone(); + let running_1 = running.clone(); + let system = System::current(); ctrlc::set_handler(move || { + let system = &system; let client_secret = client_secret.clone(); + let running_2 = Arc::clone(&running_1); System::new().block_on(async move { - send_stop(&client_secret).await; + if running_2.load(Ordering::SeqCst) { + send_stop(&client_secret).await; + } else { + warn!("Got second ctrl-c, forcefully exiting"); + system.stop() + } }); - r.store(false, Ordering::SeqCst); + running_1.store(false, Ordering::SeqCst); }) .expect("Error setting Ctrl-C handler"); @@ -174,3 +161,28 @@ async fn main() -> Result<(), std::io::Error> { Ok(()) } + +fn print_preamble_and_warnings() { + println!(concat!( + env!("CARGO_PKG_NAME"), + " ", + env!("CARGO_PKG_VERSION"), + " Copyright (C) 2021 ", + env!("CARGO_PKG_AUTHORS"), + "\n\n", + env!("CARGO_PKG_NAME"), + " is free software: you can redistribute it and/or modify\n\ + it under the terms of the GNU General Public License as published by\n\ + the Free Software Foundation, either version 3 of the License, or\n\ + (at your option) any later version.\n\n", + env!("CARGO_PKG_NAME"), + " is distributed in the hope that it will be useful,\n\ + but WITHOUT ANY WARRANTY; without even the implied warranty of\n\ + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n\ + GNU General Public License for more details.\n\n\ + You should have received a copy of the GNU General Public License\n\ + along with ", + env!("CARGO_PKG_NAME"), + ". If not, see .\n" + )); +} diff --git a/src/ping.rs b/src/ping.rs index c45f630..61e2c67 100644 --- a/src/ping.rs +++ b/src/ping.rs @@ -68,7 +68,7 @@ impl<'a> From<(&'a str, &CliArgs)> for Request<'a> { pub struct Response { pub image_server: Url, pub latest_build: usize, - pub url: String, + pub url: Url, pub token_key: Option, pub compromised: bool, pub paused: bool, @@ -145,8 +145,8 @@ impl std::fmt::Debug for Tls { } } -pub async fn update_server_state(secret: &str, req: &CliArgs, data: &mut Arc) { - let req = Request::from_config_and_state(secret, req, data); +pub async fn update_server_state(secret: &str, cli: &CliArgs, data: &mut Arc) { + let req = Request::from_config_and_state(secret, cli, data); let client = reqwest::Client::new(); let resp = client.post(CONTROL_CENTER_PING_URL).json(&req).send().await; match resp { @@ -154,7 +154,10 @@ pub async fn update_server_state(secret: &str, req: &CliArgs, data: &mut Arc { let mut write_guard = data.0.write(); - write_guard.image_server = resp.image_server; + if !write_guard.url_overridden && write_guard.image_server != resp.image_server { + warn!("Ignoring new upstream url!"); + write_guard.image_server = resp.image_server; + } if let Some(key) = resp.token_key { if let Some(key) = base64::decode(&key) @@ -167,7 +170,9 @@ pub async fn update_server_state(secret: &str, req: &CliArgs, data: &mut Arc, req: HttpRequest) -> impl R req.path().chars().skip(1).collect::() ); info!("Got unknown path, just proxying: {}", path); - let resp = reqwest::get(path).await.unwrap(); + let resp = match reqwest::get(path).await { + Ok(resp) => resp, + Err(e) => { + error!("{}", e); + return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish()); + } + }; let content_type = resp.headers().get(CONTENT_TYPE); let mut resp_builder = HttpResponseBuilder::new(resp.status()); if let Some(content_type) = content_type { @@ -153,6 +160,8 @@ fn validate_token( return Err(TokenValidationError::InvalidChapterHash); } + debug!("Token validated!"); + Ok(()) } @@ -194,15 +203,15 @@ async fn fetch_image( reqwest::get(format!( "{}/data-saver/{}/{}", state.0.read().image_server, - &key.1, - &key.2 + &key.0, + &key.1 )) } else { reqwest::get(format!( "{}/data/{}/{}", state.0.read().image_server, - &key.1, - &key.2 + &key.0, + &key.1 )) } .await; @@ -214,6 +223,7 @@ async fn fetch_image( let is_image = content_type .map(|v| String::from_utf8_lossy(v.as_ref()).contains("image/")) .unwrap_or_default(); + if resp.status() != 200 || !is_image { warn!( "Got non-OK or non-image response code from upstream, proxying and not caching result.", @@ -241,6 +251,9 @@ async fn fetch_image( }; let body = resp.bytes_stream().map_err(|e| e.into()); + + debug!("Inserting into cache"); + let metadata = ImageMetadata::new(content_type, length, last_mod).unwrap(); let (stream, metadata) = { match cache.lock().put(key, Box::new(body), metadata).await { @@ -254,6 +267,8 @@ async fn fetch_image( } }; + debug!("Done putting into cache"); + return construct_response(stream, &metadata); } Err(e) => { @@ -269,6 +284,8 @@ fn construct_response( data: impl Stream> + Unpin + 'static, metadata: &ImageMetadata, ) -> ServerResponse { + debug!("Constructing response"); + let mut resp = HttpResponse::Ok(); if let Some(content_type) = metadata.content_type { resp.append_header((CONTENT_TYPE, content_type.as_ref())); diff --git a/src/state.rs b/src/state.rs index 6a48699..21a9a0e 100644 --- a/src/state.rs +++ b/src/state.rs @@ -13,7 +13,8 @@ pub struct ServerState { pub precomputed_key: PrecomputedKey, pub image_server: Url, pub tls_config: Tls, - pub url: String, + pub url: Url, + pub url_overridden: bool, pub log_state: LogState, } @@ -36,7 +37,7 @@ impl ServerState { match resp { Ok(resp) => match resp.json::().await { - Ok(resp) => { + Ok(mut resp) => { let key = resp .token_key .and_then(|key| { @@ -60,21 +61,31 @@ impl ServerState { warn!("Control center has paused this node!"); } - info!("This client's URL has been set to {}", resp.url); - - if resp.force_tokens { - info!("This client will validate tokens."); + if let Some(ref override_url) = config.override_upstream { + resp.image_server = override_url.clone(); + warn!("Upstream URL overridden to: {}", resp.image_server); } else { - info!("This client will not validate tokens."); } - VALIDATE_TOKENS.store(resp.force_tokens, Ordering::Release); + info!("This client's URL has been set to {}", resp.url); + + if config.disable_token_validation { + warn!("Token validation is explicitly disabled!"); + } else { + if resp.force_tokens { + info!("This client will validate tokens."); + } else { + info!("This client will not validate tokens."); + } + VALIDATE_TOKENS.store(resp.force_tokens, Ordering::Release); + } Ok(Self { precomputed_key: key, image_server: resp.image_server, tls_config: resp.tls.unwrap(), url: resp.url, + url_overridden: config.override_upstream.is_some(), log_state: LogState { was_paused_before: resp.paused, },