diff --git a/src/config.rs b/src/config.rs index 614272c..eeb44f5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,7 +10,7 @@ pub static VALIDATE_TOKENS: AtomicBool = AtomicBool::new(false); // everywhere. pub static SEND_SERVER_VERSION: AtomicBool = AtomicBool::new(false); -#[derive(Clap)] +#[derive(Clap, Clone)] pub struct CliArgs { /// The port to listen on. #[clap(short, long, default_value = "42069", env = "PORT")] diff --git a/src/main.rs b/src/main.rs index d180091..ad901a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,10 +12,11 @@ use std::{num::ParseIntError, sync::atomic::Ordering}; use actix_web::rt::{spawn, time, System}; use actix_web::web::{self, Data}; use actix_web::{App, HttpServer}; +use cache::Cache; use clap::Clap; use config::CliArgs; use log::{debug, error, warn, LevelFilter}; -use parking_lot::RwLock; +use parking_lot::{Mutex, RwLock}; use rustls::{NoClientAuth, ServerConfig}; use simple_logger::SimpleLogger; use state::{RwLockServerState, ServerState}; @@ -50,6 +51,9 @@ async fn main() -> Result<(), std::io::Error> { dotenv::dotenv().ok(); let cli_args = CliArgs::parse(); let port = cli_args.port; + let memory_max_size = cli_args.memory_quota.get(); + let disk_quota = cli_args.disk_quota; + let cache_path = cli_args.cache_path.clone(); SimpleLogger::new() .with_level(LevelFilter::Info) @@ -107,6 +111,11 @@ async fn main() -> Result<(), std::io::Error> { .service(routes::token_data_saver) .route("{tail:.*}", web::get().to(routes::default)) .app_data(Data::from(Arc::clone(&data_1))) + .app_data(Data::new(Mutex::new(Cache::new( + memory_max_size, + disk_quota, + cache_path.clone(), + )))) }) .shutdown_timeout(60) .bind_rustls(format!("0.0.0.0:{}", port), tls_config)? diff --git a/src/ping.rs b/src/ping.rs index c251b79..0a3e6b0 100644 --- a/src/ping.rs +++ b/src/ping.rs @@ -1,11 +1,19 @@ -use std::sync::Arc; +use std::{io::BufReader, sync::Arc}; use std::{ num::{NonZeroU16, NonZeroUsize}, sync::atomic::Ordering, }; use log::{error, info, warn}; -use serde::{Deserialize, Serialize}; +use rustls::{ + internal::pemfile::{certs, rsa_private_keys}, + sign::RSASigningKey, +}; +use rustls::{sign::SigningKey, Certificate}; +use serde::{ + de::{MapAccess, Visitor}, + Deserialize, Serialize, +}; use sodiumoxide::crypto::box_::PrecomputedKey; use url::Url; @@ -69,11 +77,72 @@ pub struct Response { pub tls: Option, } -#[derive(Deserialize, Debug)] pub struct Tls { pub created_at: String, - pub private_key: String, - pub certificate: String, + pub priv_key: Arc>, + pub certs: Vec, +} + +impl<'de> Deserialize<'de> for Tls { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct TlsVisitor; + + impl<'de> Visitor<'de> for TlsVisitor { + type Value = Tls; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a tls struct") + } + + fn visit_map(self, mut map: A) -> Result + where + A: MapAccess<'de>, + { + let mut created_at = None; + let mut priv_key = None; + let mut certificates = None; + + while let Some((key, value)) = map.next_entry::<&str, String>()? { + match key { + "created_at" => created_at = Some(value.to_string()), + "private_key" => { + priv_key = rsa_private_keys(&mut BufReader::new(value.as_bytes())) + .ok() + .and_then(|mut v| { + v.pop().and_then(|key| RSASigningKey::new(&key).ok()) + }) + } + "certificate" => { + certificates = certs(&mut BufReader::new(value.as_bytes())).ok() + } + _ => (), // Ignore extra fields + } + } + + match (created_at, priv_key, certificates) { + (Some(created_at), Some(priv_key), Some(certificates)) => Ok(Tls { + created_at, + priv_key: Arc::new(Box::new(priv_key)), + certs: certificates, + }), + _ => Err(serde::de::Error::custom("Could not deserialize tls info")), + } + } + } + + deserializer.deserialize_map(TlsVisitor) + } +} + +impl std::fmt::Debug for Tls { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Tls") + .field("created_at", &self.created_at) + .finish() + } } pub async fn update_server_state(secret: &str, req: &CliArgs, data: &mut Arc) { diff --git a/src/routes.rs b/src/routes.rs index 260742e..4f7dc71 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -13,11 +13,12 @@ use bytes::Bytes; use chrono::{DateTime, Utc}; use futures::stream; use log::{error, info, warn}; +use parking_lot::Mutex; use serde::Deserialize; use sodiumoxide::crypto::box_::{open_precomputed, Nonce, PrecomputedKey, NONCEBYTES}; use thiserror::Error; -use crate::cache::{CacheKey, CachedImage}; +use crate::cache::{Cache, CacheKey, CachedImage}; use crate::client_api_version; use crate::config::{SEND_SERVER_VERSION, VALIDATE_TOKENS}; use crate::state::RwLockServerState; @@ -51,6 +52,7 @@ impl Responder for ServerResponse { #[get("/{token}/data/{chapter_hash}/{file_name}")] async fn token_data( state: Data, + cache: Data>, path: Path<(String, String, String)>, ) -> impl Responder { let (token, chapter_hash, file_name) = path.into_inner(); @@ -60,12 +62,13 @@ async fn token_data( } } - fetch_image(state, chapter_hash, file_name, false).await + fetch_image(state, cache, chapter_hash, file_name, false).await } #[get("/{token}/data-saver/{chapter_hash}/{file_name}")] async fn token_data_saver( state: Data, + cache: Data>, path: Path<(String, String, String)>, ) -> impl Responder { let (token, chapter_hash, file_name) = path.into_inner(); @@ -74,7 +77,7 @@ async fn token_data_saver( return ServerResponse::TokenValidationError(e); } } - fetch_image(state, chapter_hash, file_name, true).await + fetch_image(state, cache, chapter_hash, file_name, true).await } pub async fn default(state: Data, req: HttpRequest) -> impl Responder { @@ -172,13 +175,14 @@ fn push_headers(builder: &mut HttpResponseBuilder) -> &mut HttpResponseBuilder { async fn fetch_image( state: Data, + cache: Data>, chapter_hash: String, file_name: String, is_data_saver: bool, ) -> ServerResponse { let key = CacheKey(chapter_hash, file_name, is_data_saver); - if let Some(cached) = state.0.write().cache.get(&key).await { + if let Some(cached) = cache.lock().get(&key).await { return construct_response(cached); } @@ -243,7 +247,7 @@ async fn fetch_image( last_modified, }; let resp = construct_response(&cached); - state.0.write().cache.put(key, cached).await; + cache.lock().put(key, cached).await; return resp; } Err(e) => { diff --git a/src/state.rs b/src/state.rs index 48ccb40..6a48699 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,13 +1,10 @@ -use std::io::BufReader; use std::sync::{atomic::Ordering, Arc}; -use crate::config::{SEND_SERVER_VERSION, VALIDATE_TOKENS}; +use crate::config::{CliArgs, SEND_SERVER_VERSION, VALIDATE_TOKENS}; use crate::ping::{Request, Response, Tls, CONTROL_CENTER_PING_URL}; -use crate::{cache::Cache, config::CliArgs}; use log::{error, info, warn}; use parking_lot::RwLock; -use rustls::internal::pemfile::{certs, rsa_private_keys}; -use rustls::sign::{CertifiedKey, RSASigningKey}; +use rustls::sign::CertifiedKey; use rustls::ResolvesServerCert; use sodiumoxide::crypto::box_::PrecomputedKey; use url::Url; @@ -17,7 +14,6 @@ pub struct ServerState { pub image_server: Url, pub tls_config: Tls, pub url: String, - pub cache: Cache, pub log_state: LogState, } @@ -79,11 +75,6 @@ impl ServerState { image_server: resp.image_server, tls_config: resp.tls.unwrap(), url: resp.url, - cache: Cache::new( - config.memory_quota.get(), - config.disk_quota, - config.cache_path.clone(), - ), log_state: LogState { was_paused_before: resp.paused, }, @@ -113,21 +104,9 @@ pub struct RwLockServerState(pub RwLock); impl ResolvesServerCert for RwLockServerState { fn resolve(&self, _: rustls::ClientHello) -> Option { let read_guard = self.0.read(); - let priv_key = rsa_private_keys(&mut BufReader::new( - read_guard.tls_config.private_key.as_bytes(), - )) - .ok()? - .pop() - .unwrap(); - - let certs = certs(&mut BufReader::new( - read_guard.tls_config.certificate.as_bytes(), - )) - .ok()?; - Some(CertifiedKey { - cert: certs, - key: Arc::new(Box::new(RSASigningKey::new(&priv_key).unwrap())), + cert: read_guard.tls_config.certs.clone(), + key: Arc::clone(&read_guard.tls_config.priv_key), ocsp: None, sct_list: None, })