From 353ee7271328f327f4e78ee1df3c8a889ea2bce3 Mon Sep 17 00:00:00 2001 From: Edward Shen Date: Wed, 14 Jul 2021 21:56:29 -0400 Subject: [PATCH] Add unit tests --- db_queries/insert_image.sql | 1 + src/cache/compat.rs | 73 ++++++++++++++++++++++++++++++ src/cache/disk.rs | 90 ++++++++++++++++++++++++++++--------- src/config.rs | 66 ++++++++++++++++++++++++++- src/main.rs | 1 + src/units.rs | 23 +++++++++- 6 files changed, 230 insertions(+), 24 deletions(-) create mode 100644 db_queries/insert_image.sql diff --git a/db_queries/insert_image.sql b/db_queries/insert_image.sql new file mode 100644 index 0000000..266fde6 --- /dev/null +++ b/db_queries/insert_image.sql @@ -0,0 +1 @@ +insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing \ No newline at end of file diff --git a/src/cache/compat.rs b/src/cache/compat.rs index baade7c..e45980b 100644 --- a/src/cache/compat.rs +++ b/src/cache/compat.rs @@ -77,3 +77,76 @@ impl<'de> Deserialize<'de> for LegacyImageContentType { deserializer.deserialize_str(LegacyImageContentTypeVisitor) } } + +#[cfg(test)] +mod parse { + use std::error::Error; + + use chrono::DateTime; + + use crate::cache::ImageContentType; + + use super::LegacyImageMetadata; + + #[test] + fn from_valid_legacy_format() -> Result<(), Box> { + let legacy_header = r#"{"content_type":"image/jpeg","last_modified":"Sat, 10 Apr 2021 10:55:22 GMT","size":117888}"#; + let metadata: LegacyImageMetadata = serde_json::from_str(legacy_header)?; + + assert_eq!( + metadata.content_type.map(|v| v.0), + Some(ImageContentType::Jpeg) + ); + assert_eq!(metadata.size, Some(117888)); + assert_eq!( + metadata.last_modified.map(|v| v.0), + Some(DateTime::parse_from_rfc2822( + "Sat, 10 Apr 2021 10:55:22 GMT" + )?) + ); + + Ok(()) + } + + #[test] + fn empty_metadata() -> Result<(), Box> { + let legacy_header = "{}"; + let metadata: LegacyImageMetadata = serde_json::from_str(legacy_header)?; + + assert!(metadata.content_type.is_none()); + assert!(metadata.size.is_none()); + assert!(metadata.last_modified.is_none()); + + Ok(()) + } + + #[test] + fn invalid_image_mime_value() { + let legacy_header = r#"{"content_type":"image/not-a-real-image"}"#; + assert!(serde_json::from_str::(legacy_header).is_err()); + } + + #[test] + fn invalid_date_time() { + let legacy_header = r#"{"last_modified":"idk last tuesday?"}"#; + assert!(serde_json::from_str::(legacy_header).is_err()); + } + + #[test] + fn invalid_size() { + let legacy_header = r#"{"size":-1}"#; + assert!(serde_json::from_str::(legacy_header).is_err()); + } + + #[test] + fn wrong_image_type() { + let legacy_header = r#"{"content_type":25}"#; + assert!(serde_json::from_str::(legacy_header).is_err()); + } + + #[test] + fn wrong_date_time_type() { + let legacy_header = r#"{"last_modified":false}"#; + assert!(serde_json::from_str::(legacy_header).is_err()); + } +} diff --git a/src/cache/disk.rs b/src/cache/disk.rs index a5f37fe..e2bb215 100644 --- a/src/cache/disk.rs +++ b/src/cache/disk.rs @@ -41,28 +41,30 @@ impl DiskCache { /// This internally spawns a task that will wait for filesystem /// notifications when a file has been written. pub async fn new(disk_max_size: Bytes, disk_path: PathBuf) -> Arc { - let (db_tx, db_rx) = channel(128); let db_pool = { let db_url = format!("sqlite:{}/metadata.sqlite", disk_path.to_string_lossy()); let mut options = SqliteConnectOptions::from_str(&db_url) .unwrap() .create_if_missing(true); options.log_statements(LevelFilter::Trace); - let db = SqlitePool::connect_with(options).await.unwrap(); - - // Run db init - sqlx::query_file!("./db_queries/init.sql") - .execute(&mut db.acquire().await.unwrap()) - .await - .unwrap(); - - db + SqlitePool::connect_with(options).await.unwrap() }; + Self::from_db_pool(db_pool, disk_max_size, disk_path).await + } + + async fn from_db_pool(pool: SqlitePool, disk_max_size: Bytes, disk_path: PathBuf) -> Arc { + let (db_tx, db_rx) = channel(128); + // Run db init + sqlx::query_file!("./db_queries/init.sql") + .execute(&mut pool.acquire().await.unwrap()) + .await + .unwrap(); + // This is intentional. #[allow(clippy::cast_sign_loss)] let disk_cur_size = { - let mut conn = db_pool.acquire().await.unwrap(); + let mut conn = pool.acquire().await.unwrap(); sqlx::query!("SELECT IFNULL(SUM(size), 0) AS size FROM Images") .fetch_one(&mut conn) .await @@ -80,7 +82,7 @@ impl DiskCache { tokio::spawn(db_listener( Arc::clone(&new_self), db_rx, - db_pool, + pool, disk_max_size.get() as u64 / 20 * 19, )); @@ -239,14 +241,9 @@ async fn handle_db_put( // This is intentional. #[allow(clippy::cast_possible_wrap)] let casted_size = size as i64; - let query = sqlx::query!( - "insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing", - key, - casted_size, - now, - ) - .execute(transaction) - .await; + let query = sqlx::query_file!("./db_queries/insert_image.sql", key, casted_size, now) + .execute(transaction) + .await; if let Err(e) = query { warn!("Failed to add to db: {}", e); @@ -369,6 +366,59 @@ impl CallbackCache for DiskCache { } } +#[cfg(test)] +mod disk_cache { + use std::error::Error; + use std::path::PathBuf; + use std::sync::atomic::Ordering; + + use chrono::Utc; + use sqlx::SqlitePool; + + use crate::units::Bytes; + + use super::DiskCache; + + #[tokio::test] + async fn db_is_initialized() -> Result<(), Box> { + let conn = SqlitePool::connect("sqlite::memory:").await?; + let _cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await; + let res = sqlx::query("select * from Images").execute(&conn).await; + assert!(res.is_ok()); + Ok(()) + } + + #[tokio::test] + async fn db_initializes_empty() -> Result<(), Box> { + let conn = SqlitePool::connect("sqlite::memory:").await?; + let cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await; + assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 0); + Ok(()) + } + + #[tokio::test] + async fn db_can_load_from_existing() -> Result<(), Box> { + let conn = SqlitePool::connect("sqlite::memory:").await?; + sqlx::query_file!("./db_queries/init.sql") + .execute(&conn) + .await?; + + let now = Utc::now(); + sqlx::query_file!("./db_queries/insert_image.sql", "a", 4, now) + .execute(&conn) + .await?; + + let now = Utc::now(); + sqlx::query_file!("./db_queries/insert_image.sql", "b", 15, now) + .execute(&conn) + .await?; + + let cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await; + assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 19); + Ok(()) + } +} + #[cfg(test)] mod db { use chrono::{DateTime, Utc}; diff --git a/src/config.rs b/src/config.rs index de3510e..af36cde 100644 --- a/src/config.rs +++ b/src/config.rs @@ -198,7 +198,7 @@ impl std::fmt::Debug for ClientSecret { } } -#[derive(Deserialize, Copy, Clone, Debug)] +#[derive(Deserialize, Copy, Clone, Debug, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum CacheType { OnDisk, @@ -348,7 +348,69 @@ mod sample_yaml { use crate::config::YamlArgs; #[test] - fn sample_yaml_parses() { + fn parses() { assert!(serde_yaml::from_str::(include_str!("../settings.sample.yaml")).is_ok()); } } + +#[cfg(test)] +mod config { + use std::path::PathBuf; + + use log::LevelFilter; + use tracing::level_filters::LevelFilter as TracingLevelFilter; + + use crate::config::{CacheType, ClientSecret, Config, YamlExtendedOptions, YamlServerSettings}; + use crate::units::{KilobitsPerSecond, Mebibytes, Port}; + + use super::{CliArgs, YamlArgs}; + + #[test] + fn cli_has_priority() { + let cli_config = CliArgs { + port: Port::new(1234), + memory_quota: Some(Mebibytes::new(10)), + disk_quota: Some(Mebibytes::new(10)), + cache_path: Some(PathBuf::from("a")), + network_speed: KilobitsPerSecond::new(10), + verbose: 1, + quiet: 0, + unstable_options: vec![], + override_upstream: None, + ephemeral_disk_encryption: true, + config_path: None, + cache_type: Some(CacheType::Lfu), + }; + + let yaml_args = YamlArgs { + max_cache_size_in_mebibytes: Mebibytes::new(50), + server_settings: YamlServerSettings { + secret: ClientSecret(String::new()), + port: Port::new(4321).expect("to work?"), + external_max_kilobits_per_second: KilobitsPerSecond::new(50).expect("to work?"), + external_port: None, + graceful_shutdown_wait_seconds: None, + hostname: None, + external_ip: None, + }, + extended_options: Some(YamlExtendedOptions { + memory_quota: Some(Mebibytes::new(50)), + cache_type: Some(CacheType::Lru), + ephemeral_disk_encryption: Some(false), + enable_metrics: None, + logging_level: Some(LevelFilter::Error), + cache_path: Some(PathBuf::from("b")), + }), + }; + + let config = Config::from_cli_and_file(cli_config, yaml_args); + assert_eq!(Some(config.port), Port::new(1234)); + assert_eq!(config.memory_quota, Mebibytes::new(10)); + assert_eq!(config.disk_quota, Mebibytes::new(10)); + assert_eq!(config.cache_path, PathBuf::from("a")); + assert_eq!(Some(config.network_speed), KilobitsPerSecond::new(10)); + assert_eq!(config.log_level, TracingLevelFilter::DEBUG); + assert_eq!(config.ephemeral_disk_encryption, true); + assert_eq!(config.cache_type, CacheType::Lfu); + } +} diff --git a/src/main.rs b/src/main.rs index b4d3485..3441c15 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,6 +36,7 @@ mod metrics; mod ping; mod routes; mod state; +#[cfg(not(tarpaulin_include))] mod stop; mod units; diff --git a/src/units.rs b/src/units.rs index 213dc08..a1c77df 100644 --- a/src/units.rs +++ b/src/units.rs @@ -5,13 +5,18 @@ use std::str::FromStr; use serde::{Deserialize, Serialize}; /// Wrapper type for a port number. -#[derive(Serialize, Deserialize, Debug, Clone, Copy)] +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] pub struct Port(NonZeroU16); impl Port { pub const fn get(self) -> u16 { self.0.get() } + + #[cfg(test)] + pub fn new(amt: u16) -> Option { + NonZeroU16::new(amt).map(Self) + } } impl Default for Port { @@ -37,6 +42,13 @@ impl Display for Port { #[derive(Copy, Clone, Serialize, Deserialize, Default, Debug, Hash, Eq, PartialEq)] pub struct Mebibytes(usize); +impl Mebibytes { + #[cfg(test)] + pub fn new(size: usize) -> Self { + Mebibytes(size) + } +} + impl FromStr for Mebibytes { type Err = ParseIntError; @@ -45,7 +57,7 @@ impl FromStr for Mebibytes { } } -pub struct Bytes(usize); +pub struct Bytes(pub usize); impl Bytes { pub const fn get(&self) -> usize { @@ -62,6 +74,13 @@ impl From for Bytes { #[derive(Copy, Clone, Deserialize, Debug, Hash, Eq, PartialEq)] pub struct KilobitsPerSecond(NonZeroU64); +impl KilobitsPerSecond { + #[cfg(test)] + pub fn new(size: u64) -> Option { + NonZeroU64::new(size).map(Self) + } +} + impl FromStr for KilobitsPerSecond { type Err = ParseIntError;