diff --git a/src/cache/disk_cache.rs b/src/cache/disk_cache.rs index 87985db..e86aef2 100644 --- a/src/cache/disk_cache.rs +++ b/src/cache/disk_cache.rs @@ -8,7 +8,7 @@ use std::sync::Arc; use async_trait::async_trait; use bytes::Bytes; use futures::StreamExt; -use log::{warn, LevelFilter}; +use log::{error, warn, LevelFilter}; use sqlx::sqlite::SqliteConnectOptions; use sqlx::{ConnectOptions, SqlitePool}; use tokio::fs::remove_file; @@ -36,7 +36,7 @@ impl DiskCache { pub async fn new(disk_max_size: u64, disk_path: PathBuf) -> Arc> { let (db_tx, db_rx) = channel(128); let db_pool = { - let db_url = format!("sqlite:{}/metadata.sqlite", disk_path.to_str().unwrap()); + let db_url = format!("sqlite:{}/metadata.sqlite", disk_path.to_string_lossy()); let mut options = SqliteConnectOptions::from_str(&db_url) .unwrap() .create_if_missing(true); @@ -80,7 +80,13 @@ async fn db_listener( let mut recv_stream = ReceiverStream::new(db_rx).ready_chunks(128); while let Some(messages) = recv_stream.next().await { let now = chrono::Utc::now(); - let mut transaction = db_pool.begin().await.unwrap(); + let mut transaction = match db_pool.begin().await { + Ok(transaction) => transaction, + Err(e) => { + error!("Failed to start a transaction to DB, cannot update DB. Disk cache may be losing track of files! {}", e); + continue; + } + }; for message in messages { match message { DbMessage::Get(entry) => { @@ -111,15 +117,42 @@ async fn db_listener( } } } - transaction.commit().await.unwrap(); + + if let Err(e) = transaction.commit().await { + error!( + "Failed to commit transaction to DB. Disk cache may be losing track of files! {}", + e + ); + } if cache.on_disk_size() >= max_on_disk_size { - let mut conn = db_pool.acquire().await.unwrap(); - let items = - sqlx::query!("select id, size from Images order by accessed asc limit 1000") - .fetch_all(&mut conn) - .await - .unwrap(); + let mut conn = match db_pool.acquire().await { + Ok(conn) => conn, + Err(e) => { + error!( + "Failed to get a DB connection and cannot prune disk cache: {}", + e + ); + continue; + } + }; + + let items = { + let request = + sqlx::query!("select id, size from Images order by accessed asc limit 1000") + .fetch_all(&mut conn) + .await; + match request { + Ok(items) => items, + Err(e) => { + error!( + "Failed to fetch oldest images and cannot prune disk cache: {}", + e + ); + continue; + } + } + }; let mut size_freed = 0; for item in items { diff --git a/src/cache/fs.rs b/src/cache/fs.rs index efc8038..9154c28 100644 --- a/src/cache/fs.rs +++ b/src/cache/fs.rs @@ -84,14 +84,14 @@ pub async fn write_file< let mut file = { let mut write_lock = WRITING_STATUS.write().await; - let parent = path.parent().unwrap(); + let parent = path.parent().expect("The path to have a parent"); create_dir_all(parent).await?; let file = File::create(path).await?; // we need to make sure the file exists and is truncated. write_lock.insert(path.to_path_buf(), rx.clone()); file }; - let metadata_string = serde_json::to_string(&metadata).unwrap(); + let metadata_string = serde_json::to_string(&metadata).expect("serialization to work"); let metadata_size = metadata_string.len(); // need owned variant because async lifetime let path_buf = path.to_path_buf(); @@ -151,9 +151,8 @@ pub async fn write_file< } tokio::spawn(db_callback(bytes_written)); - if accumulate { + if let Some(sender) = on_complete { tokio::spawn(async move { - let sender = on_complete.unwrap(); sender .send(( cache_key, @@ -244,7 +243,7 @@ impl Stream for ConcurrentFsStream { if let Poll::Ready(Some(WritingStatus::Done(n))) = self.receiver.as_mut().poll_next_unpin(cx) { - self.bytes_total = Some(NonZeroU32::new(n).unwrap()) + self.bytes_total = Some(NonZeroU32::new(n).expect("Stored a 0 byte image?")) } // Okay, now we know if we've read enough bytes or not. If the diff --git a/src/ping.rs b/src/ping.rs index 5ee1968..613528f 100644 --- a/src/ping.rs +++ b/src/ping.rs @@ -39,7 +39,9 @@ impl<'a> Request<'a> { port: config.port, disk_space: config.disk_quota, network_speed: config.network_speed, - build_version: client_api_version!().parse().unwrap(), + build_version: client_api_version!() + .parse() + .expect("to parse the build version"), tls_created_at: Some(state.0.read().tls_config.created_at.clone()), } } @@ -53,7 +55,9 @@ impl<'a> From<(&'a str, &CliArgs)> for Request<'a> { port: config.port, disk_space: config.disk_quota, network_speed: config.network_speed, - build_version: client_api_version!().parse().unwrap(), + build_version: client_api_version!() + .parse() + .expect("to parse the build version"), tls_created_at: None, } }