From fa9ab93c77dcce6bf3cd64e274d3bfc79534ebfd Mon Sep 17 00:00:00 2001 From: Edward Shen Date: Thu, 15 Jul 2021 12:29:55 -0400 Subject: [PATCH] Add proxy support --- Cargo.toml | 2 +- src/client.rs | 33 +++++++++++++++++++++++++-------- src/config.rs | 29 +++++++++++++++++++++++++++++ src/main.rs | 9 ++++++++- src/metrics.rs | 6 +++++- src/ping.rs | 9 +++++++-- src/state.rs | 4 +++- src/stop.rs | 8 ++++---- 8 files changed, 82 insertions(+), 18 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b9278ef..31f7974 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,7 +52,7 @@ sodiumoxide = "0.2" sqlx = { version = "0.5", features = [ "runtime-actix-rustls", "sqlite", "time", "chrono", "macros", "offline" ] } tar = "0.4" thiserror = "1" -tokio = { version = "1", features = [ "rt-multi-thread", "macros", "fs", "sync", "parking_lot" ] } +tokio = { version = "1", features = [ "rt-multi-thread", "macros", "fs", "time", "sync", "parking_lot" ] } tokio-stream = { version = "0.1", features = [ "sync" ] } tokio-util = { version = "0.6", features = [ "codec" ] } tracing = "0.1" diff --git a/src/client.rs b/src/client.rs index f7f7d8f..22b9a98 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::Duration; @@ -11,21 +12,37 @@ use reqwest::header::{ ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, CACHE_CONTROL, CONTENT_LENGTH, CONTENT_TYPE, LAST_MODIFIED, X_CONTENT_TYPE_OPTIONS, }; -use reqwest::{Client, StatusCode}; +use reqwest::{Client, Proxy, StatusCode}; use tokio::sync::watch::{channel, Receiver}; use tokio::sync::Notify; -use tracing::{debug, error, warn}; +use tracing::{debug, error, info, warn}; use crate::cache::{Cache, CacheKey, ImageMetadata}; +use crate::config::{DISABLE_CERT_VALIDATION, USE_PROXY}; -pub static HTTP_CLIENT: Lazy = Lazy::new(|| CachingClient { - inner: Client::builder() +pub static HTTP_CLIENT: Lazy = Lazy::new(|| { + let mut inner = Client::builder() .pool_idle_timeout(Duration::from_secs(180)) .https_only(true) - .http2_prior_knowledge() - .build() - .expect("Client initialization to work"), - locks: RwLock::new(HashMap::new()), + .http2_prior_knowledge(); + + if let Some(socket_addr) = USE_PROXY.get() { + info!( + "Using {} as a proxy for upstream requests.", + socket_addr.as_str() + ); + inner = inner.proxy(Proxy::all(socket_addr.as_str()).unwrap()); + } + + if DISABLE_CERT_VALIDATION.load(Ordering::Acquire) { + inner = inner.danger_accept_invalid_certs(true); + } + + let inner = inner.build().expect("Client initialization to work"); + CachingClient { + inner, + locks: RwLock::new(HashMap::new()), + } }); pub static DEFAULT_HEADERS: Lazy = Lazy::new(|| { diff --git a/src/config.rs b/src/config.rs index 5c4fa2b..ac29b83 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,6 +10,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use clap::{crate_authors, crate_description, crate_version, Clap}; use log::LevelFilter; +use once_cell::sync::OnceCell; use serde::{Deserialize, Serialize}; use thiserror::Error; use tracing::level_filters::LevelFilter as TracingLevelFilter; @@ -20,6 +21,8 @@ use crate::units::{KilobitsPerSecond, Mebibytes, Port}; // Validate tokens is an atomic because it's faster than locking on rwlock. pub static VALIDATE_TOKENS: AtomicBool = AtomicBool::new(false); pub static OFFLINE_MODE: AtomicBool = AtomicBool::new(false); +pub static USE_PROXY: OnceCell = OnceCell::new(); +pub static DISABLE_CERT_VALIDATION: AtomicBool = AtomicBool::new(false); #[derive(Error, Debug)] pub enum ConfigError { @@ -70,6 +73,19 @@ pub fn load_config() -> Result { Ordering::Release, ); + config.proxy.clone().map(|socket| { + USE_PROXY + .set(socket) + .expect("USE_PROXY to be set only by this function"); + }); + + DISABLE_CERT_VALIDATION.store( + config + .unstable_options + .contains(&UnstableOptions::DisableCertValidation), + Ordering::Release, + ); + Ok(config) } @@ -92,6 +108,7 @@ pub struct Config { pub override_upstream: Option, pub enable_metrics: bool, pub geoip_license_key: Option, + pub proxy: Option, } impl Config { @@ -192,6 +209,7 @@ impl Config { None } }), + proxy: cli_args.proxy, } } } @@ -326,6 +344,10 @@ struct CliArgs { /// value is "on_disk", other options are "lfu" and "lru". #[clap(short = 't', long)] pub cache_type: Option, + /// Whether or not to use a proxy for upstream requests. This affects all + /// requests except for the shutdown request. + #[clap(short = 'P', long)] + pub proxy: Option, } #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -343,6 +365,10 @@ pub enum UnstableOptions { /// Serves HTTP in plaintext DisableTls, + + /// Disable certificate validation. Only useful for debugging with a MITM + /// proxy + DisableCertValidation, } impl FromStr for UnstableOptions { @@ -354,6 +380,7 @@ impl FromStr for UnstableOptions { "disable-token-validation" => Ok(Self::DisableTokenValidation), "offline-mode" => Ok(Self::OfflineMode), "disable-tls" => Ok(Self::DisableTls), + "disable-cert-validation" => Ok(Self::DisableCertValidation), _ => Err(format!("Unknown unstable option '{}'", s)), } } @@ -366,6 +393,7 @@ impl Display for UnstableOptions { Self::DisableTokenValidation => write!(f, "disable-token-validation"), Self::OfflineMode => write!(f, "offline-mode"), Self::DisableTls => write!(f, "disable-tls"), + Self::DisableCertValidation => write!(f, "disable-cert-validation"), } } } @@ -407,6 +435,7 @@ mod config { ephemeral_disk_encryption: true, config_path: None, cache_type: Some(CacheType::Lfu), + proxy: None, }; let yaml_args = YamlArgs { diff --git a/src/main.rs b/src/main.rs index 920b596..39b98e8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -232,7 +232,7 @@ async fn main() -> Result<(), Box> { // Waiting for us to finish sending stop message while running.load(Ordering::SeqCst) { - std::thread::sleep(Duration::from_millis(250)); + tokio::time::sleep(Duration::from_millis(250)).await; } Ok(()) @@ -309,6 +309,13 @@ fn print_preamble_and_warnings(args: &Config) -> Result<(), Box> { warn!("Serving insecure traffic! You better be running this for development only."); } + if args + .unstable_options + .contains(&UnstableOptions::DisableCertValidation) + { + error!("Cert validation disabled! You REALLY only better be debugging."); + } + if args.override_upstream.is_some() && !args .unstable_options diff --git a/src/metrics.rs b/src/metrics.rs index 11b1fe6..1de6937 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -11,6 +11,7 @@ use tar::Archive; use thiserror::Error; use tracing::{debug, field::debug, info, warn}; +use crate::client::HTTP_CLIENT; use crate::config::ClientSecret; pub static GEOIP_DATABASE: OnceCell>> = OnceCell::new(); @@ -136,7 +137,10 @@ pub async fn load_geo_ip_data(license_key: ClientSecret) -> Result<(), DbLoadErr } async fn fetch_db(license_key: ClientSecret) -> Result<(), DbLoadError> { - let resp = reqwest::get(format!("https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-Country&license_key={}&suffix=tar.gz", license_key.as_str())) + let resp = HTTP_CLIENT + .inner() + .get(format!("https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-Country&license_key={}&suffix=tar.gz", license_key.as_str())) + .send() .await? .bytes() .await?; diff --git a/src/ping.rs b/src/ping.rs index a29b412..1e7db13 100644 --- a/src/ping.rs +++ b/src/ping.rs @@ -12,6 +12,7 @@ use sodiumoxide::crypto::box_::PrecomputedKey; use tracing::{debug, error, info, warn}; use url::Url; +use crate::client::HTTP_CLIENT; use crate::config::{ClientSecret, Config, UnstableOptions, VALIDATE_TOKENS}; use crate::state::{ RwLockServerState, PREVIOUSLY_COMPROMISED, PREVIOUSLY_PAUSED, TLS_CERTS, @@ -178,8 +179,12 @@ pub async fn update_server_state( ) { let req = Request::from_config_and_state(secret, cli); debug!("Sending ping request: {:?}", req); - let client = reqwest::Client::new(); - let resp = client.post(CONTROL_CENTER_PING_URL).json(&req).send().await; + let resp = HTTP_CLIENT + .inner() + .post(CONTROL_CENTER_PING_URL) + .json(&req) + .send() + .await; match resp { Ok(resp) => match resp.json::().await { Ok(Response::Ok(resp)) => { diff --git a/src/state.rs b/src/state.rs index 61c9d4d..390f658 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,6 +1,7 @@ use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; +use crate::client::HTTP_CLIENT; use crate::config::{ClientSecret, Config, UnstableOptions, OFFLINE_MODE, VALIDATE_TOKENS}; use crate::ping::{Request, Response, CONTROL_CENTER_PING_URL}; use arc_swap::ArcSwap; @@ -46,7 +47,8 @@ pub enum ServerInitError { impl ServerState { pub async fn init(secret: &ClientSecret, config: &Config) -> Result { - let resp = reqwest::Client::new() + let resp = HTTP_CLIENT + .inner() .post(CONTROL_CENTER_PING_URL) .json(&Request::from((secret, config))) .send() diff --git a/src/stop.rs b/src/stop.rs index b2e3e1b..ccec356 100644 --- a/src/stop.rs +++ b/src/stop.rs @@ -2,6 +2,7 @@ use reqwest::StatusCode; use serde::Serialize; use tracing::{info, warn}; +use crate::client::HTTP_CLIENT; use crate::config::ClientSecret; const CONTROL_CENTER_STOP_URL: &str = "https://api.mangadex.network/ping"; @@ -12,11 +13,10 @@ struct StopRequest<'a> { } pub async fn send_stop(secret: &ClientSecret) { - let request = StopRequest { secret }; - let client = reqwest::Client::new(); - match client + match HTTP_CLIENT + .inner() .post(CONTROL_CENTER_STOP_URL) - .json(&request) + .json(&StopRequest { secret }) .send() .await {