diff --git a/src/fs.rs b/src/fs.rs index 91d9fb5..4aae4b8 100644 --- a/src/fs.rs +++ b/src/fs.rs @@ -11,67 +11,96 @@ use futures::{Future, Stream, StreamExt}; use once_cell::sync::Lazy; use parking_lot::RwLock; use reqwest::Error; -use tokio::fs::{remove_file, File, OpenOptions}; +use tokio::fs::{remove_file, File}; use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf}; use tokio::time::Sleep; +/// Keeps track of files that are currently being written to. +/// +/// Why is this necessary? Consider the following situation: +/// +/// Client A requests file `foo.png`. We construct a transparent file stream, +/// and now the file is being streamed into and from. +/// +/// Client B requests the same file `foo.png`. A naive implementation would +/// attempt to either read directly the file as it sees the file existing. This +/// is problematic as the file could still be written to. If Client B catches +/// up to Client A's request, then Client B could receive a broken image, as it +/// thinks it's done reading the file. +/// +/// We effectively use WRITING_STATUS as a status relay to ensure concurrent +/// reads to the file while it's being written to will wait for writing to be +/// completed. static WRITING_STATUS: Lazy>>> = Lazy::new(|| RwLock::new(HashMap::new())); +/// Tries to read from the file, returning a byte stream if it exists +pub async fn read_file( + path: &Path, +) -> Option>, std::io::Error>> { + if path.exists() { + if let Some(status) = WRITING_STATUS.read().get(path) { + Some(FromFsStream::new(path, Arc::clone(status)).await) + } else { + Some(FromFsStream::new(path, Arc::new(CacheStatus::done())).await) + } + } else { + None + } +} + +/// Maps the input byte stream into one that writes to disk instead, returning +/// a stream that reads from disk instead. pub async fn transparent_file_stream( path: &Path, mut byte_stream: impl Stream> + Unpin + Send + 'static, ) -> Result>, std::io::Error> { - if let Some(arc) = WRITING_STATUS.read().get(path) { - FromFsStream::new(path, Arc::clone(arc)).await - } else { - let done_writing_flag = Arc::new(CacheStatus::new()); + let done_writing_flag = Arc::new(CacheStatus::new()); - { - let mut write_lock = WRITING_STATUS.write(); - File::create(path).await?; // we need to make sure the file exists and is truncated. - write_lock.insert(path.to_path_buf(), Arc::clone(&done_writing_flag)); + let mut file = { + let mut write_lock = WRITING_STATUS.write(); + let file = File::create(path).await?; // we need to make sure the file exists and is truncated. + write_lock.insert(path.to_path_buf(), Arc::clone(&done_writing_flag)); + file + }; + + let write_flag = Arc::clone(&done_writing_flag); + // need owned variant because async lifetime + let path_buf = path.to_path_buf(); + tokio::spawn(async move { + let path_buf = path_buf; // moves path buf into async + let mut was_errored = false; + while let Some(bytes) = byte_stream.next().await { + match bytes { + Ok(bytes) => file.write_all(&bytes).await?, + Err(_) => was_errored = true, + } } - let write_flag = Arc::clone(&done_writing_flag); - // need owned variant because async lifetime - let mut file = OpenOptions::new().write(true).open(path).await?; - let path_buf = path.to_path_buf(); - tokio::spawn(async move { - let path_buf = path_buf; // moves path buf into async - let mut was_errored = false; - while let Some(bytes) = byte_stream.next().await { - match bytes { - Ok(bytes) => file.write_all(&bytes).await?, - Err(_) => was_errored = true, - } - } + if was_errored { + // It's ok if the deleting the file fails, since we truncate on + // create anyways + let _ = remove_file(&path_buf).await; + } else { + file.flush().await?; + file.sync_all().await?; // we need metadata + } - if was_errored { - // It's ok if the deleting the file fails, since we truncate on - // create anyways - let _ = remove_file(&path_buf).await; - } else { - file.flush().await?; - file.sync_all().await?; - } + let mut write_lock = WRITING_STATUS.write(); + // This needs to be written atomically with the write lock, else + // it's possible we have an inconsistent state + if was_errored { + write_flag.store(WritingStatus::Error); + } else { + write_flag.store(WritingStatus::Done); + } + write_lock.remove(&path_buf); - let mut write_lock = WRITING_STATUS.write(); - // This needs to be written atomically with the write lock, else - // it's possible we have an inconsistent state - if was_errored { - write_flag.store(WritingStatus::Error); - } else { - write_flag.store(WritingStatus::Done); - } - write_lock.remove(&path_buf); + // We don't ever check this, so the return value doesn't matter + Ok::<_, std::io::Error>(()) + }); - // We don't ever check this, so the return value doesn't matter - Ok::<_, std::io::Error>(()) - }); - - FromFsStream::new(path, Arc::clone(&done_writing_flag)).await - } + Ok(FromFsStream::new(path, done_writing_flag).await?) } struct FromFsStream { @@ -125,6 +154,11 @@ impl CacheStatus { Self(AtomicU8::new(WritingStatus::NotDone as u8)) } + #[inline] + const fn done() -> Self { + Self(AtomicU8::new(WritingStatus::Done as u8)) + } + #[inline] fn store(&self, status: WritingStatus) { self.0.store(status as u8, Ordering::Release);