Compare commits

..

2 commits

Author SHA1 Message Date
d4d22ec674
Reduce tokio features 2021-07-14 21:56:46 -04:00
353ee72713
Add unit tests 2021-07-14 21:56:29 -04:00
7 changed files with 231 additions and 25 deletions

View file

@ -42,7 +42,7 @@ serde_yaml = "0.8"
sodiumoxide = "0.2" sodiumoxide = "0.2"
sqlx = { version = "0.5", features = [ "runtime-actix-rustls", "sqlite", "time", "chrono", "macros", "offline" ] } sqlx = { version = "0.5", features = [ "runtime-actix-rustls", "sqlite", "time", "chrono", "macros", "offline" ] }
thiserror = "1" thiserror = "1"
tokio = { version = "1", features = [ "full", "parking_lot" ] } tokio = { version = "1", features = [ "rt-multi-thread", "macros", "fs", "sync", "parking_lot" ] }
tokio-stream = { version = "0.1", features = [ "sync" ] } tokio-stream = { version = "0.1", features = [ "sync" ] }
tokio-util = { version = "0.6", features = [ "codec" ] } tokio-util = { version = "0.6", features = [ "codec" ] }
tracing = "0.1" tracing = "0.1"

View file

@ -0,0 +1 @@
insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing

73
src/cache/compat.rs vendored
View file

@ -77,3 +77,76 @@ impl<'de> Deserialize<'de> for LegacyImageContentType {
deserializer.deserialize_str(LegacyImageContentTypeVisitor) deserializer.deserialize_str(LegacyImageContentTypeVisitor)
} }
} }
#[cfg(test)]
mod parse {
use std::error::Error;
use chrono::DateTime;
use crate::cache::ImageContentType;
use super::LegacyImageMetadata;
#[test]
fn from_valid_legacy_format() -> Result<(), Box<dyn Error>> {
let legacy_header = r#"{"content_type":"image/jpeg","last_modified":"Sat, 10 Apr 2021 10:55:22 GMT","size":117888}"#;
let metadata: LegacyImageMetadata = serde_json::from_str(legacy_header)?;
assert_eq!(
metadata.content_type.map(|v| v.0),
Some(ImageContentType::Jpeg)
);
assert_eq!(metadata.size, Some(117888));
assert_eq!(
metadata.last_modified.map(|v| v.0),
Some(DateTime::parse_from_rfc2822(
"Sat, 10 Apr 2021 10:55:22 GMT"
)?)
);
Ok(())
}
#[test]
fn empty_metadata() -> Result<(), Box<dyn Error>> {
let legacy_header = "{}";
let metadata: LegacyImageMetadata = serde_json::from_str(legacy_header)?;
assert!(metadata.content_type.is_none());
assert!(metadata.size.is_none());
assert!(metadata.last_modified.is_none());
Ok(())
}
#[test]
fn invalid_image_mime_value() {
let legacy_header = r#"{"content_type":"image/not-a-real-image"}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn invalid_date_time() {
let legacy_header = r#"{"last_modified":"idk last tuesday?"}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn invalid_size() {
let legacy_header = r#"{"size":-1}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn wrong_image_type() {
let legacy_header = r#"{"content_type":25}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn wrong_date_time_type() {
let legacy_header = r#"{"last_modified":false}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
}

90
src/cache/disk.rs vendored
View file

@ -41,28 +41,30 @@ impl DiskCache {
/// This internally spawns a task that will wait for filesystem /// This internally spawns a task that will wait for filesystem
/// notifications when a file has been written. /// notifications when a file has been written.
pub async fn new(disk_max_size: Bytes, disk_path: PathBuf) -> Arc<Self> { pub async fn new(disk_max_size: Bytes, disk_path: PathBuf) -> Arc<Self> {
let (db_tx, db_rx) = channel(128);
let db_pool = { let db_pool = {
let db_url = format!("sqlite:{}/metadata.sqlite", disk_path.to_string_lossy()); let db_url = format!("sqlite:{}/metadata.sqlite", disk_path.to_string_lossy());
let mut options = SqliteConnectOptions::from_str(&db_url) let mut options = SqliteConnectOptions::from_str(&db_url)
.unwrap() .unwrap()
.create_if_missing(true); .create_if_missing(true);
options.log_statements(LevelFilter::Trace); options.log_statements(LevelFilter::Trace);
let db = SqlitePool::connect_with(options).await.unwrap(); SqlitePool::connect_with(options).await.unwrap()
// Run db init
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut db.acquire().await.unwrap())
.await
.unwrap();
db
}; };
Self::from_db_pool(db_pool, disk_max_size, disk_path).await
}
async fn from_db_pool(pool: SqlitePool, disk_max_size: Bytes, disk_path: PathBuf) -> Arc<Self> {
let (db_tx, db_rx) = channel(128);
// Run db init
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut pool.acquire().await.unwrap())
.await
.unwrap();
// This is intentional. // This is intentional.
#[allow(clippy::cast_sign_loss)] #[allow(clippy::cast_sign_loss)]
let disk_cur_size = { let disk_cur_size = {
let mut conn = db_pool.acquire().await.unwrap(); let mut conn = pool.acquire().await.unwrap();
sqlx::query!("SELECT IFNULL(SUM(size), 0) AS size FROM Images") sqlx::query!("SELECT IFNULL(SUM(size), 0) AS size FROM Images")
.fetch_one(&mut conn) .fetch_one(&mut conn)
.await .await
@ -80,7 +82,7 @@ impl DiskCache {
tokio::spawn(db_listener( tokio::spawn(db_listener(
Arc::clone(&new_self), Arc::clone(&new_self),
db_rx, db_rx,
db_pool, pool,
disk_max_size.get() as u64 / 20 * 19, disk_max_size.get() as u64 / 20 * 19,
)); ));
@ -239,14 +241,9 @@ async fn handle_db_put(
// This is intentional. // This is intentional.
#[allow(clippy::cast_possible_wrap)] #[allow(clippy::cast_possible_wrap)]
let casted_size = size as i64; let casted_size = size as i64;
let query = sqlx::query!( let query = sqlx::query_file!("./db_queries/insert_image.sql", key, casted_size, now)
"insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing", .execute(transaction)
key, .await;
casted_size,
now,
)
.execute(transaction)
.await;
if let Err(e) = query { if let Err(e) = query {
warn!("Failed to add to db: {}", e); warn!("Failed to add to db: {}", e);
@ -369,6 +366,59 @@ impl CallbackCache for DiskCache {
} }
} }
#[cfg(test)]
mod disk_cache {
use std::error::Error;
use std::path::PathBuf;
use std::sync::atomic::Ordering;
use chrono::Utc;
use sqlx::SqlitePool;
use crate::units::Bytes;
use super::DiskCache;
#[tokio::test]
async fn db_is_initialized() -> Result<(), Box<dyn Error>> {
let conn = SqlitePool::connect("sqlite::memory:").await?;
let _cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await;
let res = sqlx::query("select * from Images").execute(&conn).await;
assert!(res.is_ok());
Ok(())
}
#[tokio::test]
async fn db_initializes_empty() -> Result<(), Box<dyn Error>> {
let conn = SqlitePool::connect("sqlite::memory:").await?;
let cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await;
assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 0);
Ok(())
}
#[tokio::test]
async fn db_can_load_from_existing() -> Result<(), Box<dyn Error>> {
let conn = SqlitePool::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&conn)
.await?;
let now = Utc::now();
sqlx::query_file!("./db_queries/insert_image.sql", "a", 4, now)
.execute(&conn)
.await?;
let now = Utc::now();
sqlx::query_file!("./db_queries/insert_image.sql", "b", 15, now)
.execute(&conn)
.await?;
let cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await;
assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 19);
Ok(())
}
}
#[cfg(test)] #[cfg(test)]
mod db { mod db {
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};

View file

@ -198,7 +198,7 @@ impl std::fmt::Debug for ClientSecret {
} }
} }
#[derive(Deserialize, Copy, Clone, Debug)] #[derive(Deserialize, Copy, Clone, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum CacheType { pub enum CacheType {
OnDisk, OnDisk,
@ -348,7 +348,69 @@ mod sample_yaml {
use crate::config::YamlArgs; use crate::config::YamlArgs;
#[test] #[test]
fn sample_yaml_parses() { fn parses() {
assert!(serde_yaml::from_str::<YamlArgs>(include_str!("../settings.sample.yaml")).is_ok()); assert!(serde_yaml::from_str::<YamlArgs>(include_str!("../settings.sample.yaml")).is_ok());
} }
} }
#[cfg(test)]
mod config {
use std::path::PathBuf;
use log::LevelFilter;
use tracing::level_filters::LevelFilter as TracingLevelFilter;
use crate::config::{CacheType, ClientSecret, Config, YamlExtendedOptions, YamlServerSettings};
use crate::units::{KilobitsPerSecond, Mebibytes, Port};
use super::{CliArgs, YamlArgs};
#[test]
fn cli_has_priority() {
let cli_config = CliArgs {
port: Port::new(1234),
memory_quota: Some(Mebibytes::new(10)),
disk_quota: Some(Mebibytes::new(10)),
cache_path: Some(PathBuf::from("a")),
network_speed: KilobitsPerSecond::new(10),
verbose: 1,
quiet: 0,
unstable_options: vec![],
override_upstream: None,
ephemeral_disk_encryption: true,
config_path: None,
cache_type: Some(CacheType::Lfu),
};
let yaml_args = YamlArgs {
max_cache_size_in_mebibytes: Mebibytes::new(50),
server_settings: YamlServerSettings {
secret: ClientSecret(String::new()),
port: Port::new(4321).expect("to work?"),
external_max_kilobits_per_second: KilobitsPerSecond::new(50).expect("to work?"),
external_port: None,
graceful_shutdown_wait_seconds: None,
hostname: None,
external_ip: None,
},
extended_options: Some(YamlExtendedOptions {
memory_quota: Some(Mebibytes::new(50)),
cache_type: Some(CacheType::Lru),
ephemeral_disk_encryption: Some(false),
enable_metrics: None,
logging_level: Some(LevelFilter::Error),
cache_path: Some(PathBuf::from("b")),
}),
};
let config = Config::from_cli_and_file(cli_config, yaml_args);
assert_eq!(Some(config.port), Port::new(1234));
assert_eq!(config.memory_quota, Mebibytes::new(10));
assert_eq!(config.disk_quota, Mebibytes::new(10));
assert_eq!(config.cache_path, PathBuf::from("a"));
assert_eq!(Some(config.network_speed), KilobitsPerSecond::new(10));
assert_eq!(config.log_level, TracingLevelFilter::DEBUG);
assert_eq!(config.ephemeral_disk_encryption, true);
assert_eq!(config.cache_type, CacheType::Lfu);
}
}

View file

@ -36,6 +36,7 @@ mod metrics;
mod ping; mod ping;
mod routes; mod routes;
mod state; mod state;
#[cfg(not(tarpaulin_include))]
mod stop; mod stop;
mod units; mod units;

View file

@ -5,13 +5,18 @@ use std::str::FromStr;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Wrapper type for a port number. /// Wrapper type for a port number.
#[derive(Serialize, Deserialize, Debug, Clone, Copy)] #[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub struct Port(NonZeroU16); pub struct Port(NonZeroU16);
impl Port { impl Port {
pub const fn get(self) -> u16 { pub const fn get(self) -> u16 {
self.0.get() self.0.get()
} }
#[cfg(test)]
pub fn new(amt: u16) -> Option<Self> {
NonZeroU16::new(amt).map(Self)
}
} }
impl Default for Port { impl Default for Port {
@ -37,6 +42,13 @@ impl Display for Port {
#[derive(Copy, Clone, Serialize, Deserialize, Default, Debug, Hash, Eq, PartialEq)] #[derive(Copy, Clone, Serialize, Deserialize, Default, Debug, Hash, Eq, PartialEq)]
pub struct Mebibytes(usize); pub struct Mebibytes(usize);
impl Mebibytes {
#[cfg(test)]
pub fn new(size: usize) -> Self {
Mebibytes(size)
}
}
impl FromStr for Mebibytes { impl FromStr for Mebibytes {
type Err = ParseIntError; type Err = ParseIntError;
@ -45,7 +57,7 @@ impl FromStr for Mebibytes {
} }
} }
pub struct Bytes(usize); pub struct Bytes(pub usize);
impl Bytes { impl Bytes {
pub const fn get(&self) -> usize { pub const fn get(&self) -> usize {
@ -62,6 +74,13 @@ impl From<Mebibytes> for Bytes {
#[derive(Copy, Clone, Deserialize, Debug, Hash, Eq, PartialEq)] #[derive(Copy, Clone, Deserialize, Debug, Hash, Eq, PartialEq)]
pub struct KilobitsPerSecond(NonZeroU64); pub struct KilobitsPerSecond(NonZeroU64);
impl KilobitsPerSecond {
#[cfg(test)]
pub fn new(size: u64) -> Option<Self> {
NonZeroU64::new(size).map(Self)
}
}
impl FromStr for KilobitsPerSecond { impl FromStr for KilobitsPerSecond {
type Err = ParseIntError; type Err = ParseIntError;