Partial rewrite of encrypted writer

This commit is contained in:
Edward Shen 2021-07-13 13:16:44 -04:00
parent 160f369a72
commit 2ace8d3d66
Signed by: edward
GPG key ID: 19182661E818369F
3 changed files with 196 additions and 75 deletions

7
src/cache/disk.rs vendored
View file

@ -18,12 +18,13 @@ use sqlx::{ConnectOptions, Sqlite, SqlitePool, Transaction};
use tokio::fs::remove_file;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, warn};
use tracing::{debug, error, instrument, warn};
use crate::units::Bytes;
use super::{Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata};
#[derive(Debug)]
pub struct DiskCache {
disk_path: PathBuf,
disk_cur_size: AtomicU64,
@ -215,6 +216,7 @@ async fn db_listener(
}
}
#[instrument(level = "debug", skip(transaction))]
async fn handle_db_get(entry: &Path, transaction: &mut Transaction<'_, Sqlite>) {
let hash = if let Ok(hash) = Md5Hash::try_from(entry) {
hash
@ -242,6 +244,7 @@ async fn handle_db_get(entry: &Path, transaction: &mut Transaction<'_, Sqlite>)
}
}
#[instrument(level = "debug", skip(transaction, cache))]
async fn handle_db_put(
entry: &Path,
size: u64,
@ -264,7 +267,7 @@ async fn handle_db_put(
.await;
if let Err(e) = query {
warn!("Failed to add {:?} to db: {}", key, e);
warn!("Failed to add to db: {}", e);
}
cache.disk_cur_size.fetch_add(size, Ordering::Release);

260
src/cache/fs.rs vendored
View file

@ -22,6 +22,7 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use actix_web::error::PayloadError;
use async_trait::async_trait;
use bytes::Bytes;
use futures::Future;
use serde::{Deserialize, Serialize};
@ -29,10 +30,13 @@ use sodiumoxide::crypto::secretstream::{
Header, Pull, Push, Stream as SecretStream, Tag, HEADERBYTES,
};
use tokio::fs::{create_dir_all, remove_file, File};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::io::{
AsyncBufRead, AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufReader,
ReadBuf,
};
use tokio::sync::mpsc::Sender;
use tokio_util::codec::{BytesCodec, FramedRead};
use tracing::{debug, warn};
use tracing::{debug, instrument, warn};
use super::compat::LegacyImageMetadata;
use super::{CacheKey, ImageMetadata, InnerStream, ENCRYPTION_KEY};
@ -53,6 +57,7 @@ pub(super) async fn read_file_from_path(
read_file(std::fs::File::open(path).ok()?).await
}
#[instrument(level = "debug")]
async fn read_file(
file: std::fs::File,
) -> Option<Result<(InnerStream, Option<Header>, ImageMetadata), std::io::Error>> {
@ -73,7 +78,7 @@ async fn read_file(
let parsed_metadata;
let mut maybe_header = None;
let mut reader: Option<Pin<Box<dyn AsyncRead + Send>>> = None;
let mut reader: Option<Pin<Box<dyn MetadataFetch + Send>>> = None;
if let Ok(metadata) = maybe_metadata {
// image is decrypted
if ENCRYPTION_KEY.get().is_some() {
@ -83,13 +88,13 @@ async fn read_file(
return None;
}
reader = Some(Box::pin(File::from_std(file_1)));
reader = Some(Box::pin(BufReader::new(File::from_std(file_1))));
parsed_metadata = Some(metadata);
debug!("Found not encrypted file");
} else {
debug!("metadata read failed, trying to see if it's encrypted");
let mut file = File::from_std(file_1);
file.seek(SeekFrom::Start(0)).await.ok()?;
let file_0 = file.try_clone().await.unwrap();
// image is encrypted or corrupt
@ -102,6 +107,10 @@ async fn read_file(
return None;
}
dbg!(&header_bytes);
debug!("header bytes: {:x?}", header_bytes);
let file_header = if let Some(header) = Header::from_slice(&header_bytes) {
header
} else {
@ -110,6 +119,7 @@ async fn read_file(
};
let secret_stream = if let Ok(stream) = SecretStream::init_pull(&file_header, key) {
debug!("Valid header found!");
stream
} else {
warn!("Failed to init secret stream with key and header. Assuming corrupted!");
@ -117,23 +127,24 @@ async fn read_file(
};
maybe_header = Some(file_header);
reader = Some(Box::pin(EncryptedDiskReader::new(file, secret_stream)));
}
let mut deserializer = serde_json::Deserializer::from_reader(file_0.into_std().await);
parsed_metadata = ImageMetadata::deserialize(&mut deserializer).ok();
if parsed_metadata.is_some() {
debug!("Found encrypted file");
}
parsed_metadata = if let Some(reader) = reader.as_mut() {
debug!("trying to read metadata");
dbg!(reader.as_mut().metadata().await.ok())
} else {
debug!("Failed to read encrypted data");
None
};
}
// parsed_metadata is either set or unset here. If it's set then we
// successfully decoded the data; otherwise the file is garbage.
if let Some(reader) = reader {
let stream = InnerStream::Completed(FramedRead::new(reader, BytesCodec::new()));
let stream =
InnerStream::Completed(FramedRead::new(reader as Pin<Box<_>>, BytesCodec::new()));
parsed_metadata.map(|metadata| Ok((stream, maybe_header, metadata)))
} else {
debug!("Reader was invalid, file is corrupt");
@ -144,7 +155,11 @@ async fn read_file(
struct EncryptedDiskReader {
file: Pin<Box<File>>,
stream: SecretStream<Pull>,
buf: Vec<u8>,
// Bytes we read from the secret stream
read_buffer: Box<[u8; 4096]>,
decryption_buffer: Vec<u8>,
// Bytes we write out to the read buf
write_buffer: Vec<u8>,
}
impl EncryptedDiskReader {
@ -152,51 +167,148 @@ impl EncryptedDiskReader {
Self {
file: Box::pin(file),
stream,
buf: vec![],
read_buffer: Box::new([0; 4096]),
decryption_buffer: Vec::with_capacity(4096),
write_buffer: Vec::with_capacity(4096),
}
}
}
#[async_trait]
pub trait MetadataFetch: AsyncBufRead {
async fn metadata(mut self: Pin<&mut Self>) -> Result<ImageMetadata, ()>;
}
#[async_trait]
impl<R: AsyncBufRead + Send> MetadataFetch for R {
#[inline]
async fn metadata(mut self: Pin<&mut Self>) -> Result<ImageMetadata, ()> {
MetadataFuture(&mut self).await
}
}
impl AsyncRead for EncryptedDiskReader {
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let cursor_start = buf.filled().len();
// First, try and read from the underlying file.
let pinned_self = Pin::into_inner(self);
let mut read_buf = ReadBuf::new(pinned_self.read_buffer.as_mut());
let read_res = pinned_self.file.as_mut().poll_read(cx, &mut read_buf);
let res = self.as_mut().file.as_mut().poll_read(cx, buf);
if res.is_pending() {
// If the file
if read_res.is_pending() {
return Poll::Pending;
}
let cursor_new = buf.filled().len();
// pull_to_vec internally calls vec.clear() and vec.reserve(). Generally
// speaking we should be reading about the same amount of data each time
// so we shouldn't experience too much of a slow down w.r.t resizing the
// buffer...
let new_self = Pin::into_inner(self);
new_self
if pinned_self
.stream
.pull_to_vec(
&buf.filled()[cursor_start..cursor_new],
None,
&mut new_self.buf,
)
.unwrap();
// data is strictly smaller than the encrypted stream, since you need to
// encode tags as well, so this is always safe.
// rewrite encrypted data into decrypted data
let buffer = buf.filled_mut();
for (old, new) in buffer[cursor_start..].iter_mut().zip(&new_self.buf) {
*old = *new;
.pull_to_vec(read_buf.filled(), None, &mut pinned_self.decryption_buffer)
.is_err()
{
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to decrypt data",
)));
}
buf.set_filled(cursor_start + new_self.buf.len());
res
pinned_self
.write_buffer
.extend_from_slice(&pinned_self.decryption_buffer);
// find the amount of bytes we can put into the output buffer.
let bytes_to_write = buf.remaining().min(pinned_self.write_buffer.len());
buf.put_slice(
&pinned_self
.write_buffer
.drain(..bytes_to_write)
.collect::<Vec<_>>(),
);
Poll::Ready(Ok(()))
}
}
impl AsyncBufRead for EncryptedDiskReader {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
// First, try and read from the underlying file.
let pinned_self = Pin::into_inner(self);
let mut read_buf = ReadBuf::new(pinned_self.read_buffer.as_mut());
let read_res = pinned_self.file.as_mut().poll_read(cx, &mut read_buf);
// If the file
if read_res.is_pending() {
return Poll::Pending;
}
dbg!(&read_buf.filled().len());
if pinned_self
.stream
.pull_to_vec(read_buf.filled(), None, &mut pinned_self.decryption_buffer)
.is_err()
{
dbg!(line!());
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to decrypt data",
)));
}
pinned_self
.write_buffer
.extend_from_slice(&pinned_self.decryption_buffer);
Poll::Ready(Ok(&pinned_self.write_buffer))
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
self.as_mut().write_buffer.drain(..amt);
}
}
struct MetadataFuture<'a, R>(&'a mut Pin<&'a mut R>);
impl<'a, R: AsyncBufRead> Future for MetadataFuture<'a, R> {
type Output = Result<ImageMetadata, ()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut filled = 0;
loop {
let buf = match self.0.as_mut().poll_fill_buf(cx) {
Poll::Ready(Ok(buffer)) => buffer,
Poll::Ready(Err(e)) => {
dbg!(e);
return Poll::Ready(Err(()));
}
Poll::Pending => return Poll::Pending,
};
if filled == buf.len() {
dbg!(line!());
return Poll::Ready(Err(()));
} else {
filled = buf.len();
}
let mut reader = serde_json::Deserializer::from_slice(buf).into_iter();
let (res, bytes_consumed) = match reader.next() {
Some(Ok(metadata)) => (Poll::Ready(Ok(metadata)), reader.byte_offset()),
Some(Err(e)) if e.is_eof() => {
continue;
}
Some(Err(_)) | None => {
dbg!(line!());
return Poll::Ready(Err(()));
}
};
// This needs to be outside the loop because we need to drop the
// reader ref, since that depends on a mut self.
self.0.as_mut().consume(bytes_consumed);
return res;
}
}
}
@ -216,35 +328,27 @@ where
Fut: 'static + Send + Sync + Future<Output = ()>,
DbCallback: 'static + Send + Sync + FnOnce(u64) -> Fut,
{
let file = {
let mut file = {
let parent = path.parent().expect("The path to have a parent");
create_dir_all(parent).await?;
let file = File::create(path).await?; // we need to make sure the file exists and is truncated.
file
};
let mut writer: Pin<Box<dyn AsyncWrite + Send>> = if let Some((enc, header)) = ENCRYPTION_KEY
.get()
.map(|key| SecretStream::init_push(key).expect("Failed to init enc stream"))
{
file.write_all(dbg!(header.as_ref())).await?;
Box::pin(EncryptedDiskWriter::new(file, enc))
} else {
Box::pin(file)
};
let metadata_string = serde_json::to_string(&metadata).expect("serialization to work");
let metadata_size = metadata_string.len();
let (mut writer, maybe_header): (Pin<Box<dyn AsyncWrite + Send>>, _) =
if let Some((enc, header)) = ENCRYPTION_KEY
.get()
.map(|key| SecretStream::init_push(key).expect("Failed to init enc stream"))
{
(Box::pin(EncryptedDiskWriter::new(file, enc)), Some(header))
} else {
(Box::pin(file), None)
};
let mut error = if let Some(header) = maybe_header {
writer.write_all(header.as_ref()).await.err()
} else {
None
};
if error.is_none() {
error = writer.write_all(metadata_string.as_bytes()).await.err();
}
let mut error = writer.write_all(metadata_string.as_bytes()).await.err();
if error.is_none() {
error = writer.write_all(&bytes).await.err();
}
@ -300,13 +404,11 @@ impl AsyncWrite for EncryptedDiskWriter {
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let new_self = Pin::into_inner(self);
{
let encryption_buffer = &mut new_self.encryption_buffer;
if let Some(stream) = new_self.stream.as_mut() {
stream
.push_to_vec(buf, None, Tag::Message, encryption_buffer)
.expect("Failed to write encrypted data to buffer");
}
if let Some(stream) = new_self.stream.as_mut() {
stream
.push_to_vec(buf, None, Tag::Message, &mut new_self.encryption_buffer)
.expect("Failed to write encrypted data to buffer");
}
new_self.write_buffer.extend(&new_self.encryption_buffer);
@ -353,10 +455,26 @@ impl AsyncWrite for EncryptedDiskWriter {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
self.as_mut()
let maybe_bytes = self
.as_mut()
.stream
.take()
.map(|stream| stream.finalize(None));
// If we've yet to finalize the stream, finalize it and add the bytes to
// our writer buffer.
if let Some(Ok(bytes)) = maybe_bytes {
// We just need to push it into the buffer, we don't really care
// about the result, since we can check later
let _ = self.as_mut().poll_write(cx, &bytes);
}
// Now wait for us to fully flush out our write buffer
if !self.write_buffer.is_empty() {
return self.poll_flush(cx);
}
// Write buffer is empty, flush file
self.file.as_mut().poll_shutdown(cx)
}
}

4
src/cache/mod.rs vendored
View file

@ -15,7 +15,6 @@ use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use sodiumoxide::crypto::secretstream::{Header, Key, Pull, Stream as SecretStream};
use thiserror::Error;
use tokio::io::AsyncRead;
use tokio::sync::mpsc::Sender;
use tokio_util::codec::{BytesCodec, FramedRead};
@ -24,6 +23,7 @@ pub use fs::UpstreamError;
pub use mem::MemoryCache;
use self::compat::LegacyImageMetadata;
use self::fs::MetadataFetch;
pub static ENCRYPTION_KEY: OnceCell<Key> = OnceCell::new();
@ -277,7 +277,7 @@ impl Stream for CacheStream {
pub(self) enum InnerStream {
Memory(MemStream),
Completed(FramedRead<Pin<Box<dyn AsyncRead + Send>>, BytesCodec>),
Completed(FramedRead<Pin<Box<dyn MetadataFetch + Send>>, BytesCodec>),
}
impl From<CachedImage> for InnerStream {