From 2ace8d3d669fc795dc5274f1f04cc32f89f6293d Mon Sep 17 00:00:00 2001 From: Edward Shen Date: Tue, 13 Jul 2021 13:16:44 -0400 Subject: [PATCH] Partial rewrite of encrypted writer --- src/cache/disk.rs | 7 +- src/cache/fs.rs | 260 +++++++++++++++++++++++++++++++++------------- src/cache/mod.rs | 4 +- 3 files changed, 196 insertions(+), 75 deletions(-) diff --git a/src/cache/disk.rs b/src/cache/disk.rs index 589f085..99729bf 100644 --- a/src/cache/disk.rs +++ b/src/cache/disk.rs @@ -18,12 +18,13 @@ use sqlx::{ConnectOptions, Sqlite, SqlitePool, Transaction}; use tokio::fs::remove_file; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio_stream::wrappers::ReceiverStream; -use tracing::{debug, error, warn}; +use tracing::{debug, error, instrument, warn}; use crate::units::Bytes; use super::{Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata}; +#[derive(Debug)] pub struct DiskCache { disk_path: PathBuf, disk_cur_size: AtomicU64, @@ -215,6 +216,7 @@ async fn db_listener( } } +#[instrument(level = "debug", skip(transaction))] async fn handle_db_get(entry: &Path, transaction: &mut Transaction<'_, Sqlite>) { let hash = if let Ok(hash) = Md5Hash::try_from(entry) { hash @@ -242,6 +244,7 @@ async fn handle_db_get(entry: &Path, transaction: &mut Transaction<'_, Sqlite>) } } +#[instrument(level = "debug", skip(transaction, cache))] async fn handle_db_put( entry: &Path, size: u64, @@ -264,7 +267,7 @@ async fn handle_db_put( .await; if let Err(e) = query { - warn!("Failed to add {:?} to db: {}", key, e); + warn!("Failed to add to db: {}", e); } cache.disk_cur_size.fetch_add(size, Ordering::Release); diff --git a/src/cache/fs.rs b/src/cache/fs.rs index 5bffcf4..03acb52 100644 --- a/src/cache/fs.rs +++ b/src/cache/fs.rs @@ -22,6 +22,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; use actix_web::error::PayloadError; +use async_trait::async_trait; use bytes::Bytes; use futures::Future; use serde::{Deserialize, Serialize}; @@ -29,10 +30,13 @@ use sodiumoxide::crypto::secretstream::{ Header, Pull, Push, Stream as SecretStream, Tag, HEADERBYTES, }; use tokio::fs::{create_dir_all, remove_file, File}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::io::{ + AsyncBufRead, AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufReader, + ReadBuf, +}; use tokio::sync::mpsc::Sender; use tokio_util::codec::{BytesCodec, FramedRead}; -use tracing::{debug, warn}; +use tracing::{debug, instrument, warn}; use super::compat::LegacyImageMetadata; use super::{CacheKey, ImageMetadata, InnerStream, ENCRYPTION_KEY}; @@ -53,6 +57,7 @@ pub(super) async fn read_file_from_path( read_file(std::fs::File::open(path).ok()?).await } +#[instrument(level = "debug")] async fn read_file( file: std::fs::File, ) -> Option, ImageMetadata), std::io::Error>> { @@ -73,7 +78,7 @@ async fn read_file( let parsed_metadata; let mut maybe_header = None; - let mut reader: Option>> = None; + let mut reader: Option>> = None; if let Ok(metadata) = maybe_metadata { // image is decrypted if ENCRYPTION_KEY.get().is_some() { @@ -83,13 +88,13 @@ async fn read_file( return None; } - reader = Some(Box::pin(File::from_std(file_1))); + reader = Some(Box::pin(BufReader::new(File::from_std(file_1)))); parsed_metadata = Some(metadata); debug!("Found not encrypted file"); } else { + debug!("metadata read failed, trying to see if it's encrypted"); let mut file = File::from_std(file_1); file.seek(SeekFrom::Start(0)).await.ok()?; - let file_0 = file.try_clone().await.unwrap(); // image is encrypted or corrupt @@ -102,6 +107,10 @@ async fn read_file( return None; } + dbg!(&header_bytes); + + debug!("header bytes: {:x?}", header_bytes); + let file_header = if let Some(header) = Header::from_slice(&header_bytes) { header } else { @@ -110,6 +119,7 @@ async fn read_file( }; let secret_stream = if let Ok(stream) = SecretStream::init_pull(&file_header, key) { + debug!("Valid header found!"); stream } else { warn!("Failed to init secret stream with key and header. Assuming corrupted!"); @@ -117,23 +127,24 @@ async fn read_file( }; maybe_header = Some(file_header); - reader = Some(Box::pin(EncryptedDiskReader::new(file, secret_stream))); } - let mut deserializer = serde_json::Deserializer::from_reader(file_0.into_std().await); - parsed_metadata = ImageMetadata::deserialize(&mut deserializer).ok(); - - if parsed_metadata.is_some() { - debug!("Found encrypted file"); - } + parsed_metadata = if let Some(reader) = reader.as_mut() { + debug!("trying to read metadata"); + dbg!(reader.as_mut().metadata().await.ok()) + } else { + debug!("Failed to read encrypted data"); + None + }; } // parsed_metadata is either set or unset here. If it's set then we // successfully decoded the data; otherwise the file is garbage. if let Some(reader) = reader { - let stream = InnerStream::Completed(FramedRead::new(reader, BytesCodec::new())); + let stream = + InnerStream::Completed(FramedRead::new(reader as Pin>, BytesCodec::new())); parsed_metadata.map(|metadata| Ok((stream, maybe_header, metadata))) } else { debug!("Reader was invalid, file is corrupt"); @@ -144,7 +155,11 @@ async fn read_file( struct EncryptedDiskReader { file: Pin>, stream: SecretStream, - buf: Vec, + // Bytes we read from the secret stream + read_buffer: Box<[u8; 4096]>, + decryption_buffer: Vec, + // Bytes we write out to the read buf + write_buffer: Vec, } impl EncryptedDiskReader { @@ -152,51 +167,148 @@ impl EncryptedDiskReader { Self { file: Box::pin(file), stream, - buf: vec![], + read_buffer: Box::new([0; 4096]), + decryption_buffer: Vec::with_capacity(4096), + write_buffer: Vec::with_capacity(4096), } } } +#[async_trait] +pub trait MetadataFetch: AsyncBufRead { + async fn metadata(mut self: Pin<&mut Self>) -> Result; +} + +#[async_trait] +impl MetadataFetch for R { + #[inline] + async fn metadata(mut self: Pin<&mut Self>) -> Result { + MetadataFuture(&mut self).await + } +} + impl AsyncRead for EncryptedDiskReader { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let cursor_start = buf.filled().len(); + // First, try and read from the underlying file. + let pinned_self = Pin::into_inner(self); + let mut read_buf = ReadBuf::new(pinned_self.read_buffer.as_mut()); + let read_res = pinned_self.file.as_mut().poll_read(cx, &mut read_buf); - let res = self.as_mut().file.as_mut().poll_read(cx, buf); - if res.is_pending() { + // If the file + if read_res.is_pending() { return Poll::Pending; } - let cursor_new = buf.filled().len(); - - // pull_to_vec internally calls vec.clear() and vec.reserve(). Generally - // speaking we should be reading about the same amount of data each time - // so we shouldn't experience too much of a slow down w.r.t resizing the - // buffer... - let new_self = Pin::into_inner(self); - new_self + if pinned_self .stream - .pull_to_vec( - &buf.filled()[cursor_start..cursor_new], - None, - &mut new_self.buf, - ) - .unwrap(); - - // data is strictly smaller than the encrypted stream, since you need to - // encode tags as well, so this is always safe. - - // rewrite encrypted data into decrypted data - let buffer = buf.filled_mut(); - for (old, new) in buffer[cursor_start..].iter_mut().zip(&new_self.buf) { - *old = *new; + .pull_to_vec(read_buf.filled(), None, &mut pinned_self.decryption_buffer) + .is_err() + { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to decrypt data", + ))); } - buf.set_filled(cursor_start + new_self.buf.len()); - res + pinned_self + .write_buffer + .extend_from_slice(&pinned_self.decryption_buffer); + + // find the amount of bytes we can put into the output buffer. + let bytes_to_write = buf.remaining().min(pinned_self.write_buffer.len()); + buf.put_slice( + &pinned_self + .write_buffer + .drain(..bytes_to_write) + .collect::>(), + ); + + Poll::Ready(Ok(())) + } +} + +impl AsyncBufRead for EncryptedDiskReader { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // First, try and read from the underlying file. + let pinned_self = Pin::into_inner(self); + let mut read_buf = ReadBuf::new(pinned_self.read_buffer.as_mut()); + let read_res = pinned_self.file.as_mut().poll_read(cx, &mut read_buf); + + // If the file + if read_res.is_pending() { + return Poll::Pending; + } + + dbg!(&read_buf.filled().len()); + if pinned_self + .stream + .pull_to_vec(read_buf.filled(), None, &mut pinned_self.decryption_buffer) + .is_err() + { + dbg!(line!()); + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to decrypt data", + ))); + } + + pinned_self + .write_buffer + .extend_from_slice(&pinned_self.decryption_buffer); + + Poll::Ready(Ok(&pinned_self.write_buffer)) + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + self.as_mut().write_buffer.drain(..amt); + } +} + +struct MetadataFuture<'a, R>(&'a mut Pin<&'a mut R>); + +impl<'a, R: AsyncBufRead> Future for MetadataFuture<'a, R> { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut filled = 0; + loop { + let buf = match self.0.as_mut().poll_fill_buf(cx) { + Poll::Ready(Ok(buffer)) => buffer, + Poll::Ready(Err(e)) => { + dbg!(e); + return Poll::Ready(Err(())); + } + Poll::Pending => return Poll::Pending, + }; + + if filled == buf.len() { + dbg!(line!()); + return Poll::Ready(Err(())); + } else { + filled = buf.len(); + } + + let mut reader = serde_json::Deserializer::from_slice(buf).into_iter(); + let (res, bytes_consumed) = match reader.next() { + Some(Ok(metadata)) => (Poll::Ready(Ok(metadata)), reader.byte_offset()), + Some(Err(e)) if e.is_eof() => { + continue; + } + Some(Err(_)) | None => { + dbg!(line!()); + return Poll::Ready(Err(())); + } + }; + + // This needs to be outside the loop because we need to drop the + // reader ref, since that depends on a mut self. + self.0.as_mut().consume(bytes_consumed); + return res; + } } } @@ -216,35 +328,27 @@ where Fut: 'static + Send + Sync + Future, DbCallback: 'static + Send + Sync + FnOnce(u64) -> Fut, { - let file = { + let mut file = { let parent = path.parent().expect("The path to have a parent"); create_dir_all(parent).await?; let file = File::create(path).await?; // we need to make sure the file exists and is truncated. file }; + let mut writer: Pin> = if let Some((enc, header)) = ENCRYPTION_KEY + .get() + .map(|key| SecretStream::init_push(key).expect("Failed to init enc stream")) + { + file.write_all(dbg!(header.as_ref())).await?; + Box::pin(EncryptedDiskWriter::new(file, enc)) + } else { + Box::pin(file) + }; + let metadata_string = serde_json::to_string(&metadata).expect("serialization to work"); let metadata_size = metadata_string.len(); - let (mut writer, maybe_header): (Pin>, _) = - if let Some((enc, header)) = ENCRYPTION_KEY - .get() - .map(|key| SecretStream::init_push(key).expect("Failed to init enc stream")) - { - (Box::pin(EncryptedDiskWriter::new(file, enc)), Some(header)) - } else { - (Box::pin(file), None) - }; - - let mut error = if let Some(header) = maybe_header { - writer.write_all(header.as_ref()).await.err() - } else { - None - }; - - if error.is_none() { - error = writer.write_all(metadata_string.as_bytes()).await.err(); - } + let mut error = writer.write_all(metadata_string.as_bytes()).await.err(); if error.is_none() { error = writer.write_all(&bytes).await.err(); } @@ -300,13 +404,11 @@ impl AsyncWrite for EncryptedDiskWriter { buf: &[u8], ) -> Poll> { let new_self = Pin::into_inner(self); - { - let encryption_buffer = &mut new_self.encryption_buffer; - if let Some(stream) = new_self.stream.as_mut() { - stream - .push_to_vec(buf, None, Tag::Message, encryption_buffer) - .expect("Failed to write encrypted data to buffer"); - } + + if let Some(stream) = new_self.stream.as_mut() { + stream + .push_to_vec(buf, None, Tag::Message, &mut new_self.encryption_buffer) + .expect("Failed to write encrypted data to buffer"); } new_self.write_buffer.extend(&new_self.encryption_buffer); @@ -353,10 +455,26 @@ impl AsyncWrite for EncryptedDiskWriter { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - self.as_mut() + let maybe_bytes = self + .as_mut() .stream .take() .map(|stream| stream.finalize(None)); + + // If we've yet to finalize the stream, finalize it and add the bytes to + // our writer buffer. + if let Some(Ok(bytes)) = maybe_bytes { + // We just need to push it into the buffer, we don't really care + // about the result, since we can check later + let _ = self.as_mut().poll_write(cx, &bytes); + } + + // Now wait for us to fully flush out our write buffer + if !self.write_buffer.is_empty() { + return self.poll_flush(cx); + } + + // Write buffer is empty, flush file self.file.as_mut().poll_shutdown(cx) } } diff --git a/src/cache/mod.rs b/src/cache/mod.rs index be37b54..36ee35c 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -15,7 +15,6 @@ use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; use sodiumoxide::crypto::secretstream::{Header, Key, Pull, Stream as SecretStream}; use thiserror::Error; -use tokio::io::AsyncRead; use tokio::sync::mpsc::Sender; use tokio_util::codec::{BytesCodec, FramedRead}; @@ -24,6 +23,7 @@ pub use fs::UpstreamError; pub use mem::MemoryCache; use self::compat::LegacyImageMetadata; +use self::fs::MetadataFetch; pub static ENCRYPTION_KEY: OnceCell = OnceCell::new(); @@ -277,7 +277,7 @@ impl Stream for CacheStream { pub(self) enum InnerStream { Memory(MemStream), - Completed(FramedRead>, BytesCodec>), + Completed(FramedRead>, BytesCodec>), } impl From for InnerStream {