diff --git a/src/cache/disk.rs b/src/cache/disk.rs index 1ba4546..75f7ab0 100644 --- a/src/cache/disk.rs +++ b/src/cache/disk.rs @@ -33,6 +33,7 @@ pub struct DiskCache { db_update_channel_sender: Sender, } +#[derive(Debug)] enum DbMessage { Get(Arc), Put(Arc, u64), @@ -159,21 +160,10 @@ async fn db_listener( let on_disk_size = (cache.disk_cur_size.load(Ordering::Acquire) + 4095) / 4096 * 4096; if on_disk_size >= max_on_disk_size { - let mut conn = match db_pool.acquire().await { - Ok(conn) => conn, - Err(e) => { - error!( - "Failed to get a DB connection and cannot prune disk cache: {}", - e - ); - continue; - } - }; - let items = { let request = sqlx::query!("select id, size from Images order by accessed asc limit 1000") - .fetch_all(&mut conn) + .fetch_all(&db_pool) .await; match request { Ok(items) => items, @@ -408,6 +398,62 @@ impl CallbackCache for DiskCache { } } +#[cfg(test)] +mod db_listener { + use super::{db_listener, DbMessage}; + use crate::DiskCache; + use futures::TryStreamExt; + use sqlx::{Row, SqlitePool}; + use std::error::Error; + use std::path::PathBuf; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use tokio::sync::mpsc::channel; + + #[tokio::test] + async fn can_handle_multiple_events() -> Result<(), Box> { + let (mut cache, rx) = DiskCache::in_memory(); + let (mut tx, _) = channel(1); + // Swap the tx with the new one, else the receiver will never end + std::mem::swap(&mut cache.db_update_channel_sender, &mut tx); + assert_eq!(tx.capacity(), 128); + let cache = Arc::new(cache); + let db = SqlitePool::connect("sqlite::memory:").await?; + sqlx::query_file!("./db_queries/init.sql") + .execute(&db) + .await?; + + // Populate the queue with messages + for c in 'a'..='z' { + tx.send(DbMessage::Put(Arc::new(PathBuf::from(c.to_string())), 10)) + .await?; + tx.send(DbMessage::Get(Arc::new(PathBuf::from(c.to_string())))) + .await?; + } + + // Explicitly close the channel so that the listener terminates + std::mem::drop(tx); + + db_listener(cache, rx, db.clone(), u64::MAX).await; + + let count = Arc::new(AtomicUsize::new(0)); + sqlx::query("select * from Images") + .fetch(&db) + .try_for_each_concurrent(None, |row| { + let count = Arc::clone(&count); + async move { + assert_eq!(row.get::("size"), 10); + count.fetch_add(1, Ordering::Release); + Ok(()) + } + }) + .await?; + + assert_eq!(count.load(Ordering::Acquire), 26); + Ok(()) + } +} + #[cfg(test)] mod remove_file_handler { diff --git a/src/metrics.rs b/src/metrics.rs index e219766..4db39e5 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -132,7 +132,7 @@ pub async fn load_geo_ip_data(license_key: ClientSecret) -> Result<(), DbLoadErr // Result literally cannot panic here, buuuuuut if it does we'll panic GEOIP_DATABASE .set(maxminddb::Reader::open_readfile(DB_PATH)?) - .map_err(|_| ()) + .map_err(|_| ()) // Need to map err here or can't expect .expect("to set the geo ip db singleton"); Ok(())