Compare commits

..

1 commit

Author SHA1 Message Date
ab53a6ac24
Add support for v32 2021-07-09 14:32:43 -04:00
31 changed files with 1705 additions and 4675 deletions

View file

@ -1,10 +0,0 @@
# Ignore everything
*
# Only include necessary paths (This should be synchronized with `Cargo.toml`)
!db_queries/
!src/
!settings.sample.yaml
!sqlx-data.json
!Cargo.toml
!Cargo.lock

View file

@ -1,53 +0,0 @@
name: Build and test
on:
push:
branches: [ master ]
paths-ignore:
- "docs/**"
- "settings.sample.yaml"
- "README.md"
- "LICENSE"
pull_request:
branches: [ master ]
paths-ignore:
- "docs/**"
- "settings.sample.yaml"
- "README.md"
- "LICENSE"
env:
CARGO_TERM_COLOR: always
DATABASE_URL: sqlite:./cache/metadata.sqlite
SQLX_OFFLINE: true
jobs:
clippy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- run: rustup component add clippy
- uses: actions-rs/clippy-check@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
args: --all-features
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose
sqlx-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Install sqlx-cli
run: cargo install sqlx-cli
- name: Initialize database
run: mkdir -p cache && sqlite3 cache/metadata.sqlite < db_queries/init.sql
- name: Check sqlx statements
run: cargo sqlx prepare --check

View file

@ -1,22 +0,0 @@
name: coverage
on: [push]
jobs:
test:
name: coverage
runs-on: ubuntu-latest
container:
image: xd009642/tarpaulin:develop-nightly
options: --security-opt seccomp=unconfined
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Generate code coverage
run: |
cargo +nightly tarpaulin --verbose --all-features --workspace --timeout 120 --avoid-cfg-tarpaulin --out Xml
- name: Upload to codecov.io
uses: codecov/codecov-action@v1
with:
fail_ci_if_error: true

View file

@ -1,14 +0,0 @@
name: Security audit
on:
push:
paths:
- '**/Cargo.toml'
- '**/Cargo.lock'
jobs:
security_audit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions-rs/audit-check@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}

5
.gitignore vendored
View file

@ -3,7 +3,4 @@
/cache /cache
flamegraph*.svg flamegraph*.svg
perf.data* perf.data*
dhat.out.* dhat.out.*
settings.yaml
tarpaulin-report.html
GeoLite2-Country.mmdb

1574
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,69 +1,50 @@
[package] [package]
name = "mangadex-home" name = "mangadex-home"
version = "0.5.4" version = "0.3.0"
license = "GPL-3.0-or-later" license = "GPL-3.0-or-later"
authors = ["Edward Shen <code@eddie.sh>"] authors = ["Edward Shen <code@eddie.sh>"]
edition = "2018" edition = "2018"
include = [ include = ["src/**/*", "db_queries", "LICENSE", "README.md"]
"src/**/*",
"db_queries",
"LICENSE",
"README.md",
"sqlx-data.json",
"settings.sample.yaml"
]
description = "A MangaDex@Home implementation in Rust." description = "A MangaDex@Home implementation in Rust."
repository = "https://github.com/edward-shen/mangadex-home-rs" repository = "https://github.com/edward-shen/mangadex-home-rs"
[profile.release] [profile.release]
lto = true lto = true
codegen-units = 1 codegen-units = 1
debug = 1 # debug = 1
[dependencies] [dependencies]
# Pin because we're using unstable versions actix-web = { version = "4.0.0-beta.4", features = [ "rustls" ] }
actix-web = { version = "4", features = [ "rustls" ] }
arc-swap = "1" arc-swap = "1"
async-trait = "0.1" async-trait = "0.1"
base64 = "0.13" base64 = "0.13"
bincode = "1" bincode = "1"
bytes = { version = "1", features = [ "serde" ] } bytes = "1"
chacha20 = "0.7"
chrono = { version = "0.4", features = [ "serde" ] } chrono = { version = "0.4", features = [ "serde" ] }
clap = { version = "3", features = [ "wrap_help", "derive", "cargo" ] } clap = { version = "3.0.0-beta.2", features = [ "wrap_help" ] }
ctrlc = "3" ctrlc = "3"
dotenv = "0.15" dotenv = "0.15"
flate2 = { version = "1", features = [ "tokio" ] }
futures = "0.3" futures = "0.3"
once_cell = "1" once_cell = "1"
log = { version = "0.4", features = [ "serde" ] } log = "0.4"
lfu_cache = "1" lfu_cache = "1"
lru = "0.7" lru = "0.6"
maxminddb = "0.20"
md-5 = "0.9"
parking_lot = "0.11" parking_lot = "0.11"
prometheus = { version = "0.12", features = [ "process" ] } prometheus = { version = "0.12", features = [ "process" ] }
redis = "0.21"
reqwest = { version = "0.11", default_features = false, features = [ "json", "stream", "rustls-tls" ] } reqwest = { version = "0.11", default_features = false, features = [ "json", "stream", "rustls-tls" ] }
rustls = "0.20" ring = "*"
rustls-pemfile = "0.2" rustls = "0.19"
serde = "1" serde = "1"
serde_json = "1" serde_json = "1"
serde_repr = "0.1" serde_repr = "0.1"
serde_yaml = "0.8" simple_logger = "1"
sodiumoxide = "0.2" sodiumoxide = "0.2"
sqlx = { version = "0.5", features = [ "runtime-actix-rustls", "sqlite", "time", "chrono", "macros", "offline" ] } sqlx = { version = "0.5", features = [ "runtime-actix-rustls", "sqlite", "time", "chrono", "macros" ] }
tar = "0.4"
thiserror = "1" thiserror = "1"
tokio = { version = "1", features = [ "rt-multi-thread", "macros", "fs", "time", "sync", "parking_lot" ] } tokio = { version = "1", features = [ "full", "parking_lot" ] }
tokio-stream = { version = "0.1", features = [ "sync" ] } tokio-stream = { version = "0.1", features = [ "sync" ] }
tokio-util = { version = "0.6", features = [ "codec" ] } tokio-util = { version = "0.6", features = [ "codec" ] }
tracing = "0.1"
tracing-subscriber = { version = "0.2", features = [ "parking_lot" ] }
url = { version = "2", features = [ "serde" ] } url = { version = "2", features = [ "serde" ] }
[build-dependencies] [build-dependencies]
vergen = "5" vergen = "5"
[dev-dependencies]
tempfile = "3"

View file

@ -1,11 +0,0 @@
# syntax=docker/dockerfile:1
FROM rust:alpine as builder
COPY . .
RUN apk add --no-cache file make musl-dev \
&& cargo install --path . \
&& strip /usr/local/cargo/bin/mangadex-home
FROM alpine:latest
COPY --from=builder /usr/local/cargo/bin/mangadex-home /usr/local/bin/mangadex-home
CMD ["mangadex-home"]

107
README.md
View file

@ -2,86 +2,75 @@ A Rust implementation of a MangaDex@Home client.
This client contains the following features: This client contains the following features:
- Easy migration from the official client - Multi-threaded
- Fully compliant with MangaDex@Home specifications - HTTP/2 support
- Multi-threaded, high performance, and low overhead client - No support for TLS 1.1 or 1.0
- HTTP/2 support for API users, HTTP/2 only for upstream connections
- Secure and privacy oriented features:
- Only supports TLS 1.2 or newer; HTTP is not enabled by default
- Options for no logging and no metrics
- Support for on-disk XChaCha20 encryption with ephemeral key generation
- Supports an internal LFU, LRU, or a redis instance for in-memory caching
## Building ## Building
Since we use SQLx there are a few things you'll need to do. First, you'll need
to run the init cache script, which initializes the db cache at
`./cache/metadata.sqlite`. Then you'll need to add the location of that to a
`.env` file:
```sh ```sh
# In the project root
./init_cache.sh
echo "DATABASE_URL=sqlite:./cache/metadata.sqlite" >> .env
cargo build cargo build
cargo test
``` ```
You may need to set a client secret, see Configuration for more information. ## Cache implementation
# Migration This client implements a multi-tier in-memory and on-disk LRU cache constrained
by quotas. In essence, it acts as an unified LRU, where in-memory items are
evicted and pushed into the on-disk LRU and fetching a item from the on-disk LRU
promotes it to the in-memory LRU.
Migration from the official client was made to be as painless as possible. There Note that the capacity of each LRU is dynamic, depending on the maximum byte
are caveats though: capacity that you permit each cache to be. A large item may evict multiple
- If you ever want to return to using the official client, you will need to smaller items to fit within this constraint, for example.
clear your cache.
- As this is an unofficial client implementation, the only support you can
probably get is from me.
Otherwise, the steps to migration is easy: Note that these quotas are closer to a rough estimate, and is not guaranteed to
1. Place the binary in the same folder as your `images` folder and be strictly below these values, so it's recommended to under set your config
`settings.yaml`. values to make sure you don't exceed the actual quota.
2. Rename `images` to `cache`.
# Client implementation
This client follows a secure-first approach. As such, your statistics may report
a _ever-so-slightly_ higher-than-average failure rate. Specifically, this client
choses to:
- Not support TLS 1.1 or 1.0, which would be a primary source of
incompatibility.
- Not provide a server identification string in the header of served requests.
- HTTPS by enabled by default, HTTP is provided (and unsupported).
That being said, this client should be backwards compatibility with the official
client data and config. That means you should be able to replace the binary and
preserve all your settings and cache.
## Installation ## Installation
Either build it from source or run `cargo install mangadex-home`. Either build it from source or run `cargo install mangadex-home`.
## Running
Run `mangadex-home`, and make sure the advertised port is open on your firewall.
Do note that some configuration fields are required. See the next section for
details.
## Configuration ## Configuration
Most configuration options can be either provided on the command line or sourced Most configuration options can be either provided on the command line, sourced
from a file named `settings.yaml` from the directory you ran the command from, from a `.env` file, or sourced directly from the environment. Do not that the
which will be created on first run. client secret is an exception. You must provide the client secret from the
environment or from the `.env` file, as providing client secrets in a shell is a
operation security risk.
Note that the client secret (`CLIENT_SECRET`) is the only configuration option The following options are required:
that can only can be provided from the environment, an `.env` file, or the
`settings.yaml` file. In other words, you _cannot_ provide this value from the
command line.
## Special thanks - Client Secret
- Memory cache quota
- Disk cache quota
- Advertised network speed
This project could not have been completed without the assistance of the The following are optional as a default value will be set for you:
following:
#### Development Assistance (Alphabetical Order) - Port
- Disk cache path
- carbotaniuman#6974 ### Advanced configuration
- LFlair#1337
- Plykiya#1738
- Tristan 9#6752
- The Rust Discord community
#### Beta testers This implementation prefers to act more secure by default. As a result, some
features that the official specification requires are not enabled by default.
If you don't know why these features are disabled by default, then don't enable
these, as they may generally weaken the security stance of the client for more
compatibility.
- NigelVH#7162 - Sending Server version string
---
If using the geo IP logging feature, then this product includes GeoLite2 data
created by MaxMind, available from https://www.maxmind.com.

View file

@ -3,10 +3,8 @@ use std::error::Error;
use vergen::{vergen, Config, ShaKind}; use vergen::{vergen, Config, ShaKind};
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
// Initialize vergen stuff
let mut config = Config::default(); let mut config = Config::default();
*config.git_mut().sha_kind_mut() = ShaKind::Short; *config.git_mut().sha_kind_mut() = ShaKind::Short;
vergen(config)?; vergen(config)?;
Ok(()) Ok(())
} }

View file

@ -1 +0,0 @@
insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing

View file

@ -1,9 +0,0 @@
version: "3.9"
services:
mangadex-home:
build: .
ports:
- "443:443"
volumes:
- ./cache:/cache
- ./settings.yaml:/settings.yaml

View file

@ -1,14 +0,0 @@
# Ciphers
This client relies on rustls, which only supports a subset of TLS ciphers.
Specifically, only TLS 1.2 ECDSA GCM ciphers as well as all TLS 1.3 ciphers are
supported. This means that clients that only support older, more insecure
ciphers may not be able to connect to this client.
In practice, this means this client's failure rate may be higher than expected.
This is okay, and within specifications.
## Why even bother?
Well, Australia has officially banned hentai... so I gotta make sure my mates
over there won't get in trouble if I'm connecting to them.

View file

@ -1,14 +0,0 @@
# Unstable Options
Unstable options are options that are either experimental, dangerous, for
development only, or a mix of the three. The following table describes each
option. Generally speaking, you should never need to enable these unless you
know what you're doing.
| Option | Experimental? | Dangerous? | For development? |
| -------------------------- | ------------- | ---------- | ---------------- |
| `override-upstream` | | | Yes |
| `use-lfu` | Yes | | |
| `disable-token-validation` | | Yes | Yes |
| `offline-mode` | | | Yes |
| `disable-tls` | | Yes | Yes |

10
init_cache.sh Executable file
View file

@ -0,0 +1,10 @@
#!/usr/bin/env bash
# This script needs to be run once in order for compile time macros to not
# complain about a missing DB
# We can trust that our program will initialize the db at runtime the same way
# as it pulls from the same file for initialization
mkdir cache
sqlite3 cache/metadata.sqlite < db_queries/init.sql

View file

@ -1,114 +0,0 @@
---
# ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢦⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢰⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⢼⣈⡆⠀⠀⠀⠀⠀⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠀⠀⠀⠀⠀⠀⣀⣀⣀⣀⡀⠀⠀⠀⢸⡧⣀⠀⣄⡠⠒⠉⠀⠀⠀⢀⡈⢑⢦⠀⠀⠀⠀⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⣴⡗⠒⠉⠉⠉⠉⠉⠀⠀⠀⠀⠈⠉⠉⠉⠑⠣⡀⠉⠚⡷⡆⠀⣀⣀⣀⠤⡺⠈⠫⠗⠢⡄⠀⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠔⠊⠁⢰⠃⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢰⠢⣧⠘⣎⠐⠠⠟⠋⠀⠀⠀⠀⢄⢸⠀⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠴⠋⠀⠀⠀⠀⡜⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⡆⡏⠀⠸⡄⢀⠀⠀⠪⠴⠒⠒⠚⠘⢤⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠔⠁⠀⠀⠀⠀⠀⠀⡇⠀⠀⢣⡀⢀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠀⢱⠃⠀⠀⡇⡇⠉⠖⡤⠤⠤⠤⠴⢒⠺⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠋⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠱⡼⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠘⣆⡎⠀⠀⢀⡇⠙⠦⣌⣹⣏⠉⠉⠁⣀⠠⣂⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡴⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢧⠀⠀⠀⣇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢰⠀⢻⣿⣿⣿⡟⠓⠤⣀⡀⠉⠓⠭⠭⠔⠒⠉⠈⡆⠀⠀
#⠀⠀⠀⠀⠀⠀⢀⣀⣀⣀⣀⣀⡸⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢣⣠⠴⢻⠀⠀⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣧⢸⡈⠉⠉⢣⠀⠀⠀⠉⠑⠢⢄⣀⣀⡤⠖⠋⠳⡀⠀
#⠀⠀⠀⠀⠀⡔⠁⠀⢇⠀⠀⠀⠈⠳⡀⠀⠀⢀⠂⠀⠀⠀⠀⢀⠃⢠⠏⠀⠀⠈⡆⢠⢷⠀⠀⠀⠀⠀⠀⠀⠀⢰⠀⣿⡎⡇⠀⡠⠈⡇⠀⠀⠀⠀⠀⠀⠈⣦⠃⠀⠀⠈⢦⠀
#⠀⠀⠀⠀⡰⠃⠀⠀⠘⡄⠀⠀⠀⠀⢱⠀⠀⡜⠀⠀⠀⠀⠀⢸⠐⡮⢀⠀⠀⠀⢱⡸⡄⢧⠀⠀⠀⡀⠀⠀⠀⢸⣇⡿⢳⡧⠊⠀⠀⡇⡇⠀⠆⠀⠀⠀⠀⢱⡀⡠⠂⠀⠈⡇
#⠀⠀⠀⢰⠁⠀⠀⡀⠀⠘⡄⠀⠀⠀⢸⠀⢠⠃⠀⠀⢀⠀⠀⣼⢠⠃⠀⠁⠢⢄⠀⠳⣇⠀⠣⡀⠀⢣⡀⠀⠀⢸⢹⡧⢺⠀⠀⠀⠀⡷⢹⠀⢠⠀⠀⠀⠀⠈⡏⠳⡄⠀⠀⢳
#⠀⠀⠀⢸⠀⠀⠀⠈⠢⢄⠈⣢⠔⠒⠙⠒⢼⠀⠀⠀⢸⠀⢀⠿⣸⠀⠀⠀⠀⠀⠉⠢⢌⠀⠀⠈⠉⠒⠯⠉⠒⠚⠚⣠⠾⠶⠿⠷⢶⣧⡈⢆⢸⠀⠀⠀⠀⠀⢣⠀⢱⠀⠀⡎
#⠀⠀⠀⢸⠀⠀⠸⠀⠀⠀⢹⠁⠀⠀⠀⡀⡞⠀⠀⠀⢸⠀⢸⠀⢿⠀⣠⣴⠾⠛⠛⠓⠢⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠁⡀⠘⣿⣶⣄⠈⢻⣆⢻⠀⡆⠀⠀⠀⢸⠠⣸⠂⢠⠃
#⠀⠀⠀⠘⡄⠀⠀⢡⠀⠀⡼⠀⡠⠐⠁⠀⡇⠀⠀⠀⠈⡆⢸⠀⢨⡾⠋⠀⠀⢻⣿⣿⣷⣦⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⡿⢻⣿⣿⠻⣇⠀⢻⣾⠀⡇⠀⠀⠀⠈⡞⠁⣠⠋⠀
#⠀⠀⠀⠀⠱⡀⠀⠀⠑⢄⡑⣅⠀⠀⠀⠀⡇⠀⠀⠀⠀⠘⣼⠀⣿⠁⠀⢠⡷⢾⣿⣿⡟⠛⡇⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠈⠛⠁⠀⢸⠀⠈⢹⡸⠀⠀⠀⠀⠀⡧⠚⡇⠀⠀
#⠀⠀⠀⠀⠀⠈⠢⢄⣀⠀⠈⠉⢑⣶⠴⠒⡇⠀⠀⠀⠀⠀⡟⠧⡇⠀⠀⠸⡁⠀⠙⠋⠀⠀⡞⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⠦⣤⣤⡴⠃⠀⠀⠼⠣⡀⠀⡇⠀⠀⡷⢄⣇⠀⠀
#⠀⠀⠀⠀⠀⠀⢀⠞⠀⡏⠉⢉⣵⣳⠀⠀⡇⠀⠀⠀⠀⠀⢱⠀⠁⠀⠀⠀⠑⠤⠤⡠⠤⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠠⠡⢁⠀⠀⢱⠀⡇⠀⢠⡇⠀⢻⡀⠀
#⠀⠀⠀⠀⠀⢠⠎⠀⢸⣇⡔⠉⠀⢹⡀⠀⡇⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠐⡀⢀⠀⠄⡀⠀⠀⠀⠀⠀⠀⢀⣀⣀⣀⣀⣀⣀⣤⠈⠠⠡⠁⠂⠌⠀⢸⠀⡗⠀⢸⠇⠀⢀⡇⠀
#⠀⠀⠀⠀⢠⠃⠀⠀⡎⡏⠀⠀⠀⠀⡇⠀⡇⠀⡆⠀⠀⠀⠘⡄⠀⠀⠈⠌⠠⠂⠌⠐⠀⢀⠎⠉⠒⠉⠉⠉⠉⠙⠛⠧⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⡼⢀⠇⠀⢸⣀⠴⠋⢱⠀
#⠀⠀⠀⢠⠃⠀⠀⢰⠙⢣⡀⠀⠀⣇⡇⠀⢧⠀⡇⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⡿⠀⠀⠀⠀⠀⢀⡼⠃⡜⠀⠀⡏⢱⠀⠐⠈⡇
#⠀⠀⢠⢃⠀⠀⢠⠇⠀⢸⡉⠓⠶⢿⠃⠀⢸⠀⡇⠀⠀⠀⡄⢹⠀⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡴⠁⠀⠀⠀⢀⣴⡟⠁⡰⠁⠀⢰⢧⠈⡆⠀⠇⢇
#⠀⠀⡜⡄⠀⢀⡎⠀⠀⠀⡇⠀⠀⢸⠀⠀⠈⡇⣿⠀⠀⠀⢧⠈⡗⠢⢤⣀⠀⠀⠀⠀⠀⠀⠙⢄⡀⠀⠀⠀⠀⠀⣀⡤⠚⠁⢀⣠⡤⠞⠋⢀⠇⡴⠁⠀⠀⠾⣼⠀⢱⠀⢸⢸
#⠀⠀⡇⡇⠀⡜⠀⠀⠀⠀⡇⠀⠀⣾⠀⠀⠀⢹⡏⡆⠀⠀⢸⢆⠸⡄⠀⠀⠉⢑⣦⣤⡀⠀⠀⠀⠉⠑⠒⣒⣋⣉⣡⣤⠒⠊⠉⡇⠀⠀⠀⣾⣊⠀⠀⠀⠈⢠⢻⠀⢸⣀⣿⡜
#⠀⠀⣷⢇⢸⠁⠀⠀⠀⠀⡇⠀⢰⢹⠀⠀⠀⠀⢿⠹⡀⠀⠸⡀⠳⣵⡀⡠⠚⠉⠙⢿⣿⣷⣦⣀⠀⠀⠀⣱⣿⣿⠀⠈⠉⠲⣄⢧⣠⠒⢌⡇⡠⣃⣀⡠⠔⠁⠀⡇⢸⡟⢸⠇
#⠀⠀⢻⠘⣼⠀⠀⠀⠀⢰⠁⣠⠃⢸⠀⠀⠀⠀⠘⠀⠳⡀⠀⡇⠀⢀⠟⠦⠀⡀⠀⢸⣛⣻⣿⣿⣿⣶⣭⣿⣿⣻⡆⠀⠀⠀⠈⢦⠸⣽⢝⠿⡫⡁⢸⡇⠀⠀⠀⢣⠘⠁⠘⠀
#⠀⠀⠘⠆⠸⠄⠀⠀⢠⠏⡰⠁⠀⡞⠀⠀⠀⠀⠀⠀⠀⠙⢄⣸⣶⣷⣶⣶⣶⣤⣤⣼⣿⣽⣯⣿⣿⣿⣷⣾⣿⣿⣿⣾⣤⣴⣶⣾⣷⣇⠺⠤⠕⠈⢉⠇⠀⠀⠀⠘⡄
#
# MangaDex@Home configuration file
#
# Thanks for contributing to MangaDex@Home, friend!
# Beat up a pineapple, and don't forget your AsaCoco!
#
# Default values are commented out.
# The size in mebibytes of the cache You can use megabytes instead in a pinch,
# but just know the two are **NOT** the same.
max_cache_size_in_mebibytes: 0
server_settings:
# The client secret. Keep this secret at all costs :P
secret: suichan wa kyou mo kawaii!
# The port for the webserver to listen on. 443 is recommended for max appeal.
# port: 443
# This controls the value the server receives for your upload speed.
external_max_kilobits_per_second: 1
#
# Advanced settings
#
# The external hostname to listen on. Keep this at 0.0.0.0 unless you know
# what you're doing.
# hostname: 0.0.0.0
# The external port to broadcast to the backend. Keep this at 0 unless you
# know what you're doing. 0 means broadcast the same value as `port`.
# external_port: 0
# How long to wait at most for the graceful shutdown (Ctrl-C or SIGINT).
# graceful_shutdown_wait_seconds: 60
# The external ip to broadcast to the webserver. The default of null (~) means
# the backend will infer it from where it was sent from, which may fail in the
# presence of multiple IPs.
# external_ip: ~
# Settings for geo IP analytics
metric_settings:
# Whether to enable geo IP analytics
# enable_geoip: false
# geoip_license_key: none
# These settings are unique to the Rust client, and may be ignored or behave
# differently from the official client.
extended_options:
# Which cache type to use. By default, this is `on_disk`, but one can select
# `lfu`, `lru`, or `redis` to use a LFU, LRU, or redis instance in addition
# to the on-disk cache to improve lookup times. Generally speaking, using one
# is almost always better, but by how much depends on how much memory you let
# the node use, how large is your node, and which caching implementation you
# use.
# cache_type: on_disk
# The redis url to connect with. Does nothing if the cache type isn't`redis`.
# redis_url: "redis://127.0.0.1/"
# The amount of memory the client should use when using an in-memory cache.
# This does nothing if only the on-disk cache is used.
# memory_quota: 0
# Whether or not to expose a prometheus endpoint at /metrics. This is a
# completely open endpoint, so best practice is to make sure this is only
# readable from the internal network.
# enable_metrics: false
# If you'd like to specify a different path location for the cache, you can do
# so here.
# cache_path: "./cache"
# What logging level to use. Valid options are "error", "warn", "info",
# "debug", "trace", and "off", which disables logging.
# logging_level: info
# Enables disk encryption where the key is stored in memory. In other words,
# when the MD@H program is stopped, all cached files are irrecoverable.
# Practically speaking, this isn't all too useful (and definitely hurts
# performance), but for peace of mind, this may be useful.
# ephemeral_disk_encryption: false

View file

@ -1,75 +0,0 @@
{
"db": "SQLite",
"24b536161a0ed44d0595052ad069c023631ffcdeadb15a01ee294717f87cdd42": {
"query": "update Images set accessed = ? where id = ?",
"describe": {
"columns": [],
"parameters": {
"Right": 2
},
"nullable": []
}
},
"2a8aa6dd2c59241a451cd73f23547d0e94930e35654692839b5d11bb8b87703e": {
"query": "insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing",
"describe": {
"columns": [],
"parameters": {
"Right": 3
},
"nullable": []
}
},
"311721fec7824c2fc3ecf53f714949a49245c11a6b622efdb04fdac24be41ba3": {
"query": "SELECT IFNULL(SUM(size), 0) AS size FROM Images",
"describe": {
"columns": [
{
"name": "size",
"ordinal": 0,
"type_info": "Int"
}
],
"parameters": {
"Right": 0
},
"nullable": [
true
]
}
},
"44234188e873a467ecf2c60dfb4731011e0b7afc4472339ed2ae33aee8b0c9dd": {
"query": "select id, size from Images order by accessed asc limit 1000",
"describe": {
"columns": [
{
"name": "id",
"ordinal": 0,
"type_info": "Text"
},
{
"name": "size",
"ordinal": 1,
"type_info": "Int64"
}
],
"parameters": {
"Right": 0
},
"nullable": [
false,
false
]
}
},
"a60501a30fd75b2a2a59f089e850343af075436a5c543a267ecb4fa841593ce9": {
"query": "create table if not exists Images(\n id varchar primary key not null,\n size integer not null,\n accessed timestamp not null default CURRENT_TIMESTAMP\n);\ncreate index if not exists Images_accessed on Images(accessed);",
"describe": {
"columns": [],
"parameters": {
"Right": 0
},
"nullable": []
}
}
}

152
src/cache/compat.rs vendored
View file

@ -1,152 +0,0 @@
//! These structs have alternative deserialize and serializations
//! implementations to assist reading from the official client file format.
use std::str::FromStr;
use chrono::{DateTime, FixedOffset};
use serde::de::{Unexpected, Visitor};
use serde::{Deserialize, Serialize};
use super::ImageContentType;
#[derive(Copy, Clone, Deserialize)]
pub struct LegacyImageMetadata {
pub(crate) content_type: Option<LegacyImageContentType>,
pub(crate) size: Option<u32>,
pub(crate) last_modified: Option<LegacyDateTime>,
}
#[derive(Copy, Clone, Serialize)]
pub struct LegacyDateTime(pub DateTime<FixedOffset>);
impl<'de> Deserialize<'de> for LegacyDateTime {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct LegacyDateTimeVisitor;
impl<'de> Visitor<'de> for LegacyDateTimeVisitor {
type Value = LegacyDateTime;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "a valid image type")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
DateTime::parse_from_rfc2822(v)
.map(LegacyDateTime)
.map_err(|_| E::invalid_value(Unexpected::Str(v), &"a valid image type"))
}
}
deserializer.deserialize_str(LegacyDateTimeVisitor)
}
}
#[derive(Copy, Clone)]
pub struct LegacyImageContentType(pub ImageContentType);
impl<'de> Deserialize<'de> for LegacyImageContentType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct LegacyImageContentTypeVisitor;
impl<'de> Visitor<'de> for LegacyImageContentTypeVisitor {
type Value = LegacyImageContentType;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "a valid image type")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
ImageContentType::from_str(v)
.map(LegacyImageContentType)
.map_err(|_| E::invalid_value(Unexpected::Str(v), &"a valid image type"))
}
}
deserializer.deserialize_str(LegacyImageContentTypeVisitor)
}
}
#[cfg(test)]
mod parse {
use std::error::Error;
use chrono::DateTime;
use crate::cache::ImageContentType;
use super::LegacyImageMetadata;
#[test]
fn from_valid_legacy_format() -> Result<(), Box<dyn Error>> {
let legacy_header = r#"{"content_type":"image/jpeg","last_modified":"Sat, 10 Apr 2021 10:55:22 GMT","size":117888}"#;
let metadata: LegacyImageMetadata = serde_json::from_str(legacy_header)?;
assert_eq!(
metadata.content_type.map(|v| v.0),
Some(ImageContentType::Jpeg)
);
assert_eq!(metadata.size, Some(117_888));
assert_eq!(
metadata.last_modified.map(|v| v.0),
Some(DateTime::parse_from_rfc2822(
"Sat, 10 Apr 2021 10:55:22 GMT"
)?)
);
Ok(())
}
#[test]
fn empty_metadata() -> Result<(), Box<dyn Error>> {
let legacy_header = "{}";
let metadata: LegacyImageMetadata = serde_json::from_str(legacy_header)?;
assert!(metadata.content_type.is_none());
assert!(metadata.size.is_none());
assert!(metadata.last_modified.is_none());
Ok(())
}
#[test]
fn invalid_image_mime_value() {
let legacy_header = r#"{"content_type":"image/not-a-real-image"}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn invalid_date_time() {
let legacy_header = r#"{"last_modified":"idk last tuesday?"}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn invalid_size() {
let legacy_header = r#"{"size":-1}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn wrong_image_type() {
let legacy_header = r#"{"content_type":25}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn wrong_date_time_type() {
let legacy_header = r#"{"last_modified":false}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
}

595
src/cache/disk.rs vendored
View file

@ -1,39 +1,30 @@
//! Low memory caching stuff //! Low memory caching stuff
use std::convert::TryFrom; use std::path::PathBuf;
use std::hint::unreachable_unchecked;
use std::os::unix::prelude::OsStrExt;
use std::path::{Path, PathBuf};
use std::str::FromStr; use std::str::FromStr;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes;
use futures::StreamExt; use futures::StreamExt;
use log::LevelFilter; use log::{error, warn, LevelFilter};
use md5::digest::generic_array::GenericArray;
use md5::{Digest, Md5};
use sodiumoxide::hex;
use sqlx::sqlite::SqliteConnectOptions; use sqlx::sqlite::SqliteConnectOptions;
use sqlx::{ConnectOptions, Sqlite, SqlitePool, Transaction}; use sqlx::{ConnectOptions, SqlitePool};
use tokio::fs::{create_dir_all, remove_file, rename, File}; use tokio::fs::remove_file;
use tokio::join;
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, info, instrument, warn};
use crate::units::Bytes; use super::{
BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata,
};
use super::{Cache, CacheEntry, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata};
#[derive(Debug)]
pub struct DiskCache { pub struct DiskCache {
disk_path: PathBuf, disk_path: PathBuf,
disk_cur_size: AtomicU64, disk_cur_size: AtomicU64,
db_update_channel_sender: Sender<DbMessage>, db_update_channel_sender: Sender<DbMessage>,
} }
#[derive(Debug)]
enum DbMessage { enum DbMessage {
Get(Arc<PathBuf>), Get(Arc<PathBuf>),
Put(Arc<PathBuf>, u64), Put(Arc<PathBuf>, u64),
@ -43,47 +34,29 @@ impl DiskCache {
/// Constructs a new low memory cache at the provided path and capacity. /// Constructs a new low memory cache at the provided path and capacity.
/// This internally spawns a task that will wait for filesystem /// This internally spawns a task that will wait for filesystem
/// notifications when a file has been written. /// notifications when a file has been written.
pub async fn new(disk_max_size: Bytes, disk_path: PathBuf) -> Arc<Self> { pub async fn new(disk_max_size: u64, disk_path: PathBuf) -> Arc<Self> {
if let Err(e) = create_dir_all(&disk_path).await { let (db_tx, db_rx) = channel(128);
error!("Failed to create cache folder: {}", e);
}
let cache_path = disk_path.to_string_lossy();
// Migrate old to new path
if rename(
format!("{}/metadata.sqlite", cache_path),
format!("{}/metadata.db", cache_path),
)
.await
.is_ok()
{
info!("Found old metadata file, migrating to new location.");
}
let db_pool = { let db_pool = {
let db_url = format!("sqlite:{}/metadata.db", cache_path); let db_url = format!("sqlite:{}/metadata.sqlite", disk_path.to_string_lossy());
let mut options = SqliteConnectOptions::from_str(&db_url) let mut options = SqliteConnectOptions::from_str(&db_url)
.unwrap() .unwrap()
.create_if_missing(true); .create_if_missing(true);
options.log_statements(LevelFilter::Trace); options.log_statements(LevelFilter::Trace);
SqlitePool::connect_with(options).await.unwrap() let db = SqlitePool::connect_with(options).await.unwrap();
// Run db init
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut db.acquire().await.unwrap())
.await
.unwrap();
db
}; };
Self::from_db_pool(db_pool, disk_max_size, disk_path).await
}
async fn from_db_pool(pool: SqlitePool, disk_max_size: Bytes, disk_path: PathBuf) -> Arc<Self> {
let (db_tx, db_rx) = channel(128);
// Run db init
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut pool.acquire().await.unwrap())
.await
.unwrap();
// This is intentional. // This is intentional.
#[allow(clippy::cast_sign_loss)] #[allow(clippy::cast_sign_loss)]
let disk_cur_size = { let disk_cur_size = {
let mut conn = pool.acquire().await.unwrap(); let mut conn = db_pool.acquire().await.unwrap();
sqlx::query!("SELECT IFNULL(SUM(size), 0) AS size FROM Images") sqlx::query!("SELECT IFNULL(SUM(size), 0) AS size FROM Images")
.fetch_one(&mut conn) .fetch_one(&mut conn)
.await .await
@ -101,25 +74,12 @@ impl DiskCache {
tokio::spawn(db_listener( tokio::spawn(db_listener(
Arc::clone(&new_self), Arc::clone(&new_self),
db_rx, db_rx,
pool, db_pool,
disk_max_size.get() as u64 / 20 * 19, disk_max_size / 20 * 19,
)); ));
new_self new_self
} }
#[cfg(test)]
fn in_memory() -> (Self, Receiver<DbMessage>) {
let (db_tx, db_rx) = channel(128);
(
Self {
disk_path: PathBuf::new(),
disk_cur_size: AtomicU64::new(0),
db_update_channel_sender: db_tx,
},
db_rx,
)
}
} }
/// Spawn a new task that will listen for updates to the db, pruning if the size /// Spawn a new task that will listen for updates to the db, pruning if the size
@ -130,10 +90,9 @@ async fn db_listener(
db_pool: SqlitePool, db_pool: SqlitePool,
max_on_disk_size: u64, max_on_disk_size: u64,
) { ) {
// This is in a receiver stream to process up to 128 simultaneous db updates
// in one transaction
let mut recv_stream = ReceiverStream::new(db_rx).ready_chunks(128); let mut recv_stream = ReceiverStream::new(db_rx).ready_chunks(128);
while let Some(messages) = recv_stream.next().await { while let Some(messages) = recv_stream.next().await {
let now = chrono::Utc::now();
let mut transaction = match db_pool.begin().await { let mut transaction = match db_pool.begin().await {
Ok(transaction) => transaction, Ok(transaction) => transaction,
Err(e) => { Err(e) => {
@ -141,12 +100,38 @@ async fn db_listener(
continue; continue;
} }
}; };
for message in messages { for message in messages {
match message { match message {
DbMessage::Get(entry) => handle_db_get(&entry, &mut transaction).await, DbMessage::Get(entry) => {
let key = entry.as_os_str().to_str();
let query =
sqlx::query!("update Images set accessed = ? where id = ?", now, key)
.execute(&mut transaction)
.await;
if let Err(e) = query {
warn!("Failed to update timestamp in db for {:?}: {}", key, e);
}
}
DbMessage::Put(entry, size) => { DbMessage::Put(entry, size) => {
handle_db_put(&entry, size, &cache, &mut transaction).await; let key = entry.as_os_str().to_str();
{
// This is intentional.
#[allow(clippy::cast_possible_wrap)]
let size = size as i64;
let query = sqlx::query!(
"insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing",
key,
size,
now,
)
.execute(&mut transaction)
.await;
if let Err(e) = query {
warn!("Failed to add {:?} to db: {}", key, e);
}
}
cache.disk_cur_size.fetch_add(size, Ordering::Release);
} }
} }
} }
@ -160,10 +145,21 @@ async fn db_listener(
let on_disk_size = (cache.disk_cur_size.load(Ordering::Acquire) + 4095) / 4096 * 4096; let on_disk_size = (cache.disk_cur_size.load(Ordering::Acquire) + 4095) / 4096 * 4096;
if on_disk_size >= max_on_disk_size { if on_disk_size >= max_on_disk_size {
let mut conn = match db_pool.acquire().await {
Ok(conn) => conn,
Err(e) => {
error!(
"Failed to get a DB connection and cannot prune disk cache: {}",
e
);
continue;
}
};
let items = { let items = {
let request = let request =
sqlx::query!("select id, size from Images order by accessed asc limit 1000") sqlx::query!("select id, size from Images order by accessed asc limit 1000")
.fetch_all(&db_pool) .fetch_all(&mut conn)
.await; .await;
match request { match request {
Ok(items) => items, Ok(items) => items,
@ -180,9 +176,8 @@ async fn db_listener(
let mut size_freed = 0; let mut size_freed = 0;
#[allow(clippy::cast_sign_loss)] #[allow(clippy::cast_sign_loss)]
for item in items { for item in items {
debug!("deleting file due to exceeding cache size");
size_freed += item.size as u64; size_freed += item.size as u64;
tokio::spawn(remove_file_handler(item.id)); tokio::spawn(remove_file(item.id));
} }
cache.disk_cur_size.fetch_sub(size_freed, Ordering::Release); cache.disk_cur_size.fetch_sub(size_freed, Ordering::Release);
@ -190,126 +185,6 @@ async fn db_listener(
} }
} }
/// Returns if a file was successfully deleted.
async fn remove_file_handler(key: String) -> bool {
let error = if let Err(e) = remove_file(&key).await {
e
} else {
return true;
};
if error.kind() != std::io::ErrorKind::NotFound {
warn!("Failed to delete file `{}` from cache: {}", &key, error);
return false;
}
if let Ok(bytes) = hex::decode(&key) {
if bytes.len() != 16 {
warn!("Failed to delete file `{}`; invalid hash size.", &key);
return false;
}
let hash = Md5Hash(*GenericArray::from_slice(&bytes));
let path: PathBuf = hash.into();
if let Err(e) = remove_file(&path).await {
warn!(
"Failed to delete file `{}` from cache: {}",
path.to_string_lossy(),
e
);
false
} else {
true
}
} else {
warn!("Failed to delete file `{}`; not a md5hash.", &key);
false
}
}
#[instrument(level = "debug", skip(transaction))]
async fn handle_db_get(entry: &Path, transaction: &mut Transaction<'_, Sqlite>) {
let key = entry.as_os_str().to_str();
let now = chrono::Utc::now();
let query = sqlx::query!("update Images set accessed = ? where id = ?", now, key)
.execute(transaction)
.await;
if let Err(e) = query {
warn!("Failed to update timestamp in db for {:?}: {}", key, e);
}
}
#[instrument(level = "debug", skip(transaction, cache))]
async fn handle_db_put(
entry: &Path,
size: u64,
cache: &DiskCache,
transaction: &mut Transaction<'_, Sqlite>,
) {
let key = entry.as_os_str().to_str();
let now = chrono::Utc::now();
// This is intentional.
#[allow(clippy::cast_possible_wrap)]
let casted_size = size as i64;
let query = sqlx::query_file!("./db_queries/insert_image.sql", key, casted_size, now)
.execute(transaction)
.await;
if let Err(e) = query {
warn!("Failed to add to db: {}", e);
}
cache.disk_cur_size.fetch_add(size, Ordering::Release);
}
/// Represents a Md5 hash that can be converted to and from a path. This is used
/// for compatibility with the official client, where the image id and on-disk
/// path is determined by file path.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
struct Md5Hash(GenericArray<u8, <Md5 as md5::Digest>::OutputSize>);
impl Md5Hash {
fn to_hex_string(self) -> String {
format!("{:x}", self.0)
}
}
impl TryFrom<&Path> for Md5Hash {
type Error = ();
fn try_from(path: &Path) -> Result<Self, Self::Error> {
let mut iter = path.iter();
let file_name = iter.next_back().ok_or(())?;
let chapter_hash = iter.next_back().ok_or(())?;
let is_data_saver = iter.next_back().ok_or(())? == "saver";
let mut hasher = Md5::new();
if is_data_saver {
hasher.update("saver");
}
hasher.update(chapter_hash.as_bytes());
hasher.update(".");
hasher.update(file_name.as_bytes());
Ok(Self(hasher.finalize()))
}
}
impl From<Md5Hash> for PathBuf {
fn from(hash: Md5Hash) -> Self {
let hex_value = hash.to_hex_string();
let path = hex_value[0..3]
.chars()
.rev()
.map(|char| Self::from(char.to_string()))
.reduce(|first, second| first.join(second));
match path {
Some(p) => p.join(hex_value),
None => unsafe { unreachable_unchecked() }, // literally not possible
}
}
}
#[async_trait] #[async_trait]
impl Cache for DiskCache { impl Cache for DiskCache {
async fn get( async fn get(
@ -321,33 +196,12 @@ impl Cache for DiskCache {
let path = Arc::new(self.disk_path.clone().join(PathBuf::from(key))); let path = Arc::new(self.disk_path.clone().join(PathBuf::from(key)));
let path_0 = Arc::clone(&path); let path_0 = Arc::clone(&path);
let legacy_path = Md5Hash::try_from(path_0.as_path()) tokio::spawn(async move { channel.send(DbMessage::Get(path_0)).await });
.map(PathBuf::from)
.map(|path| self.disk_path.clone().join(path))
.map(Arc::new);
// Get file and path of first existing location path super::fs::read_file(&path).await.map(|res| {
let (file, path) = if let Ok(legacy_path) = legacy_path { let (inner, maybe_header, metadata) = res?;
let maybe_files = join!( CacheStream::new(inner, maybe_header)
File::open(legacy_path.as_path()), .map(|stream| (stream, metadata))
File::open(path.as_path()),
);
match maybe_files {
(Ok(f), _) => Some((f, legacy_path)),
(_, Ok(f)) => Some((f, path)),
_ => return None,
}
} else {
File::open(path.as_path())
.await
.ok()
.map(|file| (file, path))
}?;
tokio::spawn(async move { channel.send(DbMessage::Get(path)).await });
super::fs::read_file(file).await.map(|res| {
res.map(|(stream, _, metadata)| (stream, metadata))
.map_err(|_| CacheError::DecryptionFailure) .map_err(|_| CacheError::DecryptionFailure)
}) })
} }
@ -355,9 +209,9 @@ impl Cache for DiskCache {
async fn put( async fn put(
&self, &self,
key: CacheKey, key: CacheKey,
image: bytes::Bytes, image: BoxedImageStream,
metadata: ImageMetadata, metadata: ImageMetadata,
) -> Result<(), CacheError> { ) -> Result<CacheStream, CacheError> {
let channel = self.db_update_channel_sender.clone(); let channel = self.db_update_channel_sender.clone();
let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key))); let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key)));
@ -370,6 +224,9 @@ impl Cache for DiskCache {
super::fs::write_file(&path, key, image, metadata, db_callback, None) super::fs::write_file(&path, key, image, metadata, db_callback, None)
.await .await
.map_err(CacheError::from) .map_err(CacheError::from)
.and_then(|(inner, maybe_header)| {
CacheStream::new(inner, maybe_header).map_err(|_| CacheError::DecryptionFailure)
})
} }
} }
@ -378,10 +235,10 @@ impl CallbackCache for DiskCache {
async fn put_with_on_completed_callback( async fn put_with_on_completed_callback(
&self, &self,
key: CacheKey, key: CacheKey,
image: bytes::Bytes, image: BoxedImageStream,
metadata: ImageMetadata, metadata: ImageMetadata,
on_complete: Sender<CacheEntry>, on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>,
) -> Result<(), CacheError> { ) -> Result<CacheStream, CacheError> {
let channel = self.db_update_channel_sender.clone(); let channel = self.db_update_channel_sender.clone();
let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key))); let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key)));
@ -395,298 +252,8 @@ impl CallbackCache for DiskCache {
super::fs::write_file(&path, key, image, metadata, db_callback, Some(on_complete)) super::fs::write_file(&path, key, image, metadata, db_callback, Some(on_complete))
.await .await
.map_err(CacheError::from) .map_err(CacheError::from)
} .and_then(|(inner, maybe_header)| {
} CacheStream::new(inner, maybe_header).map_err(|_| CacheError::DecryptionFailure)
#[cfg(test)]
mod db_listener {
use super::{db_listener, DbMessage};
use crate::DiskCache;
use futures::TryStreamExt;
use sqlx::{Row, SqlitePool};
use std::error::Error;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc::channel;
#[tokio::test]
async fn can_handle_multiple_events() -> Result<(), Box<dyn Error>> {
let (mut cache, rx) = DiskCache::in_memory();
let (mut tx, _) = channel(1);
// Swap the tx with the new one, else the receiver will never end
std::mem::swap(&mut cache.db_update_channel_sender, &mut tx);
assert_eq!(tx.capacity(), 128);
let cache = Arc::new(cache);
let db = SqlitePool::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&db)
.await?;
// Populate the queue with messages
for c in 'a'..='z' {
tx.send(DbMessage::Put(Arc::new(PathBuf::from(c.to_string())), 10))
.await?;
tx.send(DbMessage::Get(Arc::new(PathBuf::from(c.to_string()))))
.await?;
}
// Explicitly close the channel so that the listener terminates
std::mem::drop(tx);
db_listener(cache, rx, db.clone(), u64::MAX).await;
let count = Arc::new(AtomicUsize::new(0));
sqlx::query("select * from Images")
.fetch(&db)
.try_for_each_concurrent(None, |row| {
let count = Arc::clone(&count);
async move {
assert_eq!(row.get::<i32, _>("size"), 10);
count.fetch_add(1, Ordering::Release);
Ok(())
}
}) })
.await?;
assert_eq!(count.load(Ordering::Acquire), 26);
Ok(())
}
}
#[cfg(test)]
mod remove_file_handler {
use std::error::Error;
use tempfile::tempdir;
use tokio::fs::{create_dir_all, remove_dir_all};
use super::{remove_file_handler, File};
#[tokio::test]
async fn should_not_panic_on_invalid_path() {
assert!(!remove_file_handler("/this/is/a/non-existent/path/".to_string()).await);
}
#[tokio::test]
async fn should_not_panic_on_invalid_hash() {
assert!(!remove_file_handler("68b329da9893e34099c7d8ad5cb9c940".to_string()).await);
}
#[tokio::test]
async fn should_not_panic_on_malicious_hashes() {
assert!(!remove_file_handler("68b329da9893e34".to_string()).await);
assert!(
!remove_file_handler("68b329da9893e34099c7d8ad5cb9c940aaaaaaaaaaaaaaaaaa".to_string())
.await
);
}
#[tokio::test]
async fn should_delete_existing_file() -> Result<(), Box<dyn Error>> {
let temp_dir = tempdir()?;
let mut dir_path = temp_dir.path().to_path_buf();
dir_path.push("abc123.png");
// create a file, it can be empty
File::create(&dir_path).await?;
assert!(remove_file_handler(dir_path.to_string_lossy().into_owned()).await);
Ok(())
}
#[tokio::test]
async fn should_delete_existing_hash() -> Result<(), Box<dyn Error>> {
create_dir_all("b/8/6").await?;
File::create("b/8/6/68b329da9893e34099c7d8ad5cb9c900").await?;
assert!(remove_file_handler("68b329da9893e34099c7d8ad5cb9c900".to_string()).await);
remove_dir_all("b").await?;
Ok(())
}
}
#[cfg(test)]
mod disk_cache {
use std::error::Error;
use std::path::PathBuf;
use std::sync::atomic::Ordering;
use chrono::Utc;
use sqlx::SqlitePool;
use crate::units::Bytes;
use super::DiskCache;
#[tokio::test]
async fn db_is_initialized() -> Result<(), Box<dyn Error>> {
let conn = SqlitePool::connect("sqlite::memory:").await?;
let _cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await;
let res = sqlx::query("select * from Images").execute(&conn).await;
assert!(res.is_ok());
Ok(())
}
#[tokio::test]
async fn db_initializes_empty() -> Result<(), Box<dyn Error>> {
let conn = SqlitePool::connect("sqlite::memory:").await?;
let cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await;
assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 0);
Ok(())
}
#[tokio::test]
async fn db_can_load_from_existing() -> Result<(), Box<dyn Error>> {
let conn = SqlitePool::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&conn)
.await?;
let now = Utc::now();
sqlx::query_file!("./db_queries/insert_image.sql", "a", 4, now)
.execute(&conn)
.await?;
let now = Utc::now();
sqlx::query_file!("./db_queries/insert_image.sql", "b", 15, now)
.execute(&conn)
.await?;
let cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await;
assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 19);
Ok(())
}
}
#[cfg(test)]
mod db {
use chrono::{DateTime, Utc};
use sqlx::{Connection, Row, SqliteConnection};
use std::error::Error;
use super::{handle_db_get, handle_db_put, DiskCache, FromStr, Ordering, PathBuf, StreamExt};
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn get() -> Result<(), Box<dyn Error>> {
let (cache, _) = DiskCache::in_memory();
let path = PathBuf::from_str("a/b/c")?;
let mut conn = SqliteConnection::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut conn)
.await?;
// Add an entry
let mut transaction = conn.begin().await?;
handle_db_put(&path, 10, &cache, &mut transaction).await;
transaction.commit().await?;
let time_fence = Utc::now();
let mut transaction = conn.begin().await?;
handle_db_get(&path, &mut transaction).await;
transaction.commit().await?;
let mut rows: Vec<_> = sqlx::query("select * from Images")
.fetch(&mut conn)
.collect()
.await;
assert_eq!(rows.len(), 1);
let entry = rows.pop().unwrap()?;
assert!(time_fence < entry.get::<'_, DateTime<Utc>, _>("accessed"));
Ok(())
}
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn put() -> Result<(), Box<dyn Error>> {
let (cache, _) = DiskCache::in_memory();
let path = PathBuf::from_str("a/b/c")?;
let mut conn = SqliteConnection::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut conn)
.await?;
let mut transaction = conn.begin().await?;
let transaction_time = Utc::now();
handle_db_put(&path, 10, &cache, &mut transaction).await;
transaction.commit().await?;
let mut rows: Vec<_> = sqlx::query("select * from Images")
.fetch(&mut conn)
.collect()
.await;
assert_eq!(rows.len(), 1);
let entry = rows.pop().unwrap()?;
assert_eq!(entry.get::<'_, &str, _>("id"), "a/b/c");
assert_eq!(entry.get::<'_, i64, _>("size"), 10);
let accessed: DateTime<Utc> = entry.get("accessed");
assert!(transaction_time < accessed);
assert!(accessed < Utc::now());
assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 10);
Ok(())
}
}
#[cfg(test)]
mod md5_hash {
use super::{Digest, GenericArray, Md5, Md5Hash, Path, PathBuf, TryFrom};
#[test]
fn to_cache_path() {
let hash = Md5Hash(
*GenericArray::<_, <Md5 as md5::Digest>::OutputSize>::from_slice(&[
0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd,
0xef, 0xab,
]),
);
assert_eq!(
PathBuf::from(hash).to_str(),
Some("c/b/a/abcdefabcdefabcdefabcdefabcdefab")
)
}
#[test]
fn from_data_path() {
let mut expected_hasher = Md5::new();
expected_hasher.update("foo.bar.png");
assert_eq!(
Md5Hash::try_from(Path::new("data/foo/bar.png")),
Ok(Md5Hash(expected_hasher.finalize()))
);
}
#[test]
fn from_data_saver_path() {
let mut expected_hasher = Md5::new();
expected_hasher.update("saverfoo.bar.png");
assert_eq!(
Md5Hash::try_from(Path::new("saver/foo/bar.png")),
Ok(Md5Hash(expected_hasher.finalize()))
);
}
#[test]
fn can_handle_long_paths() {
assert_eq!(
Md5Hash::try_from(Path::new("a/b/c/d/e/f/g/saver/foo/bar.png")),
Md5Hash::try_from(Path::new("saver/foo/bar.png")),
);
}
#[test]
fn from_invalid_paths() {
assert!(Md5Hash::try_from(Path::new("foo/bar.png")).is_err());
assert!(Md5Hash::try_from(Path::new("bar.png")).is_err());
assert!(Md5Hash::try_from(Path::new("")).is_err());
} }
} }

884
src/cache/fs.rs vendored

File diff suppressed because it is too large Load diff

698
src/cache/mem.rs vendored
View file

@ -1,74 +1,19 @@
use std::borrow::Cow;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use super::{Cache, CacheEntry, CacheKey, CacheStream, CallbackCache, ImageMetadata, MemStream}; use super::{
BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata,
InnerStream, MemStream,
};
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use futures::FutureExt; use futures::FutureExt;
use lfu_cache::LfuCache; use lfu_cache::LfuCache;
use lru::LruCache; use lru::LruCache;
use redis::{ use tokio::sync::mpsc::{channel, Sender};
Client as RedisClient, Commands, FromRedisValue, RedisError, RedisResult, ToRedisArgs,
};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::warn;
#[derive(Clone, Serialize, Deserialize)] type CacheValue = (Bytes, ImageMetadata, u64);
pub struct CacheValue {
data: Bytes,
metadata: ImageMetadata,
on_disk_size: u64,
}
impl CacheValue {
#[inline]
fn new(data: Bytes, metadata: ImageMetadata, on_disk_size: u64) -> Self {
Self {
data,
metadata,
on_disk_size,
}
}
}
impl FromRedisValue for CacheValue {
fn from_redis_value(v: &redis::Value) -> RedisResult<Self> {
use bincode::ErrorKind;
if let redis::Value::Data(data) = v {
bincode::deserialize(data).map_err(|err| match *err {
ErrorKind::Io(e) => RedisError::from(e),
ErrorKind::Custom(e) => RedisError::from((
redis::ErrorKind::ResponseError,
"bincode deserialize failed",
e,
)),
e => RedisError::from((
redis::ErrorKind::ResponseError,
"bincode deserialized failed",
e.to_string(),
)),
})
} else {
Err(RedisError::from((
redis::ErrorKind::TypeError,
"Got non data type from redis db",
)))
}
}
}
impl ToRedisArgs for CacheValue {
fn write_redis_args<W>(&self, out: &mut W)
where
W: ?Sized + redis::RedisWrite,
{
out.write_arg(&bincode::serialize(self).expect("serialization to work"));
}
}
/// Use LRU as the eviction strategy /// Use LRU as the eviction strategy
pub type Lru = LruCache<CacheKey, CacheValue>; pub type Lru = LruCache<CacheKey, CacheValue>;
@ -76,29 +21,22 @@ pub type Lru = LruCache<CacheKey, CacheValue>;
pub type Lfu = LfuCache<CacheKey, CacheValue>; pub type Lfu = LfuCache<CacheKey, CacheValue>;
/// Adapter trait for memory cache backends /// Adapter trait for memory cache backends
pub trait InternalMemoryCacheInitializer: InternalMemoryCache {
fn new() -> Self;
}
pub trait InternalMemoryCache: Sync + Send { pub trait InternalMemoryCache: Sync + Send {
fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>>; fn unbounded() -> Self;
fn get(&mut self, key: &CacheKey) -> Option<&CacheValue>;
fn push(&mut self, key: CacheKey, data: CacheValue); fn push(&mut self, key: CacheKey, data: CacheValue);
fn pop(&mut self) -> Option<(CacheKey, CacheValue)>; fn pop(&mut self) -> Option<(CacheKey, CacheValue)>;
} }
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCacheInitializer for Lfu {
#[inline]
fn new() -> Self {
Self::unbounded()
}
}
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCache for Lfu { impl InternalMemoryCache for Lfu {
#[inline] #[inline]
fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>> { fn unbounded() -> Self {
self.get(key).map(Cow::Borrowed) Self::unbounded()
}
#[inline]
fn get(&mut self, key: &CacheKey) -> Option<&CacheValue> {
self.get(key)
} }
#[inline] #[inline]
@ -112,19 +50,15 @@ impl InternalMemoryCache for Lfu {
} }
} }
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCacheInitializer for Lru {
#[inline]
fn new() -> Self {
Self::unbounded()
}
}
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCache for Lru { impl InternalMemoryCache for Lru {
#[inline] #[inline]
fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>> { fn unbounded() -> Self {
self.get(key).map(Cow::Borrowed) Self::unbounded()
}
#[inline]
fn get(&mut self, key: &CacheKey) -> Option<&CacheValue> {
self.get(key)
} }
#[inline] #[inline]
@ -138,73 +72,13 @@ impl InternalMemoryCache for Lru {
} }
} }
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCache for RedisClient {
fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>> {
Commands::get(self, key).ok().map(Cow::Owned)
}
fn push(&mut self, key: CacheKey, data: CacheValue) {
if let Err(e) = Commands::set::<_, _, ()>(self, key, data) {
warn!("Failed to push to redis: {}", e);
}
}
fn pop(&mut self) -> Option<(CacheKey, CacheValue)> {
unimplemented!("redis should handle its own memory")
}
}
/// Memory accelerated disk cache. Uses the internal cache implementation in /// Memory accelerated disk cache. Uses the internal cache implementation in
/// memory to speed up reads. /// memory to speed up reads.
pub struct MemoryCache<MemoryCacheImpl, ColdCache> { pub struct MemoryCache<MemoryCacheImpl, ColdCache> {
inner: ColdCache, inner: ColdCache,
cur_mem_size: AtomicU64, cur_mem_size: AtomicU64,
mem_cache: Mutex<MemoryCacheImpl>, mem_cache: Mutex<MemoryCacheImpl>,
master_sender: Sender<CacheEntry>, master_sender: Sender<(CacheKey, Bytes, ImageMetadata, u64)>,
}
impl<MemoryCacheImpl, ColdCache> MemoryCache<MemoryCacheImpl, ColdCache>
where
MemoryCacheImpl: 'static + InternalMemoryCacheInitializer,
ColdCache: 'static + Cache,
{
pub fn new(inner: ColdCache, max_mem_size: crate::units::Bytes) -> Arc<Self> {
let (tx, rx) = channel(100);
let new_self = Arc::new(Self {
inner,
cur_mem_size: AtomicU64::new(0),
mem_cache: Mutex::new(MemoryCacheImpl::new()),
master_sender: tx,
});
tokio::spawn(internal_cache_listener(
Arc::clone(&new_self),
max_mem_size,
rx,
));
new_self
}
/// Returns an instance of the cache with the receiver for callback events
/// Really only useful for inspecting the receiver, e.g. for testing
#[cfg(test)]
pub fn new_with_receiver(
inner: ColdCache,
_: crate::units::Bytes,
) -> (Self, Receiver<CacheEntry>) {
let (tx, rx) = channel(100);
(
Self {
inner,
cur_mem_size: AtomicU64::new(0),
mem_cache: Mutex::new(MemoryCacheImpl::new()),
master_sender: tx,
},
rx,
)
}
} }
impl<MemoryCacheImpl, ColdCache> MemoryCache<MemoryCacheImpl, ColdCache> impl<MemoryCacheImpl, ColdCache> MemoryCache<MemoryCacheImpl, ColdCache>
@ -212,66 +86,52 @@ where
MemoryCacheImpl: 'static + InternalMemoryCache, MemoryCacheImpl: 'static + InternalMemoryCache,
ColdCache: 'static + Cache, ColdCache: 'static + Cache,
{ {
pub fn new_with_cache(inner: ColdCache, init_mem_cache: MemoryCacheImpl) -> Self { pub async fn new(inner: ColdCache, max_mem_size: u64) -> Arc<Self> {
Self { let (tx, mut rx) = channel(100);
let new_self = Arc::new(Self {
inner, inner,
cur_mem_size: AtomicU64::new(0), cur_mem_size: AtomicU64::new(0),
mem_cache: Mutex::new(init_mem_cache), mem_cache: Mutex::new(MemoryCacheImpl::unbounded()),
master_sender: channel(1).0, master_sender: tx,
} });
}
}
async fn internal_cache_listener<MemoryCacheImpl, ColdCache>( let new_self_0 = Arc::clone(&new_self);
cache: Arc<MemoryCache<MemoryCacheImpl, ColdCache>>, tokio::spawn(async move {
max_mem_size: crate::units::Bytes, let new_self = new_self_0;
mut rx: Receiver<CacheEntry>, let max_mem_size = max_mem_size / 20 * 19;
) where while let Some((key, bytes, metadata, size)) = rx.recv().await {
MemoryCacheImpl: InternalMemoryCache, // Add to memory cache
ColdCache: Cache, // We can add first because we constrain our memory usage to 95%
{ new_self
let max_mem_size = mem_threshold(&max_mem_size); .cur_mem_size
while let Some(CacheEntry { .fetch_add(size as u64, Ordering::Release);
key, new_self
data, .mem_cache
metadata, .lock()
on_disk_size, .await
}) = rx.recv().await .push(key, (bytes, metadata, size));
{
// Add to memory cache
// We can add first because we constrain our memory usage to 95%
cache
.cur_mem_size
.fetch_add(on_disk_size as u64, Ordering::Release);
cache
.mem_cache
.lock()
.await
.push(key, CacheValue::new(data, metadata, on_disk_size));
// Pop if too large // Pop if too large
while cache.cur_mem_size.load(Ordering::Acquire) >= max_mem_size as u64 { while new_self.cur_mem_size.load(Ordering::Acquire) >= max_mem_size {
let popped = cache.mem_cache.lock().await.pop().map( let popped = new_self
|( .mem_cache
key, .lock()
CacheValue { .await
data, .pop()
metadata, .map(|(key, (bytes, metadata, size))| (key, bytes, metadata, size));
on_disk_size, if let Some((_, _, _, size)) = popped {
}, new_self
)| (key, data, metadata, on_disk_size), .cur_mem_size
); .fetch_sub(size as u64, Ordering::Release);
if let Some((_, _, _, size)) = popped { } else {
cache.cur_mem_size.fetch_sub(size as u64, Ordering::Release); break;
} else { }
break; }
} }
} });
}
}
const fn mem_threshold(bytes: &crate::units::Bytes) -> usize { new_self
bytes.get() / 20 * 19 }
} }
#[async_trait] #[async_trait]
@ -286,16 +146,16 @@ where
key: &CacheKey, key: &CacheKey,
) -> Option<Result<(CacheStream, ImageMetadata), super::CacheError>> { ) -> Option<Result<(CacheStream, ImageMetadata), super::CacheError>> {
match self.mem_cache.lock().now_or_never() { match self.mem_cache.lock().now_or_never() {
Some(mut mem_cache) => { Some(mut mem_cache) => match mem_cache.get(key).map(|(bytes, metadata, _)| {
match mem_cache.get(key).map(Cow::into_owned).map( Ok((InnerStream::Memory(MemStream(bytes.clone())), *metadata))
|CacheValue { data, metadata, .. }| { }) {
Ok((CacheStream::Memory(MemStream(data)), metadata)) Some(v) => Some(v.and_then(|(inner, metadata)| {
}, CacheStream::new(inner, None)
) { .map(|v| (v, metadata))
Some(v) => Some(v), .map_err(|_| CacheError::DecryptionFailure)
None => self.inner.get(key).await, })),
} None => self.inner.get(key).await,
} },
None => self.inner.get(key).await, None => self.inner.get(key).await,
} }
} }
@ -304,419 +164,11 @@ where
async fn put( async fn put(
&self, &self,
key: CacheKey, key: CacheKey,
image: Bytes, image: BoxedImageStream,
metadata: ImageMetadata, metadata: ImageMetadata,
) -> Result<(), super::CacheError> { ) -> Result<CacheStream, super::CacheError> {
self.inner self.inner
.put_with_on_completed_callback(key, image, metadata, self.master_sender.clone()) .put_with_on_completed_callback(key, image, metadata, self.master_sender.clone())
.await .await
} }
} }
#[cfg(test)]
mod test_util {
use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap};
use super::{CacheValue, InternalMemoryCache, InternalMemoryCacheInitializer};
use crate::cache::{
Cache, CacheEntry, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata,
};
use async_trait::async_trait;
use parking_lot::Mutex;
use tokio::io::BufReader;
use tokio::sync::mpsc::Sender;
use tokio_util::io::ReaderStream;
#[derive(Default)]
pub struct TestDiskCache(
pub Mutex<RefCell<HashMap<CacheKey, Result<(CacheStream, ImageMetadata), CacheError>>>>,
);
#[async_trait]
impl Cache for TestDiskCache {
async fn get(
&self,
key: &CacheKey,
) -> Option<Result<(CacheStream, ImageMetadata), CacheError>> {
self.0.lock().get_mut().remove(key)
}
async fn put(
&self,
key: CacheKey,
image: bytes::Bytes,
metadata: ImageMetadata,
) -> Result<(), CacheError> {
let reader = Box::pin(BufReader::new(tokio_util::io::StreamReader::new(
tokio_stream::once(Ok::<_, std::io::Error>(image)),
)));
let stream = CacheStream::Completed(ReaderStream::new(reader));
self.0.lock().get_mut().insert(key, Ok((stream, metadata)));
Ok(())
}
}
#[async_trait]
impl CallbackCache for TestDiskCache {
async fn put_with_on_completed_callback(
&self,
key: CacheKey,
data: bytes::Bytes,
metadata: ImageMetadata,
on_complete: Sender<CacheEntry>,
) -> Result<(), CacheError> {
self.put(key.clone(), data.clone(), metadata)
.await?;
let on_disk_size = data.len() as u64;
let _ = on_complete
.send(CacheEntry {
key,
data,
metadata,
on_disk_size,
})
.await;
Ok(())
}
}
#[derive(Default)]
pub struct TestMemoryCache(pub BTreeMap<CacheKey, CacheValue>);
impl InternalMemoryCacheInitializer for TestMemoryCache {
fn new() -> Self {
Self::default()
}
}
impl InternalMemoryCache for TestMemoryCache {
fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>> {
self.0.get(key).map(Cow::Borrowed)
}
fn push(&mut self, key: CacheKey, data: CacheValue) {
self.0.insert(key, data);
}
fn pop(&mut self) -> Option<(CacheKey, CacheValue)> {
let mut cache = BTreeMap::new();
std::mem::swap(&mut cache, &mut self.0);
let mut iter = cache.into_iter();
let ret = iter.next();
self.0 = iter.collect();
ret
}
}
}
#[cfg(test)]
mod cache_ops {
use std::error::Error;
use bytes::Bytes;
use futures::{FutureExt, StreamExt};
use crate::cache::mem::{CacheValue, InternalMemoryCache};
use crate::cache::{Cache, CacheEntry, CacheKey, CacheStream, ImageMetadata, MemStream};
use super::test_util::{TestDiskCache, TestMemoryCache};
use super::MemoryCache;
#[tokio::test]
async fn get_mem_cached() -> Result<(), Box<dyn Error>> {
let (cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
let value = CacheValue::new(bytes.clone(), metadata, 34);
// Populate the cache, need to drop the lock else it's considered locked
// when we actually call the cache
{
let mem_cache = &mut cache.mem_cache.lock().await;
mem_cache.push(key.clone(), value.clone());
}
let (stream, ret_metadata) = cache.get(&key).await.unwrap()?;
assert_eq!(metadata, ret_metadata);
if let CacheStream::Memory(MemStream(ret_stream)) = stream {
assert_eq!(bytes, ret_stream);
} else {
panic!("wrong stream type");
}
assert!(rx.recv().now_or_never().is_none());
Ok(())
}
#[tokio::test]
async fn get_disk_cached() -> Result<(), Box<dyn Error>> {
let (mut cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
{
let cache = &mut cache.inner;
cache
.put(key.clone(), bytes.clone(), metadata)
.await?;
}
let (mut stream, ret_metadata) = cache.get(&key).await.unwrap()?;
assert_eq!(metadata, ret_metadata);
assert!(matches!(stream, CacheStream::Completed(_)));
assert_eq!(stream.next().await, Some(Ok(bytes.clone())));
assert!(rx.recv().now_or_never().is_none());
Ok(())
}
// Identical to the get_disk_cached test but we hold a lock on the mem_cache
#[tokio::test]
async fn get_mem_locked() -> Result<(), Box<dyn Error>> {
let (mut cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
{
let cache = &mut cache.inner;
cache
.put(key.clone(), bytes.clone(), metadata)
.await?;
}
// intentionally not dropped
let _mem_cache = &mut cache.mem_cache.lock().await;
let (mut stream, ret_metadata) = cache.get(&key).await.unwrap()?;
assert_eq!(metadata, ret_metadata);
assert!(matches!(stream, CacheStream::Completed(_)));
assert_eq!(stream.next().await, Some(Ok(bytes.clone())));
assert!(rx.recv().now_or_never().is_none());
Ok(())
}
#[tokio::test]
async fn get_miss() {
let (cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
assert!(cache.get(&key).await.is_none());
assert!(rx.recv().now_or_never().is_none());
}
#[tokio::test]
async fn put_puts_into_disk_and_hears_from_rx() -> Result<(), Box<dyn Error>> {
let (cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
let bytes_len = bytes.len() as u64;
cache
.put(key.clone(), bytes.clone(), metadata)
.await?;
// Because the callback is supposed to let the memory cache insert the
// entry into its cache, we can check that it properly stored it on the
// disk layer by checking if we can successfully fetch it.
let (mut stream, ret_metadata) = cache.get(&key).await.unwrap()?;
assert_eq!(metadata, ret_metadata);
assert!(matches!(stream, CacheStream::Completed(_)));
assert_eq!(stream.next().await, Some(Ok(bytes.clone())));
// Check that we heard back
let cache_entry = rx
.recv()
.now_or_never()
.flatten()
.ok_or("failed to hear back from cache")?;
assert_eq!(
cache_entry,
CacheEntry {
key,
data: bytes,
metadata,
on_disk_size: bytes_len,
}
);
Ok(())
}
}
#[cfg(test)]
mod db_listener {
use std::error::Error;
use std::iter::FromIterator;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use bytes::Bytes;
use tokio::task;
use crate::cache::{Cache, CacheKey, ImageMetadata};
use super::test_util::{TestDiskCache, TestMemoryCache};
use super::{internal_cache_listener, MemoryCache};
#[tokio::test]
async fn put_into_memory() -> Result<(), Box<dyn Error>> {
let (cache, rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(0),
);
let cache = Arc::new(cache);
tokio::spawn(internal_cache_listener(
Arc::clone(&cache),
crate::units::Bytes(20),
rx,
));
// put small image into memory
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
cache.put(key.clone(), bytes.clone(), metadata).await?;
// let the listener run first
for _ in 0..10 {
task::yield_now().await;
}
assert_eq!(
cache.cur_mem_size.load(Ordering::SeqCst),
bytes.len() as u64
);
// Since we didn't populate the cache, fetching must be from memory, so
// this should succeed since the cache listener should push the item
// into cache
assert!(cache.get(&key).await.is_some());
Ok(())
}
#[tokio::test]
async fn pops_items() -> Result<(), Box<dyn Error>> {
let (cache, rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(0),
);
let cache = Arc::new(cache);
tokio::spawn(internal_cache_listener(
Arc::clone(&cache),
crate::units::Bytes(20),
rx,
));
// put small image into memory
let key_0 = CacheKey("a".to_string(), "b".to_string(), false);
let key_1 = CacheKey("c".to_string(), "d".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcde");
cache.put(key_0, bytes.clone(), metadata).await?;
cache.put(key_1, bytes.clone(), metadata).await?;
// let the listener run first
task::yield_now().await;
for _ in 0..10 {
task::yield_now().await;
}
// Items should be in cache now
assert_eq!(
cache.cur_mem_size.load(Ordering::SeqCst),
(bytes.len() * 2) as u64
);
let key_3 = CacheKey("e".to_string(), "f".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_iter(b"0".repeat(16).into_iter());
let bytes_len = bytes.len();
cache.put(key_3, bytes, metadata).await?;
// let the listener run first
task::yield_now().await;
for _ in 0..10 {
task::yield_now().await;
}
// Items should have been evicted, only 16 bytes should be there now
assert_eq!(cache.cur_mem_size.load(Ordering::SeqCst), bytes_len as u64);
Ok(())
}
}
#[cfg(test)]
mod mem_threshold {
use crate::units::Bytes;
use super::mem_threshold;
#[test]
fn small_amount_works() {
assert_eq!(mem_threshold(&Bytes(100)), 95);
}
#[test]
fn large_amount_cannot_overflow() {
assert_eq!(mem_threshold(&Bytes(usize::MAX)), 17_524_406_870_024_074_020);
}
}

138
src/cache/mod.rs vendored
View file

@ -5,46 +5,34 @@ use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use actix_web::http::header::HeaderValue; use actix_web::http::HeaderValue;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::{Bytes, BytesMut};
use chacha20::Key;
use chrono::{DateTime, FixedOffset}; use chrono::{DateTime, FixedOffset};
use fs::ConcurrentFsStream;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use redis::ToRedisArgs;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr}; use serde_repr::{Deserialize_repr, Serialize_repr};
use sodiumoxide::crypto::secretstream::{Header, Key, Pull, Stream as SecretStream};
use thiserror::Error; use thiserror::Error;
use tokio::io::AsyncRead;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use tokio_util::io::ReaderStream; use tokio_util::codec::{BytesCodec, FramedRead};
pub use disk::DiskCache; pub use disk::DiskCache;
pub use fs::UpstreamError; pub use fs::UpstreamError;
pub use mem::MemoryCache; pub use mem::MemoryCache;
use self::compat::LegacyImageMetadata;
use self::fs::MetadataFetch;
pub static ENCRYPTION_KEY: OnceCell<Key> = OnceCell::new(); pub static ENCRYPTION_KEY: OnceCell<Key> = OnceCell::new();
mod compat;
mod disk; mod disk;
mod fs; mod fs;
pub mod mem; pub mod mem;
#[derive(PartialEq, Eq, Hash, Clone, Debug, PartialOrd, Ord)] #[derive(PartialEq, Eq, Hash, Clone)]
pub struct CacheKey(pub String, pub String, pub bool); pub struct CacheKey(pub String, pub String, pub bool);
impl ToRedisArgs for CacheKey {
fn write_redis_args<W>(&self, out: &mut W)
where
W: ?Sized + redis::RedisWrite,
{
out.write_arg_fmt(self);
}
}
impl Display for CacheKey { impl Display for CacheKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.2 { if self.2 {
@ -72,7 +60,7 @@ impl From<&CacheKey> for PathBuf {
#[derive(Clone)] #[derive(Clone)]
pub struct CachedImage(pub Bytes); pub struct CachedImage(pub Bytes);
#[derive(Copy, Clone, Serialize, Deserialize, Debug, PartialEq, Eq)] #[derive(Copy, Clone, Serialize, Deserialize)]
pub struct ImageMetadata { pub struct ImageMetadata {
pub content_type: Option<ImageContentType>, pub content_type: Option<ImageContentType>,
pub content_length: Option<u32>, pub content_length: Option<u32>,
@ -80,7 +68,7 @@ pub struct ImageMetadata {
} }
// Confirmed by Ply to be these types: https://link.eddie.sh/ZXfk0 // Confirmed by Ply to be these types: https://link.eddie.sh/ZXfk0
#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, Debug, PartialEq, Eq)] #[derive(Copy, Clone, Serialize_repr, Deserialize_repr)]
#[repr(u8)] #[repr(u8)]
pub enum ImageContentType { pub enum ImageContentType {
Png = 0, Png = 0,
@ -115,21 +103,12 @@ impl AsRef<str> for ImageContentType {
} }
} }
impl From<LegacyImageMetadata> for ImageMetadata { #[allow(clippy::pub_enum_variant_names)]
fn from(legacy: LegacyImageMetadata) -> Self {
Self {
content_type: legacy.content_type.map(|v| v.0),
content_length: legacy.size,
last_modified: legacy.last_modified.map(|v| v.0),
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub enum ImageRequestError { pub enum ImageRequestError {
ContentType, InvalidContentType,
ContentLength, InvalidContentLength,
LastModified, InvalidLastModified,
} }
impl ImageMetadata { impl ImageMetadata {
@ -145,14 +124,14 @@ impl ImageMetadata {
Err(_) => Err(InvalidContentType), Err(_) => Err(InvalidContentType),
}) })
.transpose() .transpose()
.map_err(|_| ImageRequestError::ContentType)?, .map_err(|_| ImageRequestError::InvalidContentType)?,
content_length: content_length content_length: content_length
.map(|header_val| { .map(|header_val| {
header_val header_val
.to_str() .to_str()
.map_err(|_| ImageRequestError::ContentLength)? .map_err(|_| ImageRequestError::InvalidContentLength)?
.parse() .parse()
.map_err(|_| ImageRequestError::ContentLength) .map_err(|_| ImageRequestError::InvalidContentLength)
}) })
.transpose()?, .transpose()?,
last_modified: last_modified last_modified: last_modified
@ -160,15 +139,17 @@ impl ImageMetadata {
DateTime::parse_from_rfc2822( DateTime::parse_from_rfc2822(
header_val header_val
.to_str() .to_str()
.map_err(|_| ImageRequestError::LastModified)?, .map_err(|_| ImageRequestError::InvalidLastModified)?,
) )
.map_err(|_| ImageRequestError::LastModified) .map_err(|_| ImageRequestError::InvalidLastModified)
}) })
.transpose()?, .transpose()?,
}) })
} }
} }
type BoxedImageStream = Box<dyn Stream<Item = Result<Bytes, CacheError>> + Unpin + Send>;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum CacheError { pub enum CacheError {
#[error(transparent)] #[error(transparent)]
@ -189,9 +170,9 @@ pub trait Cache: Send + Sync {
async fn put( async fn put(
&self, &self,
key: CacheKey, key: CacheKey,
image: Bytes, image: BoxedImageStream,
metadata: ImageMetadata, metadata: ImageMetadata,
) -> Result<(), CacheError>; ) -> Result<CacheStream, CacheError>;
} }
#[async_trait] #[async_trait]
@ -208,9 +189,9 @@ impl<T: Cache> Cache for Arc<T> {
async fn put( async fn put(
&self, &self,
key: CacheKey, key: CacheKey,
image: Bytes, image: BoxedImageStream,
metadata: ImageMetadata, metadata: ImageMetadata,
) -> Result<(), CacheError> { ) -> Result<CacheStream, CacheError> {
self.as_ref().put(key, image, metadata).await self.as_ref().put(key, image, metadata).await
} }
} }
@ -220,43 +201,74 @@ pub trait CallbackCache: Cache {
async fn put_with_on_completed_callback( async fn put_with_on_completed_callback(
&self, &self,
key: CacheKey, key: CacheKey,
image: Bytes, image: BoxedImageStream,
metadata: ImageMetadata, metadata: ImageMetadata,
on_complete: Sender<CacheEntry>, on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>,
) -> Result<(), CacheError>; ) -> Result<CacheStream, CacheError>;
} }
#[async_trait] #[async_trait]
impl<T: CallbackCache> CallbackCache for Arc<T> { impl<T: CallbackCache> CallbackCache for Arc<T> {
#[inline] #[inline]
#[cfg(not(tarpaulin_include))]
async fn put_with_on_completed_callback( async fn put_with_on_completed_callback(
&self, &self,
key: CacheKey, key: CacheKey,
image: Bytes, image: BoxedImageStream,
metadata: ImageMetadata, metadata: ImageMetadata,
on_complete: Sender<CacheEntry>, on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>,
) -> Result<(), CacheError> { ) -> Result<CacheStream, CacheError> {
self.as_ref() self.as_ref()
.put_with_on_completed_callback(key, image, metadata, on_complete) .put_with_on_completed_callback(key, image, metadata, on_complete)
.await .await
} }
} }
#[derive(PartialEq, Eq, Debug)] pub struct CacheStream {
pub struct CacheEntry { inner: InnerStream,
key: CacheKey, decrypt: Option<SecretStream<Pull>>,
data: Bytes,
metadata: ImageMetadata,
on_disk_size: u64,
} }
pub enum CacheStream { impl CacheStream {
pub(self) fn new(inner: InnerStream, header: Option<Header>) -> Result<Self, ()> {
Ok(Self {
inner,
decrypt: header
.and_then(|header| ENCRYPTION_KEY.get().map(|key| SecretStream::init_pull(&header, key)))
.transpose()?,
})
}
}
impl Stream for CacheStream {
type Item = CacheStreamItem;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx).map(|data| {
// False positive (`data`): https://link.eddie.sh/r1fXX
#[allow(clippy::option_if_let_else)]
if let Some(keystream) = self.decrypt.as_mut() {
data.map(|bytes_res| {
bytes_res.and_then(|bytes| {
keystream
.pull(&bytes, None)
.map(|(data, _tag)| Bytes::from(data))
.map_err(|_| UpstreamError)
})
})
} else {
data
}
})
}
}
pub(self) enum InnerStream {
Concurrent(ConcurrentFsStream),
Memory(MemStream), Memory(MemStream),
Completed(ReaderStream<Pin<Box<dyn MetadataFetch + Send + Sync>>>), Completed(FramedRead<Pin<Box<dyn AsyncRead + Send>>, BytesCodec>),
} }
impl From<CachedImage> for CacheStream { impl From<CachedImage> for InnerStream {
fn from(image: CachedImage) -> Self { fn from(image: CachedImage) -> Self {
Self::Memory(MemStream(image.0)) Self::Memory(MemStream(image.0))
} }
@ -264,13 +276,17 @@ impl From<CachedImage> for CacheStream {
type CacheStreamItem = Result<Bytes, UpstreamError>; type CacheStreamItem = Result<Bytes, UpstreamError>;
impl Stream for CacheStream { impl Stream for InnerStream {
type Item = CacheStreamItem; type Item = CacheStreamItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.get_mut() { match self.get_mut() {
Self::Concurrent(stream) => stream.poll_next_unpin(cx),
Self::Memory(stream) => stream.poll_next_unpin(cx), Self::Memory(stream) => stream.poll_next_unpin(cx),
Self::Completed(stream) => stream.poll_next_unpin(cx).map_err(|_| UpstreamError), Self::Completed(stream) => stream
.poll_next_unpin(cx)
.map_ok(BytesMut::freeze)
.map_err(|_| UpstreamError),
} }
} }
} }

View file

@ -1,220 +0,0 @@
use std::collections::HashMap;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
use actix_web::http::header::{HeaderMap, HeaderName, HeaderValue};
use actix_web::web::Data;
use bytes::Bytes;
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use reqwest::header::{
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, CACHE_CONTROL, CONTENT_LENGTH,
CONTENT_TYPE, LAST_MODIFIED, X_CONTENT_TYPE_OPTIONS,
};
use reqwest::{Client, Proxy, StatusCode};
use tokio::sync::watch::{channel, Receiver};
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
use crate::cache::{Cache, CacheKey, ImageMetadata};
use crate::config::{DISABLE_CERT_VALIDATION, USE_PROXY};
pub static HTTP_CLIENT: Lazy<CachingClient> = Lazy::new(|| {
let mut inner = Client::builder()
.pool_idle_timeout(Duration::from_secs(180))
.https_only(true)
.http2_prior_knowledge();
if let Some(socket_addr) = USE_PROXY.get() {
info!(
"Using {} as a proxy for upstream requests.",
socket_addr.as_str()
);
inner = inner.proxy(Proxy::all(socket_addr.as_str()).unwrap());
}
if DISABLE_CERT_VALIDATION.load(Ordering::Acquire) {
inner = inner.danger_accept_invalid_certs(true);
}
let inner = inner.build().expect("Client initialization to work");
CachingClient {
inner,
locks: RwLock::new(HashMap::new()),
}
});
#[cfg(not(tarpaulin_include))]
pub static DEFAULT_HEADERS: Lazy<HeaderMap> = Lazy::new(|| {
let mut headers = HeaderMap::with_capacity(8);
headers.insert(X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff"));
headers.insert(
ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_static("https://mangadex.org"),
);
headers.insert(ACCESS_CONTROL_EXPOSE_HEADERS, HeaderValue::from_static("*"));
headers.insert(
CACHE_CONTROL,
HeaderValue::from_static("public, max-age=1209600"),
);
headers.insert(
HeaderName::from_static("timing-allow-origin"),
HeaderValue::from_static("https://mangadex.org"),
);
headers
});
pub struct CachingClient {
inner: Client,
locks: RwLock<HashMap<String, Receiver<FetchResult>>>,
}
#[derive(Clone, Debug)]
pub enum FetchResult {
ServiceUnavailable,
InternalServerError,
Data(StatusCode, HeaderMap, Bytes),
Processing,
}
impl CachingClient {
pub async fn fetch_and_cache(
&'static self,
url: String,
key: CacheKey,
cache: Data<dyn Cache>,
) -> FetchResult {
let maybe_receiver = {
let lock = self.locks.read();
lock.get(&url).map(Clone::clone)
};
if let Some(mut recv) = maybe_receiver {
loop {
if !matches!(*recv.borrow(), FetchResult::Processing) {
break;
}
if recv.changed().await.is_err() {
break;
}
}
return recv.borrow().clone();
}
let notify = Arc::new(Notify::new());
tokio::spawn(self.fetch_and_cache_impl(cache, url.clone(), key, Arc::clone(&notify)));
notify.notified().await;
let mut recv = self
.locks
.read()
.get(&url)
.expect("receiver to exist since we just made one")
.clone();
loop {
if !matches!(*recv.borrow(), FetchResult::Processing) {
break;
}
if recv.changed().await.is_err() {
break;
}
}
let resp = recv.borrow().clone();
resp
}
async fn fetch_and_cache_impl(
&self,
cache: Data<dyn Cache>,
url: String,
key: CacheKey,
notify: Arc<Notify>,
) {
let (tx, rx) = channel(FetchResult::Processing);
self.locks.write().insert(url.clone(), rx);
notify.notify_one();
let resp = self.inner.get(&url).send().await;
let resp = match resp {
Ok(mut resp) => {
let content_type = resp.headers().get(CONTENT_TYPE);
let is_image = content_type
.map(|v| String::from_utf8_lossy(v.as_ref()).contains("image/"))
.unwrap_or_default();
if resp.status() != StatusCode::OK || !is_image {
warn!("Got non-OK or non-image response code from upstream, proxying and not caching result.");
let mut headers = DEFAULT_HEADERS.clone();
if let Some(content_type) = content_type {
headers.insert(CONTENT_TYPE, content_type.clone());
}
FetchResult::Data(
resp.status(),
headers,
resp.bytes().await.unwrap_or_default(),
)
} else {
let (content_type, length, last_mod) = {
let headers = resp.headers_mut();
(
headers.remove(CONTENT_TYPE),
headers.remove(CONTENT_LENGTH),
headers.remove(LAST_MODIFIED),
)
};
let body = resp.bytes().await.unwrap();
debug!("Inserting into cache");
let metadata =
ImageMetadata::new(content_type.clone(), length.clone(), last_mod.clone())
.unwrap();
match cache.put(key, body.clone(), metadata).await {
Ok(()) => {
debug!("Done putting into cache");
let mut headers = DEFAULT_HEADERS.clone();
if let Some(content_type) = content_type {
headers.insert(CONTENT_TYPE, content_type);
}
if let Some(content_length) = length {
headers.insert(CONTENT_LENGTH, content_length);
}
if let Some(last_modified) = last_mod {
headers.insert(LAST_MODIFIED, last_modified);
}
FetchResult::Data(StatusCode::OK, headers, body)
}
Err(e) => {
warn!("Failed to insert into cache: {}", e);
FetchResult::InternalServerError
}
}
}
}
Err(e) => {
error!("Failed to fetch image from server: {}", e);
FetchResult::ServiceUnavailable
}
};
// This shouldn't happen
tx.send(resp).unwrap();
self.locks.write().remove(&url);
}
#[inline]
pub const fn inner(&self) -> &Client {
&self.inner
}
}

View file

@ -1,340 +1,67 @@
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::fs::{File, OpenOptions}; use std::num::{NonZeroU16, NonZeroU64};
use std::hint::unreachable_unchecked; use std::path::PathBuf;
use std::io::{ErrorKind, Write};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::num::NonZeroU16;
use std::path::{Path, PathBuf};
use std::str::FromStr; use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::AtomicBool;
use clap::{crate_authors, crate_description, crate_version, Parser}; use clap::{crate_authors, crate_description, crate_version, Clap};
use log::LevelFilter;
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::level_filters::LevelFilter as TracingLevelFilter;
use url::Url; use url::Url;
use crate::units::{KilobitsPerSecond, Mebibytes, Port};
// Validate tokens is an atomic because it's faster than locking on rwlock. // Validate tokens is an atomic because it's faster than locking on rwlock.
pub static VALIDATE_TOKENS: AtomicBool = AtomicBool::new(false); pub static VALIDATE_TOKENS: AtomicBool = AtomicBool::new(false);
// We use an atomic here because it's better for us to not pass the config
// everywhere.
pub static SEND_SERVER_VERSION: AtomicBool = AtomicBool::new(false);
pub static OFFLINE_MODE: AtomicBool = AtomicBool::new(false); pub static OFFLINE_MODE: AtomicBool = AtomicBool::new(false);
pub static USE_PROXY: OnceCell<Url> = OnceCell::new();
pub static DISABLE_CERT_VALIDATION: AtomicBool = AtomicBool::new(false);
#[derive(Error, Debug)] #[derive(Clap, Clone)]
pub enum ConfigError {
#[error("No config found. One has been created for you to modify.")]
NotInitialized,
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_yaml::Error),
}
pub fn load_config() -> Result<Config, ConfigError> {
// Load cli args first
let cli_args: CliArgs = CliArgs::parse();
// Load yaml file next
let config_file: Result<YamlArgs, _> = {
let config_path = cli_args
.config_path
.as_deref()
.unwrap_or_else(|| Path::new("./settings.yaml"));
match File::open(config_path) {
Ok(file) => serde_yaml::from_reader(file),
Err(e) if e.kind() == ErrorKind::NotFound => {
let mut file = OpenOptions::new()
.write(true)
.create_new(true)
.open(config_path)
.unwrap();
let default_config = include_str!("../settings.sample.yaml");
file.write_all(default_config.as_bytes()).unwrap();
return Err(ConfigError::NotInitialized);
}
Err(e) => return Err(e.into()),
}
};
// generate config
let config = Config::from_cli_and_file(cli_args, config_file?);
// initialize globals
OFFLINE_MODE.store(
config
.unstable_options
.contains(&UnstableOptions::OfflineMode),
Ordering::Release,
);
if let Some(socket) = config.proxy.clone() {
USE_PROXY
.set(socket)
.expect("USE_PROXY to be set only by this function");
}
DISABLE_CERT_VALIDATION.store(
config
.unstable_options
.contains(&UnstableOptions::DisableCertValidation),
Ordering::Release,
);
Ok(config)
}
#[derive(Debug)]
/// Represents a fully parsed config, from a variety of sources.
pub struct Config {
pub cache_type: CacheType,
pub cache_path: PathBuf,
pub shutdown_timeout: NonZeroU16,
pub log_level: TracingLevelFilter,
pub client_secret: ClientSecret,
pub port: Port,
pub bind_address: SocketAddr,
pub external_address: Option<SocketAddr>,
pub ephemeral_disk_encryption: bool,
pub network_speed: KilobitsPerSecond,
pub disk_quota: Mebibytes,
pub memory_quota: Mebibytes,
pub unstable_options: Vec<UnstableOptions>,
pub override_upstream: Option<Url>,
pub enable_metrics: bool,
pub geoip_license_key: Option<ClientSecret>,
pub proxy: Option<Url>,
pub redis_url: Option<Url>,
}
impl Config {
fn from_cli_and_file(cli_args: CliArgs, file_args: YamlArgs) -> Self {
let file_extended_options = file_args.extended_options.unwrap_or_default();
let log_level = match (cli_args.quiet, cli_args.verbose) {
(n, _) if n > 2 => TracingLevelFilter::OFF,
(2, _) => TracingLevelFilter::ERROR,
(1, _) => TracingLevelFilter::WARN,
// Use log level from file if no flags were provided to CLI
(0, 0) => {
file_extended_options
.logging_level
.map_or(TracingLevelFilter::INFO, |filter| match filter {
LevelFilter::Off => TracingLevelFilter::OFF,
LevelFilter::Error => TracingLevelFilter::ERROR,
LevelFilter::Warn => TracingLevelFilter::WARN,
LevelFilter::Info => TracingLevelFilter::INFO,
LevelFilter::Debug => TracingLevelFilter::DEBUG,
LevelFilter::Trace => TracingLevelFilter::TRACE,
})
}
(_, 1) => TracingLevelFilter::DEBUG,
(_, n) if n > 1 => TracingLevelFilter::TRACE,
// compiler can't figure it out
_ => unsafe { unreachable_unchecked() },
};
let bind_port = cli_args
.port
.unwrap_or(file_args.server_settings.port)
.get();
// This needs to be outside because rust isn't smart enough yet to
// realize a disjointed borrow of a moved value is ok. This will be
// fixed in Rust 2021.
let external_port = file_args
.server_settings
.external_port
.map_or(bind_port, Port::get);
Self {
cache_type: cli_args
.cache_type
.or(file_extended_options.cache_type)
.unwrap_or_default(),
cache_path: cli_args
.cache_path
.or(file_extended_options.cache_path)
.unwrap_or_else(|| PathBuf::from_str("./cache").unwrap()),
shutdown_timeout: file_args
.server_settings
.graceful_shutdown_wait_seconds
.unwrap_or(unsafe { NonZeroU16::new_unchecked(60) }),
log_level,
// secret should never be in CLI
client_secret: if let Ok(v) = std::env::var("CLIENT_SECRET") {
ClientSecret(v)
} else {
file_args.server_settings.secret
},
port: cli_args.port.unwrap_or(file_args.server_settings.port),
bind_address: SocketAddr::new(
file_args
.server_settings
.hostname
.unwrap_or_else(|| IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))),
bind_port,
),
external_address: file_args
.server_settings
.external_ip
.map(|ip_addr| SocketAddr::new(ip_addr, external_port)),
ephemeral_disk_encryption: cli_args.ephemeral_disk_encryption
|| file_extended_options
.ephemeral_disk_encryption
.unwrap_or_default(),
network_speed: cli_args
.network_speed
.unwrap_or(file_args.server_settings.external_max_kilobits_per_second),
disk_quota: cli_args
.disk_quota
.unwrap_or(file_args.max_cache_size_in_mebibytes),
memory_quota: cli_args
.memory_quota
.or(file_extended_options.memory_quota)
.unwrap_or_default(),
enable_metrics: file_extended_options.enable_metrics.unwrap_or_default(),
// Unstable options (and related) should never be in yaml config
unstable_options: cli_args.unstable_options,
override_upstream: cli_args.override_upstream,
geoip_license_key: file_args.metric_settings.and_then(|args| {
if args.enable_geoip.unwrap_or_default() {
args.geoip_license_key
} else {
None
}
}),
proxy: cli_args.proxy,
redis_url: file_extended_options.redis_url,
}
}
}
// this intentionally does not implement display
#[derive(Deserialize, Serialize, Clone)]
pub struct ClientSecret(String);
impl ClientSecret {
pub fn as_str(&self) -> &str {
self.0.as_ref()
}
}
impl std::fmt::Debug for ClientSecret {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "[client secret]")
}
}
#[derive(Deserialize, Copy, Clone, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum CacheType {
OnDisk,
Lru,
Lfu,
Redis,
}
impl FromStr for CacheType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"on_disk" => Ok(Self::OnDisk),
"lru" => Ok(Self::Lru),
"lfu" => Ok(Self::Lfu),
"redis" => Ok(Self::Redis),
_ => Err(format!("Unknown option: {}", s)),
}
}
}
impl Default for CacheType {
fn default() -> Self {
Self::OnDisk
}
}
#[derive(Deserialize)]
struct YamlArgs {
// Naming is legacy
max_cache_size_in_mebibytes: Mebibytes,
server_settings: YamlServerSettings,
metric_settings: Option<YamlMetricSettings>,
// This implementation's custom options
extended_options: Option<YamlExtendedOptions>,
}
// Naming is legacy
#[derive(Deserialize)]
struct YamlServerSettings {
secret: ClientSecret,
#[serde(default)]
port: Port,
external_max_kilobits_per_second: KilobitsPerSecond,
external_port: Option<Port>,
graceful_shutdown_wait_seconds: Option<NonZeroU16>,
hostname: Option<IpAddr>,
external_ip: Option<IpAddr>,
}
#[derive(Deserialize)]
struct YamlMetricSettings {
enable_geoip: Option<bool>,
geoip_license_key: Option<ClientSecret>,
}
#[derive(Deserialize, Default)]
struct YamlExtendedOptions {
memory_quota: Option<Mebibytes>,
cache_type: Option<CacheType>,
ephemeral_disk_encryption: Option<bool>,
enable_metrics: Option<bool>,
logging_level: Option<LevelFilter>,
cache_path: Option<PathBuf>,
redis_url: Option<Url>,
}
#[derive(Parser, Clone)]
#[clap(version = crate_version!(), author = crate_authors!(), about = crate_description!())] #[clap(version = crate_version!(), author = crate_authors!(), about = crate_description!())]
struct CliArgs { pub struct CliArgs {
/// The port to listen on. /// The port to listen on.
#[clap(short, long)] #[clap(short, long, default_value = "42069", env = "PORT")]
pub port: Option<Port>, pub port: NonZeroU16,
/// How large, in mebibytes, the in-memory cache should be. Note that this /// How large, in bytes, the in-memory cache should be. Note that this does
/// does not include runtime memory usage. /// not include runtime memory usage.
#[clap(long)] #[clap(long, env = "MEM_CACHE_QUOTA_BYTES", conflicts_with = "low-memory")]
pub memory_quota: Option<Mebibytes>, pub memory_quota: Option<NonZeroU64>,
/// How large, in mebibytes, the on-disk cache should be. Note that actual /// How large, in bytes, the on-disk cache should be. Note that actual
/// values may be larger for metadata information. /// values may be larger for metadata information.
#[clap(long)] #[clap(long, env = "DISK_CACHE_QUOTA_BYTES")]
pub disk_quota: Option<Mebibytes>, pub disk_quota: u64,
/// Sets the location of the disk cache. /// Sets the location of the disk cache.
#[clap(long)] #[clap(long, default_value = "./cache", env = "DISK_CACHE_PATH")]
pub cache_path: Option<PathBuf>, pub cache_path: PathBuf,
/// The network speed to advertise to Mangadex@Home control server. /// The network speed to advertise to Mangadex@Home control server.
#[clap(long)] #[clap(long, env = "MAX_NETWORK_SPEED")]
pub network_speed: Option<KilobitsPerSecond>, pub network_speed: NonZeroU64,
/// Whether or not to provide the Server HTTP header to clients. This is
/// useful for debugging, but is generally not recommended for security
/// reasons.
#[clap(long, env = "ENABLE_SERVER_STRING", takes_value = false)]
pub enable_server_string: bool,
/// Changes the caching behavior to avoid buffering images in memory, and
/// instead use the filesystem as the buffer backing. This is useful for
/// clients in low (< 1GB) RAM environments.
#[clap(
short,
long,
conflicts_with("memory-quota"),
env = "LOW_MEMORY_MODE",
takes_value = false
)]
pub low_memory: bool,
/// Changes verbosity. Default verbosity is INFO, while increasing counts of /// Changes verbosity. Default verbosity is INFO, while increasing counts of
/// verbose flags increases the verbosity to DEBUG and TRACE, respectively. /// verbose flags increases the verbosity to DEBUG and TRACE, respectively.
#[clap(short, long, parse(from_occurrences), conflicts_with = "quiet")] #[clap(short, long, parse(from_occurrences))]
pub verbose: usize, pub verbose: usize,
/// Changes verbosity. Default verbosity is INFO, while increasing counts of /// Changes verbosity. Default verbosity is INFO, while increasing counts of
/// quiet flags decreases the verbosity to WARN, ERROR, and no logs, /// quiet flags decreases the verbosity to WARN, ERROR, and no logs,
/// respectively. /// respectively.
#[clap(short, long, parse(from_occurrences), conflicts_with = "verbose")] #[clap(short, long, parse(from_occurrences), conflicts_with = "verbose")]
pub quiet: usize, pub quiet: usize,
/// Unstable options. Intentionally not documented.
#[clap(short = 'Z', long)] #[clap(short = 'Z', long)]
pub unstable_options: Vec<UnstableOptions>, pub unstable_options: Vec<UnstableOptions>,
/// Override the image server with the one provided. Do not set this unless
/// you know what you're doing.
#[clap(long)] #[clap(long)]
pub override_upstream: Option<Url>, pub override_upstream: Option<Url>,
/// Enables ephemeral disk encryption. Items written to disk are first /// Enables ephemeral disk encryption. Items written to disk are first
@ -342,17 +69,6 @@ struct CliArgs {
/// performance, privacy, and usability with this flag enabled. /// performance, privacy, and usability with this flag enabled.
#[clap(short, long)] #[clap(short, long)]
pub ephemeral_disk_encryption: bool, pub ephemeral_disk_encryption: bool,
/// The path to the config file. Default value is `./settings.yaml`.
#[clap(short, long)]
pub config_path: Option<PathBuf>,
/// Whether to use an in-memory cache in addition to the disk cache. Default
/// value is "on_disk", other options are "lfu", "lru", and "redis".
#[clap(short = 't', long)]
pub cache_type: Option<CacheType>,
/// Whether or not to use a proxy for upstream requests. This affects all
/// requests except for the shutdown request.
#[clap(short = 'P', long)]
pub proxy: Option<Url>,
} }
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
@ -361,6 +77,10 @@ pub enum UnstableOptions {
/// you know what you're dealing with. /// you know what you're dealing with.
OverrideUpstream, OverrideUpstream,
/// Use an LFU implementation for the in-memory cache instead of the default
/// LRU implementation.
UseLfu,
/// Disables token validation. Don't use this unless you know the /// Disables token validation. Don't use this unless you know the
/// ramifications of this command. /// ramifications of this command.
DisableTokenValidation, DisableTokenValidation,
@ -370,10 +90,6 @@ pub enum UnstableOptions {
/// Serves HTTP in plaintext /// Serves HTTP in plaintext
DisableTls, DisableTls,
/// Disable certificate validation. Only useful for debugging with a MITM
/// proxy
DisableCertValidation,
} }
impl FromStr for UnstableOptions { impl FromStr for UnstableOptions {
@ -382,10 +98,10 @@ impl FromStr for UnstableOptions {
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
match s { match s {
"override-upstream" => Ok(Self::OverrideUpstream), "override-upstream" => Ok(Self::OverrideUpstream),
"use-lfu" => Ok(Self::UseLfu),
"disable-token-validation" => Ok(Self::DisableTokenValidation), "disable-token-validation" => Ok(Self::DisableTokenValidation),
"offline-mode" => Ok(Self::OfflineMode), "offline-mode" => Ok(Self::OfflineMode),
"disable-tls" => Ok(Self::DisableTls), "disable-tls" => Ok(Self::DisableTls),
"disable-cert-validation" => Ok(Self::DisableCertValidation),
_ => Err(format!("Unknown unstable option '{}'", s)), _ => Err(format!("Unknown unstable option '{}'", s)),
} }
} }
@ -395,85 +111,10 @@ impl Display for UnstableOptions {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self { match self {
Self::OverrideUpstream => write!(f, "override-upstream"), Self::OverrideUpstream => write!(f, "override-upstream"),
Self::UseLfu => write!(f, "use-lfu"),
Self::DisableTokenValidation => write!(f, "disable-token-validation"), Self::DisableTokenValidation => write!(f, "disable-token-validation"),
Self::OfflineMode => write!(f, "offline-mode"), Self::OfflineMode => write!(f, "offline-mode"),
Self::DisableTls => write!(f, "disable-tls"), Self::DisableTls => write!(f, "disable-tls"),
Self::DisableCertValidation => write!(f, "disable-cert-validation"),
} }
} }
} }
#[cfg(test)]
mod sample_yaml {
use crate::config::YamlArgs;
#[test]
fn parses() {
assert!(serde_yaml::from_str::<YamlArgs>(include_str!("../settings.sample.yaml")).is_ok());
}
}
#[cfg(test)]
mod config {
use std::path::PathBuf;
use log::LevelFilter;
use tracing::level_filters::LevelFilter as TracingLevelFilter;
use crate::config::{CacheType, ClientSecret, Config, YamlExtendedOptions, YamlServerSettings};
use crate::units::{KilobitsPerSecond, Mebibytes, Port};
use super::{CliArgs, YamlArgs};
#[test]
fn cli_has_priority() {
let cli_config = CliArgs {
port: Port::new(1234),
memory_quota: Some(Mebibytes::new(10)),
disk_quota: Some(Mebibytes::new(10)),
cache_path: Some(PathBuf::from("a")),
network_speed: KilobitsPerSecond::new(10),
verbose: 1,
quiet: 0,
unstable_options: vec![],
override_upstream: None,
ephemeral_disk_encryption: true,
config_path: None,
cache_type: Some(CacheType::Lfu),
proxy: None,
};
let yaml_args = YamlArgs {
max_cache_size_in_mebibytes: Mebibytes::new(50),
server_settings: YamlServerSettings {
secret: ClientSecret(String::new()),
port: Port::new(4321).expect("to work?"),
external_max_kilobits_per_second: KilobitsPerSecond::new(50).expect("to work?"),
external_port: None,
graceful_shutdown_wait_seconds: None,
hostname: None,
external_ip: None,
},
metric_settings: None,
extended_options: Some(YamlExtendedOptions {
memory_quota: Some(Mebibytes::new(50)),
cache_type: Some(CacheType::Lru),
ephemeral_disk_encryption: Some(false),
enable_metrics: None,
logging_level: Some(LevelFilter::Error),
cache_path: Some(PathBuf::from("b")),
redis_url: None,
}),
};
let config = Config::from_cli_and_file(cli_config, yaml_args);
assert_eq!(Some(config.port), Port::new(1234));
assert_eq!(config.memory_quota, Mebibytes::new(10));
assert_eq!(config.disk_quota, Mebibytes::new(10));
assert_eq!(config.cache_path, PathBuf::from("a"));
assert_eq!(Some(config.network_speed), KilobitsPerSecond::new(10));
assert_eq!(config.log_level, TracingLevelFilter::DEBUG);
assert_eq!(config.ephemeral_disk_encryption, true);
assert_eq!(config.cache_type, CacheType::Lfu);
}
}

View file

@ -2,52 +2,50 @@
// We're end users, so these is ok // We're end users, so these is ok
#![allow(clippy::module_name_repetitions)] #![allow(clippy::module_name_repetitions)]
use std::env::VarError; use std::env::{self, VarError};
use std::error::Error; use std::error::Error;
use std::fmt::Display; use std::fmt::Display;
use std::net::SocketAddr; use std::hint::unreachable_unchecked;
use std::num::ParseIntError; use std::num::{NonZeroU64, ParseIntError};
use std::str::FromStr; use std::process;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use actix_web::dev::Service;
use actix_web::rt::{spawn, time, System}; use actix_web::rt::{spawn, time, System};
use actix_web::web::{self, Data}; use actix_web::web::{self, Data};
use actix_web::{App, HttpResponse, HttpServer}; use actix_web::{App, HttpResponse, HttpServer};
use cache::{Cache, DiskCache}; use cache::{Cache, DiskCache};
use chacha20::Key; use clap::Clap;
use config::Config; use config::CliArgs;
use maxminddb::geoip2; use log::{debug, error, info, warn, LevelFilter};
use parking_lot::RwLock; use parking_lot::RwLock;
use redis::Client as RedisClient; use rustls::{NoClientAuth, ServerConfig};
use simple_logger::SimpleLogger;
use rustls::server::NoClientAuth; use sodiumoxide::crypto::secretstream::gen_key;
use rustls::ServerConfig;
use sodiumoxide::crypto::stream::xchacha20::gen_key;
use state::{RwLockServerState, ServerState}; use state::{RwLockServerState, ServerState};
use stop::send_stop; use stop::send_stop;
use thiserror::Error; use thiserror::Error;
use tracing::{debug, error, info, warn};
use crate::cache::mem::{Lfu, Lru}; use crate::cache::mem::{Lfu, Lru};
use crate::cache::{MemoryCache, ENCRYPTION_KEY}; use crate::cache::{MemoryCache, ENCRYPTION_KEY};
use crate::config::{CacheType, UnstableOptions, OFFLINE_MODE}; use crate::config::{UnstableOptions, OFFLINE_MODE};
use crate::metrics::{record_country_visit, GEOIP_DATABASE};
use crate::state::DynamicServerCert; use crate::state::DynamicServerCert;
mod cache; mod cache;
mod client;
mod config; mod config;
mod metrics; mod metrics;
mod ping; mod ping;
mod routes; mod routes;
mod state; mod state;
mod stop; mod stop;
mod units;
const CLIENT_API_VERSION: usize = 31; #[macro_export]
macro_rules! client_api_version {
() => {
"31"
};
}
#[derive(Error, Debug)] #[derive(Error, Debug)]
enum ServerError { enum ServerError {
@ -67,76 +65,77 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Config loading // Config loading
// //
let config = match config::load_config() { let cli_args = CliArgs::parse();
Ok(c) => c, let port = cli_args.port;
Err(e) => { let memory_max_size = cli_args
eprintln!("{}", e); .memory_quota
return Err(Box::new(e) as Box<_>); .map(NonZeroU64::get)
} .unwrap_or_default();
}; let disk_quota = cli_args.disk_quota;
let cache_path = cli_args.cache_path.clone();
let memory_quota = config.memory_quota; let low_mem_mode = cli_args.low_memory;
let disk_quota = config.disk_quota; let use_lfu = cli_args.unstable_options.contains(&UnstableOptions::UseLfu);
let cache_type = config.cache_type; let disable_tls = cli_args
let cache_path = config.cache_path.clone();
let disable_tls = config
.unstable_options .unstable_options
.contains(&UnstableOptions::DisableTls); .contains(&UnstableOptions::DisableTls);
let bind_address = config.bind_address; OFFLINE_MODE.store(
let redis_url = config.redis_url.clone(); cli_args
.unstable_options
.contains(&UnstableOptions::OfflineMode),
Ordering::Release,
);
// //
// Logging and warnings // Logging and warnings
// //
tracing_subscriber::fmt() let log_level = match (cli_args.quiet, cli_args.verbose) {
.with_max_level(config.log_level) (n, _) if n > 2 => LevelFilter::Off,
.init(); (2, _) => LevelFilter::Error,
(1, _) => LevelFilter::Warn,
(0, 0) => LevelFilter::Info,
(_, 1) => LevelFilter::Debug,
(_, n) if n > 1 => LevelFilter::Trace,
// compiler can't figure it out
_ => unsafe { unreachable_unchecked() },
};
if let Err(e) = print_preamble_and_warnings(&config) { SimpleLogger::new().with_level(log_level).init()?;
if let Err(e) = print_preamble_and_warnings(&cli_args) {
error!("{}", e); error!("{}", e);
return Err(e); return Err(e);
} }
debug!("{:?}", &config); let client_secret = if let Ok(v) = env::var("CLIENT_SECRET") {
v
} else {
error!("Client secret not found in ENV. Please set CLIENT_SECRET.");
process::exit(1);
};
let client_secret_1 = client_secret.clone();
let client_secret = config.client_secret.clone(); if cli_args.ephemeral_disk_encryption {
let client_secret_1 = config.client_secret.clone();
if config.ephemeral_disk_encryption {
info!("Running with at-rest encryption!"); info!("Running with at-rest encryption!");
ENCRYPTION_KEY ENCRYPTION_KEY.set(gen_key()).unwrap();
.set(*Key::from_slice(gen_key().as_ref()))
.unwrap();
} }
if config.enable_metrics { metrics::init();
metrics::init();
}
if let Some(key) = config.geoip_license_key.clone() {
if let Err(e) = metrics::load_geo_ip_data(key).await {
error!("Failed to initialize geo ip db: {}", e);
}
}
// HTTP Server init // HTTP Server init
// Try bind to provided port first
let port_reservation = std::net::TcpListener::bind(bind_address);
if let Err(e) = port_reservation {
error!("Failed to bind to port!");
return Err(e.into());
};
let server = if OFFLINE_MODE.load(Ordering::Acquire) { let server = if OFFLINE_MODE.load(Ordering::Acquire) {
ServerState::init_offline() ServerState::init_offline()
} else { } else {
ServerState::init(&client_secret, &config).await? ServerState::init(&client_secret, &cli_args).await?
}; };
let data_0 = Arc::new(RwLockServerState(RwLock::new(server))); let data_0 = Arc::new(RwLockServerState(RwLock::new(server)));
let data_1 = Arc::clone(&data_0); let data_1 = Arc::clone(&data_0);
// What's nice is that Rustls only supports TLS 1.2 and 1.3.
let mut tls_config = ServerConfig::new(NoClientAuth::new());
tls_config.cert_resolver = Arc::new(DynamicServerCert);
// //
// At this point, the server is ready to start, and starts the necessary // At this point, the server is ready to start, and starts the necessary
// threads. // threads.
@ -156,7 +155,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
send_stop(&client_secret).await; send_stop(&client_secret).await;
} else { } else {
warn!("Got second Ctrl-C, forcefully exiting"); warn!("Got second Ctrl-C, forcefully exiting");
system.stop(); system.stop()
} }
}); });
} }
@ -172,25 +171,18 @@ async fn main() -> Result<(), Box<dyn Error>> {
loop { loop {
interval.tick().await; interval.tick().await;
debug!("Sending ping!"); debug!("Sending ping!");
ping::update_server_state(&client_secret_1, &config, &mut data).await; ping::update_server_state(&client_secret_1, &cli_args, &mut data).await;
} }
}); });
} }
let memory_max_size = memory_quota.into(); let cache = DiskCache::new(disk_quota, cache_path.clone()).await;
let cache = DiskCache::new(disk_quota.into(), cache_path.clone()).await; let cache: Arc<dyn Cache> = if low_mem_mode {
let cache: Arc<dyn Cache> = match cache_type { cache
CacheType::OnDisk => cache, } else if use_lfu {
CacheType::Lru => MemoryCache::<Lfu, _>::new(cache, memory_max_size), MemoryCache::<Lfu, _>::new(cache, memory_max_size).await
CacheType::Lfu => MemoryCache::<Lru, _>::new(cache, memory_max_size), } else {
CacheType::Redis => { MemoryCache::<Lru, _>::new(cache, memory_max_size).await
let url = redis_url.unwrap_or_else(|| {
url::Url::parse("redis://127.0.0.1/").expect("default redis url to be parsable")
});
info!("Trying to connect to redis instance at {}", url);
let mem_cache = RedisClient::open(url)?;
Arc::new(MemoryCache::new_with_cache(cache, mem_cache))
}
}; };
let cache_0 = Arc::clone(&cache); let cache_0 = Arc::clone(&cache);
@ -198,23 +190,6 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Start HTTPS server // Start HTTPS server
let server = HttpServer::new(move || { let server = HttpServer::new(move || {
App::new() App::new()
.wrap_fn(|req, srv| {
if let Some(reader) = GEOIP_DATABASE.get() {
let maybe_country = req
.connection_info()
.realip_remote_addr()
.map(SocketAddr::from_str)
.and_then(Result::ok)
.as_ref()
.map(SocketAddr::ip)
.map(|ip| reader.lookup::<geoip2::Country>(ip))
.and_then(Result::ok);
record_country_visit(maybe_country);
}
srv.call(req)
})
.service(routes::index) .service(routes::index)
.service(routes::token_data) .service(routes::token_data)
.service(routes::token_data_saver) .service(routes::token_data_saver)
@ -233,25 +208,18 @@ async fn main() -> Result<(), Box<dyn Error>> {
}) })
.shutdown_timeout(60); .shutdown_timeout(60);
// drop port reservation, might have a TOCTOU but it's not a big deal; this
// is just a best effort.
std::mem::drop(port_reservation);
if disable_tls { if disable_tls {
server.bind(bind_address)?.run().await?; server.bind(format!("0.0.0.0:{}", port))?.run().await?;
} else { } else {
// Rustls only supports TLS 1.2 and 1.3. server
let tls_config = ServerConfig::builder() .bind_rustls(format!("0.0.0.0:{}", port), tls_config)?
.with_safe_defaults() .run()
.with_client_cert_verifier(NoClientAuth::new()) .await?;
.with_cert_resolver(Arc::new(DynamicServerCert));
server.bind_rustls(bind_address, tls_config)?.run().await?;
} }
// Waiting for us to finish sending stop message // Waiting for us to finish sending stop message
while running.load(Ordering::SeqCst) { while running.load(Ordering::SeqCst) {
tokio::time::sleep(Duration::from_millis(250)).await; std::thread::sleep(Duration::from_millis(250));
} }
Ok(()) Ok(())
@ -262,7 +230,6 @@ enum InvalidCombination {
MissingUnstableOption(&'static str, UnstableOptions), MissingUnstableOption(&'static str, UnstableOptions),
} }
#[cfg(not(tarpaulin_include))]
impl Display for InvalidCombination { impl Display for InvalidCombination {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
@ -279,38 +246,32 @@ impl Display for InvalidCombination {
impl Error for InvalidCombination {} impl Error for InvalidCombination {}
#[cfg(not(tarpaulin_include))] fn print_preamble_and_warnings(args: &CliArgs) -> Result<(), Box<dyn Error>> {
#[allow(clippy::cognitive_complexity)] println!(concat!(
fn print_preamble_and_warnings(args: &Config) -> Result<(), Box<dyn Error>> { env!("CARGO_PKG_NAME"),
let build_string = option_env!("VERGEN_GIT_SHA_SHORT") " ",
.map(|git_sha| format!(" ({})", git_sha)) env!("CARGO_PKG_VERSION"),
.unwrap_or_default(); " (",
env!("VERGEN_GIT_SHA_SHORT"),
println!( ")",
concat!( " Copyright (C) 2021 ",
env!("CARGO_PKG_NAME"), env!("CARGO_PKG_AUTHORS"),
" ", "\n\n",
env!("CARGO_PKG_VERSION"), env!("CARGO_PKG_NAME"),
"{} Copyright (C) 2021 ", " is free software: you can redistribute it and/or modify\n\
env!("CARGO_PKG_AUTHORS"), it under the terms of the GNU General Public License as published by\n\
"\n\n", the Free Software Foundation, either version 3 of the License, or\n\
env!("CARGO_PKG_NAME"), (at your option) any later version.\n\n",
" is free software: you can redistribute it and/or modify\n\ env!("CARGO_PKG_NAME"),
it under the terms of the GNU General Public License as published by\n\ " is distributed in the hope that it will be useful,\n\
the Free Software Foundation, either version 3 of the License, or\n\ but WITHOUT ANY WARRANTY; without even the implied warranty of\n\
(at your option) any later version.\n\n", MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n\
env!("CARGO_PKG_NAME"), GNU General Public License for more details.\n\n\
" is distributed in the hope that it will be useful,\n\ You should have received a copy of the GNU General Public License\n\
but WITHOUT ANY WARRANTY; without even the implied warranty of\n\ along with ",
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n\ env!("CARGO_PKG_NAME"),
GNU General Public License for more details.\n\n\ ". If not, see <https://www.gnu.org/licenses/>.\n"
You should have received a copy of the GNU General Public License\n\ ));
along with ",
env!("CARGO_PKG_NAME"),
". If not, see <https://www.gnu.org/licenses/>.\n"
),
build_string
);
if !args.unstable_options.is_empty() { if !args.unstable_options.is_empty() {
warn!("Unstable options are enabled. These options should not be used in production!"); warn!("Unstable options are enabled. These options should not be used in production!");
@ -327,13 +288,6 @@ fn print_preamble_and_warnings(args: &Config) -> Result<(), Box<dyn Error>> {
warn!("Serving insecure traffic! You better be running this for development only."); warn!("Serving insecure traffic! You better be running this for development only.");
} }
if args
.unstable_options
.contains(&UnstableOptions::DisableCertValidation)
{
error!("Cert validation disabled! You REALLY only better be debugging.");
}
if args.override_upstream.is_some() if args.override_upstream.is_some()
&& !args && !args
.unstable_options .unstable_options

View file

@ -1,31 +1,5 @@
#![cfg(not(tarpaulin_include))] use once_cell::sync::Lazy;
use prometheus::{register_int_counter, IntCounter};
use std::fs::metadata;
use std::hint::unreachable_unchecked;
use std::time::SystemTime;
use chrono::Duration;
use flate2::read::GzDecoder;
use maxminddb::geoip2::Country;
use once_cell::sync::{Lazy, OnceCell};
use prometheus::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec};
use tar::Archive;
use thiserror::Error;
use tracing::{debug, field::debug, info, warn};
use crate::client::HTTP_CLIENT;
use crate::config::ClientSecret;
pub static GEOIP_DATABASE: OnceCell<maxminddb::Reader<Vec<u8>>> = OnceCell::new();
static COUNTRY_VISIT_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"country_visits_total",
"The number of visits from a country",
&["country"]
)
.unwrap()
});
macro_rules! init_counters { macro_rules! init_counters {
($(($counter:ident, $ty:ty, $name:literal, $desc:literal),)*) => { ($(($counter:ident, $ty:ty, $name:literal, $desc:literal),)*) => {
@ -37,11 +11,7 @@ macro_rules! init_counters {
#[allow(clippy::shadow_unrelated)] #[allow(clippy::shadow_unrelated)]
pub fn init() { pub fn init() {
// These need to be called at least once, otherwise the macro never
// called and thus the metrics don't get logged
$(let _a = $counter.get();)* $(let _a = $counter.get();)*
init_other();
} }
}; };
} }
@ -50,13 +20,13 @@ init_counters!(
( (
CACHE_HIT_COUNTER, CACHE_HIT_COUNTER,
IntCounter, IntCounter,
"cache_hit_total", "cache_hit",
"The number of cache hits." "The number of cache hits."
), ),
( (
CACHE_MISS_COUNTER, CACHE_MISS_COUNTER,
IntCounter, IntCounter,
"cache_miss_total", "cache_miss",
"The number of cache misses." "The number of cache misses."
), ),
( (
@ -68,118 +38,19 @@ init_counters!(
( (
REQUESTS_DATA_COUNTER, REQUESTS_DATA_COUNTER,
IntCounter, IntCounter,
"requests_data_total", "requests_data",
"The number of requests served from the /data endpoint." "The number of requests served from the /data endpoint."
), ),
( (
REQUESTS_DATA_SAVER_COUNTER, REQUESTS_DATA_SAVER_COUNTER,
IntCounter, IntCounter,
"requests_data_saver_total", "requests_data_saver",
"The number of requests served from the /data-saver endpoint." "The number of requests served from the /data-saver endpoint."
), ),
( (
REQUESTS_OTHER_COUNTER, REQUESTS_OTHER_COUNTER,
IntCounter, IntCounter,
"requests_other_total", "requests_other",
"The total number of request not served by primary endpoints." "The total number of request not served by primary endpoints."
), ),
); );
// initialization for any other counters that aren't simple int counters
fn init_other() {
let _a = COUNTRY_VISIT_COUNTER.local();
}
#[derive(Error, Debug)]
pub enum DbLoadError {
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
MaxMindDb(#[from] maxminddb::MaxMindDBError),
}
pub async fn load_geo_ip_data(license_key: ClientSecret) -> Result<(), DbLoadError> {
const DB_PATH: &str = "./GeoLite2-Country.mmdb";
// Check date of db
let db_date_created = metadata(DB_PATH)
.ok()
.and_then(|metadata| {
if let Ok(time) = metadata.created() {
Some(time)
} else {
debug("fs didn't report birth time, fall back to last modified instead");
metadata.modified().ok()
}
})
.unwrap_or(SystemTime::UNIX_EPOCH);
let duration = if let Ok(time) = SystemTime::now().duration_since(db_date_created) {
Duration::from_std(time).expect("duration to fit")
} else {
warn!("Clock may have gone backwards?");
Duration::max_value()
};
// DB expired, fetch a new one
if duration > Duration::weeks(1) {
fetch_db(license_key).await?;
} else {
info!("Geo IP database isn't old enough, not updating.");
}
// Result literally cannot panic here, buuuuuut if it does we'll panic
GEOIP_DATABASE
.set(maxminddb::Reader::open_readfile(DB_PATH)?)
.map_err(|_| ()) // Need to map err here or can't expect
.expect("to set the geo ip db singleton");
Ok(())
}
async fn fetch_db(license_key: ClientSecret) -> Result<(), DbLoadError> {
let resp = HTTP_CLIENT
.inner()
.get(format!("https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-Country&license_key={}&suffix=tar.gz", license_key.as_str()))
.send()
.await?
.bytes()
.await?;
let mut decoder = Archive::new(GzDecoder::new(resp.as_ref()));
let mut decoded_paths: Vec<_> = decoder
.entries()?
.filter_map(Result::ok)
.filter_map(|mut entry| {
let path = entry.path().ok()?.to_path_buf();
let file_name = path.file_name()?;
if file_name != "GeoLite2-Country.mmdb" {
return None;
}
entry.unpack(file_name).ok()?;
Some(path)
})
.collect();
assert_eq!(decoded_paths.len(), 1);
let path = match decoded_paths.pop() {
Some(path) => path,
None => unsafe { unreachable_unchecked() },
};
debug!("Extracted {}", path.as_path().to_string_lossy());
Ok(())
}
pub fn record_country_visit(country: Option<Country>) {
let iso_code = country
.and_then(|country| country.country.and_then(|c| c.iso_code))
.unwrap_or("unknown");
COUNTRY_VISIT_COUNTER
.get_metric_with_label_values(&[iso_code])
.unwrap()
.inc();
}

View file

@ -1,71 +1,65 @@
use std::net::{IpAddr, SocketAddr}; use std::num::{NonZeroU16, NonZeroU64};
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::{io::BufReader, sync::Arc}; use std::{io::BufReader, sync::Arc};
use rustls::sign::{CertifiedKey, RsaSigningKey, SigningKey}; use log::{debug, error, info, warn};
use rustls::{Certificate, PrivateKey}; use rustls::internal::pemfile::{certs, rsa_private_keys};
use rustls_pemfile::{certs, rsa_private_keys}; use rustls::sign::{RSASigningKey, SigningKey};
use rustls::Certificate;
use serde::de::{MapAccess, Visitor}; use serde::de::{MapAccess, Visitor};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_repr::Deserialize_repr; use serde_repr::Deserialize_repr;
use sodiumoxide::crypto::box_::PrecomputedKey; use sodiumoxide::crypto::box_::PrecomputedKey;
use tracing::{debug, error, info, warn};
use url::Url; use url::Url;
use crate::client::HTTP_CLIENT; use crate::config::{CliArgs, VALIDATE_TOKENS};
use crate::config::{ClientSecret, Config};
use crate::state::{ use crate::state::{
RwLockServerState, CERTIFIED_KEY, PREVIOUSLY_COMPROMISED, PREVIOUSLY_PAUSED, RwLockServerState, PREVIOUSLY_COMPROMISED, PREVIOUSLY_PAUSED, TLS_CERTS,
TLS_PREVIOUSLY_CREATED, TLS_PREVIOUSLY_CREATED, TLS_SIGNING_KEY,
}; };
use crate::units::{Bytes, BytesPerSecond, Port}; use crate::{client_api_version, config::UnstableOptions};
use crate::CLIENT_API_VERSION;
pub const CONTROL_CENTER_PING_URL: &str = "https://api.mangadex.network/ping"; pub const CONTROL_CENTER_PING_URL: &str = "https://api.mangadex.network/ping";
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
pub struct Request<'a> { pub struct Request<'a> {
secret: &'a ClientSecret, secret: &'a str,
port: Port, port: NonZeroU16,
disk_space: Bytes, disk_space: u64,
network_speed: BytesPerSecond, network_speed: NonZeroU64,
build_version: usize, build_version: u64,
tls_created_at: Option<String>, tls_created_at: Option<String>,
ip_address: Option<IpAddr>,
} }
impl<'a> Request<'a> { impl<'a> Request<'a> {
fn from_config_and_state(secret: &'a ClientSecret, config: &Config) -> Self { fn from_config_and_state(secret: &'a str, config: &CliArgs) -> Self {
Self { Self {
secret, secret,
port: config port: config.port,
.external_address disk_space: config.disk_quota,
.and_then(|v| Port::new(v.port())) network_speed: config.network_speed,
.unwrap_or(config.port), build_version: client_api_version!()
disk_space: config.disk_quota.into(), .parse()
network_speed: config.network_speed.into(), .expect("to parse the build version"),
build_version: CLIENT_API_VERSION,
tls_created_at: TLS_PREVIOUSLY_CREATED tls_created_at: TLS_PREVIOUSLY_CREATED
.get() .get()
.map(|v| v.load().as_ref().clone()), .map(|v| v.load().as_ref().clone()),
ip_address: config.external_address.as_ref().map(SocketAddr::ip),
} }
} }
} }
impl<'a> From<(&'a ClientSecret, &Config)> for Request<'a> { #[allow(clippy::fallible_impl_from)]
fn from((secret, config): (&'a ClientSecret, &Config)) -> Self { impl<'a> From<(&'a str, &CliArgs)> for Request<'a> {
fn from((secret, config): (&'a str, &CliArgs)) -> Self {
Self { Self {
secret, secret,
port: config port: config.port,
.external_address disk_space: config.disk_quota,
.and_then(|v| Port::new(v.port())) network_speed: config.network_speed,
.unwrap_or(config.port), build_version: client_api_version!()
disk_space: config.disk_quota.into(), .parse()
network_speed: config.network_speed.into(), .expect("to parse the build version"),
build_version: CLIENT_API_VERSION,
tls_created_at: None, tls_created_at: None,
ip_address: config.external_address.as_ref().map(SocketAddr::ip),
} }
} }
} }
@ -73,7 +67,7 @@ impl<'a> From<(&'a ClientSecret, &Config)> for Request<'a> {
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
#[serde(untagged)] #[serde(untagged)]
pub enum Response { pub enum Response {
Ok(Box<OkResponse>), Ok(OkResponse),
Error(ErrorResponse), Error(ErrorResponse),
} }
@ -85,6 +79,8 @@ pub struct OkResponse {
pub token_key: Option<String>, pub token_key: Option<String>,
pub compromised: bool, pub compromised: bool,
pub paused: bool, pub paused: bool,
#[serde(default)]
pub force_tokens: bool,
pub tls: Option<Tls>, pub tls: Option<Tls>,
} }
@ -104,7 +100,7 @@ pub enum ErrorCode {
pub struct Tls { pub struct Tls {
pub created_at: String, pub created_at: String,
pub priv_key: Arc<RsaSigningKey>, pub priv_key: Arc<Box<dyn SigningKey>>,
pub certs: Vec<Certificate>, pub certs: Vec<Certificate>,
} }
@ -137,12 +133,11 @@ impl<'de> Deserialize<'de> for Tls {
priv_key = rsa_private_keys(&mut BufReader::new(value.as_bytes())) priv_key = rsa_private_keys(&mut BufReader::new(value.as_bytes()))
.ok() .ok()
.and_then(|mut v| { .and_then(|mut v| {
v.pop() v.pop().and_then(|key| RSASigningKey::new(&key).ok())
.and_then(|key| RsaSigningKey::new(&PrivateKey(key)).ok()) })
});
} }
"certificate" => { "certificate" => {
certificates = certs(&mut BufReader::new(value.as_bytes())).ok(); certificates = certs(&mut BufReader::new(value.as_bytes())).ok()
} }
_ => (), // Ignore extra fields _ => (), // Ignore extra fields
} }
@ -151,8 +146,8 @@ impl<'de> Deserialize<'de> for Tls {
match (created_at, priv_key, certificates) { match (created_at, priv_key, certificates) {
(Some(created_at), Some(priv_key), Some(certificates)) => Ok(Tls { (Some(created_at), Some(priv_key), Some(certificates)) => Ok(Tls {
created_at, created_at,
priv_key: Arc::new(priv_key), priv_key: Arc::new(Box::new(priv_key)),
certs: certificates.into_iter().map(Certificate).collect(), certs: certificates,
}), }),
_ => Err(serde::de::Error::custom("Could not deserialize tls info")), _ => Err(serde::de::Error::custom("Could not deserialize tls info")),
} }
@ -163,7 +158,6 @@ impl<'de> Deserialize<'de> for Tls {
} }
} }
#[cfg(not(tarpaulin_include))]
impl std::fmt::Debug for Tls { impl std::fmt::Debug for Tls {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tls") f.debug_struct("Tls")
@ -172,19 +166,10 @@ impl std::fmt::Debug for Tls {
} }
} }
pub async fn update_server_state( pub async fn update_server_state(secret: &str, cli: &CliArgs, data: &mut Arc<RwLockServerState>) {
secret: &ClientSecret,
cli: &Config,
data: &mut Arc<RwLockServerState>,
) {
let req = Request::from_config_and_state(secret, cli); let req = Request::from_config_and_state(secret, cli);
debug!("Sending ping request: {:?}", req); let client = reqwest::Client::new();
let resp = HTTP_CLIENT let resp = client.post(CONTROL_CENTER_PING_URL).json(&req).send().await;
.inner()
.post(CONTROL_CENTER_PING_URL)
.json(&req)
.send()
.await;
match resp { match resp {
Ok(resp) => match resp.json::<Response>().await { Ok(resp) => match resp.json::<Response>().await {
Ok(Response::Ok(resp)) => { Ok(Response::Ok(resp)) => {
@ -199,13 +184,27 @@ pub async fn update_server_state(
} }
if let Some(key) = resp.token_key { if let Some(key) = resp.token_key {
base64::decode(&key) if let Some(key) = base64::decode(&key)
.ok() .ok()
.and_then(|k| PrecomputedKey::from_slice(&k)) .and_then(|k| PrecomputedKey::from_slice(&k))
.map_or_else( {
|| error!("Failed to parse token key: got {}", key), write_guard.precomputed_key = key;
|key| write_guard.precomputed_key = key, } else {
); error!("Failed to parse token key: got {}", key);
}
}
if !cli
.unstable_options
.contains(&UnstableOptions::DisableTokenValidation)
&& VALIDATE_TOKENS.load(Ordering::Acquire) != resp.force_tokens
{
if resp.force_tokens {
info!("Client received command to enforce token validity.");
} else {
info!("Client received command to no longer enforce token validity");
}
VALIDATE_TOKENS.store(resp.force_tokens, Ordering::Release);
} }
if let Some(tls) = resp.tls { if let Some(tls) = resp.tls {
@ -213,12 +212,8 @@ pub async fn update_server_state(
.get() .get()
.unwrap() .unwrap()
.swap(Arc::new(tls.created_at)); .swap(Arc::new(tls.created_at));
CERTIFIED_KEY.store(Some(Arc::new(CertifiedKey { TLS_SIGNING_KEY.get().unwrap().swap(tls.priv_key);
cert: tls.certs.clone(), TLS_CERTS.get().unwrap().swap(Arc::new(tls.certs));
key: Arc::clone(&tls.priv_key) as Arc<dyn SigningKey>,
ocsp: None,
sct_list: None,
})));
} }
let previously_compromised = PREVIOUSLY_COMPROMISED.load(Ordering::Acquire); let previously_compromised = PREVIOUSLY_COMPROMISED.load(Ordering::Acquire);
@ -257,7 +252,7 @@ pub async fn update_server_state(
}, },
Err(e) => match e { Err(e) => match e {
e if e.is_timeout() => { e if e.is_timeout() => {
error!("Response timed out to control server. Is MangaDex down?"); error!("Response timed out to control server. Is MangaDex down?")
} }
e => warn!("Failed to send request: {}", e), e => warn!("Failed to send request: {}", e),
}, },

View file

@ -1,41 +1,62 @@
use std::hint::unreachable_unchecked;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::time::Duration;
use actix_web::body::BoxBody;
use actix_web::error::ErrorNotFound; use actix_web::error::ErrorNotFound;
use actix_web::http::header::{HeaderValue, CONTENT_LENGTH, CONTENT_TYPE, LAST_MODIFIED}; use actix_web::http::header::{
use actix_web::web::{Data, Path}; ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, CACHE_CONTROL, CONTENT_LENGTH,
CONTENT_TYPE, LAST_MODIFIED, X_CONTENT_TYPE_OPTIONS,
};
use actix_web::web::Path;
use actix_web::HttpResponseBuilder; use actix_web::HttpResponseBuilder;
use actix_web::{get, HttpRequest, HttpResponse, Responder}; use actix_web::{get, web::Data, HttpRequest, HttpResponse, Responder};
use base64::DecodeError; use base64::DecodeError;
use bytes::Bytes; use bytes::Bytes;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use futures::Stream; use futures::{Stream, TryStreamExt};
use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use prometheus::{Encoder, TextEncoder}; use prometheus::{Encoder, TextEncoder};
use reqwest::{Client, StatusCode};
use ring::signature::ECDSA_P256_SHA256_ASN1;
use serde::Deserialize; use serde::Deserialize;
use sodiumoxide::crypto::box_::{open_precomputed, Nonce, PrecomputedKey, NONCEBYTES}; use sodiumoxide::crypto::box_::{open_precomputed, Nonce, PrecomputedKey, NONCEBYTES};
use thiserror::Error; use thiserror::Error;
use tracing::{debug, error, info, trace};
use crate::cache::{Cache, CacheKey, ImageMetadata, UpstreamError}; use crate::cache::{Cache, CacheKey, ImageMetadata, UpstreamError};
use crate::client::{FetchResult, DEFAULT_HEADERS, HTTP_CLIENT}; use crate::client_api_version;
use crate::config::{OFFLINE_MODE, VALIDATE_TOKENS}; use crate::config::{OFFLINE_MODE, SEND_SERVER_VERSION, VALIDATE_TOKENS};
use crate::metrics::{ use crate::metrics::{
CACHE_HIT_COUNTER, CACHE_MISS_COUNTER, REQUESTS_DATA_COUNTER, REQUESTS_DATA_SAVER_COUNTER, CACHE_HIT_COUNTER, CACHE_MISS_COUNTER, REQUESTS_DATA_COUNTER, REQUESTS_DATA_SAVER_COUNTER,
REQUESTS_OTHER_COUNTER, REQUESTS_TOTAL_COUNTER, REQUESTS_OTHER_COUNTER, REQUESTS_TOTAL_COUNTER,
}; };
use crate::state::RwLockServerState; use crate::state::RwLockServerState;
const BASE64_CONFIG: base64::Config = base64::Config::new(base64::CharacterSet::UrlSafe, false); pub const BASE64_CONFIG: base64::Config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
pub enum ServerResponse { static HTTP_CLIENT: Lazy<Client> = Lazy::new(|| {
Client::builder()
.pool_idle_timeout(Duration::from_secs(180))
.https_only(true)
.http2_prior_knowledge()
.build()
.expect("Client initialization to work")
});
const SERVER_ID_STRING: &str = concat!(
env!("CARGO_CRATE_NAME"),
" ",
env!("CARGO_PKG_VERSION"),
" (",
client_api_version!(),
") - Conforming to spec revision b82043289",
);
enum ServerResponse {
TokenValidationError(TokenValidationError), TokenValidationError(TokenValidationError),
HttpResponse(HttpResponse), HttpResponse(HttpResponse),
} }
impl Responder for ServerResponse { impl Responder for ServerResponse {
type Body = BoxBody;
#[inline] #[inline]
fn respond_to(self, req: &HttpRequest) -> HttpResponse { fn respond_to(self, req: &HttpRequest) -> HttpResponse {
match self { match self {
@ -48,12 +69,12 @@ impl Responder for ServerResponse {
} }
} }
#[allow(clippy::unused_async)]
#[get("/")] #[get("/")]
async fn index() -> impl Responder { async fn index() -> impl Responder {
HttpResponse::Ok().body(include_str!("index.html")) HttpResponse::Ok().body(include_str!("index.html"))
} }
#[allow(clippy::future_not_send)]
#[get("/{token}/data/{chapter_hash}/{file_name}")] #[get("/{token}/data/{chapter_hash}/{file_name}")]
async fn token_data( async fn token_data(
state: Data<RwLockServerState>, state: Data<RwLockServerState>,
@ -70,6 +91,7 @@ async fn token_data(
fetch_image(state, cache, chapter_hash, file_name, false).await fetch_image(state, cache, chapter_hash, file_name, false).await
} }
#[allow(clippy::future_not_send)]
#[get("/{token}/data-saver/{chapter_hash}/{file_name}")] #[get("/{token}/data-saver/{chapter_hash}/{file_name}")]
async fn token_data_saver( async fn token_data_saver(
state: Data<RwLockServerState>, state: Data<RwLockServerState>,
@ -105,43 +127,42 @@ pub async fn default(state: Data<RwLockServerState>, req: HttpRequest) -> impl R
info!("Got unknown path, just proxying: {}", path); info!("Got unknown path, just proxying: {}", path);
let mut resp = match HTTP_CLIENT.inner().get(path).send().await { let resp = match HTTP_CLIENT.get(path).send().await {
Ok(resp) => resp, Ok(resp) => resp,
Err(e) => { Err(e) => {
error!("{}", e); error!("{}", e);
return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish()); return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish());
} }
}; };
let content_type = resp.headers_mut().remove(CONTENT_TYPE); let content_type = resp.headers().get(CONTENT_TYPE);
let mut resp_builder = HttpResponseBuilder::new(resp.status()); let mut resp_builder = HttpResponseBuilder::new(resp.status());
let mut headers = DEFAULT_HEADERS.clone();
if let Some(content_type) = content_type { if let Some(content_type) = content_type {
headers.insert(CONTENT_TYPE, content_type); resp_builder.insert_header((CONTENT_TYPE, content_type));
} }
// push_headers(&mut resp_builder); push_headers(&mut resp_builder);
let mut resp = resp_builder.body(resp.bytes().await.unwrap_or_default()); ServerResponse::HttpResponse(resp_builder.body(resp.bytes().await.unwrap_or_default()))
*resp.headers_mut() = headers;
ServerResponse::HttpResponse(resp)
} }
#[allow(clippy::unused_async)] #[allow(clippy::future_not_send)]
#[get("/prometheus")] #[get("/metrics")]
pub async fn metrics() -> impl Responder { pub async fn metrics() -> impl Responder {
let metric_families = prometheus::gather(); let metric_families = prometheus::gather();
let mut buffer = Vec::new(); let mut buffer = Vec::new();
TextEncoder::new() TextEncoder::new()
.encode(&metric_families, &mut buffer) .encode(&metric_families, &mut buffer)
.expect("Should never have an io error writing to a vec"); .unwrap();
String::from_utf8(buffer).expect("Text encoder should render valid utf-8") String::from_utf8(buffer).unwrap()
} }
#[derive(Error, Debug, PartialEq, Eq)] #[derive(Error, Debug)]
pub enum TokenValidationError { enum TokenValidationError {
#[error("Failed to decode base64 token.")] #[error("Failed to decode base64 token.")]
DecodeError(#[from] DecodeError), DecodeError(#[from] DecodeError),
#[error("Nonce was too short.")] #[error("Nonce was too short.")]
IncompleteNonce, IncompleteNonce,
#[error("Invalid nonce.")]
InvalidNonce,
#[error("Decryption failed")] #[error("Decryption failed")]
DecryptionFailure, DecryptionFailure,
#[error("The token format was invalid.")] #[error("The token format was invalid.")]
@ -150,16 +171,14 @@ pub enum TokenValidationError {
TokenExpired, TokenExpired,
#[error("Invalid chapter hash.")] #[error("Invalid chapter hash.")]
InvalidChapterHash, InvalidChapterHash,
#[error("Invalid v32 format")]
InvalidV32Format,
} }
impl Responder for TokenValidationError { impl Responder for TokenValidationError {
type Body = BoxBody;
#[inline] #[inline]
fn respond_to(self, _: &HttpRequest) -> HttpResponse { fn respond_to(self, _: &HttpRequest) -> HttpResponse {
let mut resp = HttpResponse::Forbidden().finish(); push_headers(&mut HttpResponse::Forbidden()).finish()
*resp.headers_mut() = DEFAULT_HEADERS.clone();
resp
} }
} }
@ -181,11 +200,7 @@ fn validate_token(
let (nonce, encrypted) = data.split_at(NONCEBYTES); let (nonce, encrypted) = data.split_at(NONCEBYTES);
let nonce = match Nonce::from_slice(nonce) { let nonce = Nonce::from_slice(nonce).ok_or(TokenValidationError::InvalidNonce)?;
Some(nonce) => nonce,
// We split at NONCEBYTES, so this should never happen.
None => unsafe { unreachable_unchecked() },
};
let decrypted = open_precomputed(encrypted, &nonce, precomputed_key) let decrypted = open_precomputed(encrypted, &nonce, precomputed_key)
.map_err(|_| TokenValidationError::DecryptionFailure)?; .map_err(|_| TokenValidationError::DecryptionFailure)?;
@ -205,6 +220,51 @@ fn validate_token(
Ok(()) Ok(())
} }
fn validate_token_v32(pub_key: &[u8], token: String) -> Result<(), TokenValidationError> {
#[derive(Deserialize)]
struct Token<'a> {
expires: DateTime<Utc>,
client_id: &'a str,
}
let (token_base64, sig_base64) = token
.split_once('~')
.ok_or_else(|| TokenValidationError::InvalidV32Format)?;
let token = base64::decode_config(token_base64, BASE64_CONFIG)?;
let sig = base64::decode_config(sig_base64, BASE64_CONFIG)?;
ring::signature::UnparsedPublicKey::new(&ECDSA_P256_SHA256_ASN1, pub_key)
.verify(&token, &sig)
.map_err(|_| TokenValidationError::DecryptionFailure)?;
// At this point, token has a valid signature, now to check token fields
let token: Token =
serde_json::from_slice(&token).map_err(|_| TokenValidationError::InvalidToken)?;
if token.expires < Utc::now() {
return Err(TokenValidationError::TokenExpired);
}
Ok(())
}
#[inline]
fn push_headers(builder: &mut HttpResponseBuilder) -> &mut HttpResponseBuilder {
builder
.insert_header((X_CONTENT_TYPE_OPTIONS, "nosniff"))
.insert_header((ACCESS_CONTROL_ALLOW_ORIGIN, "https://mangadex.org"))
.insert_header((ACCESS_CONTROL_EXPOSE_HEADERS, "*"))
.insert_header((CACHE_CONTROL, "public, max-age=1209600"))
.insert_header(("Timing-Allow-Origin", "https://mangadex.org"));
if SEND_SERVER_VERSION.load(Ordering::Acquire) {
builder.insert_header(("Server", SERVER_ID_STRING));
}
builder
}
#[allow(clippy::future_not_send)] #[allow(clippy::future_not_send)]
async fn fetch_image( async fn fetch_image(
state: Data<RwLockServerState>, state: Data<RwLockServerState>,
@ -223,7 +283,7 @@ async fn fetch_image(
Some(Err(_)) => { Some(Err(_)) => {
return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish()); return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish());
} }
None => (), _ => (),
} }
CACHE_MISS_COUNTER.inc(); CACHE_MISS_COUNTER.inc();
@ -235,217 +295,112 @@ async fn fetch_image(
); );
} }
let url = if is_data_saver { // It's important to not get a write lock before this request, else we're
format!( // holding the read lock until the await resolves.
"{}/data-saver/{}/{}",
state.0.read().image_server,
&key.0,
&key.1,
)
} else {
format!("{}/data/{}/{}", state.0.read().image_server, &key.0, &key.1)
};
match HTTP_CLIENT.fetch_and_cache(url, key, cache).await { let resp = if is_data_saver {
FetchResult::ServiceUnavailable => { HTTP_CLIENT
ServerResponse::HttpResponse(HttpResponse::ServiceUnavailable().finish()) .get(format!(
"{}/data-saver/{}/{}",
state.0.read().image_server,
&key.0,
&key.1
))
.send()
} else {
HTTP_CLIENT
.get(format!(
"{}/data/{}/{}",
state.0.read().image_server,
&key.0,
&key.1
))
.send()
}
.await;
match resp {
Ok(mut resp) => {
let content_type = resp.headers().get(CONTENT_TYPE);
let is_image = content_type
.map(|v| String::from_utf8_lossy(v.as_ref()).contains("image/"))
.unwrap_or_default();
if resp.status() != StatusCode::OK || !is_image {
warn!(
"Got non-OK or non-image response code from upstream, proxying and not caching result.",
);
let mut resp_builder = HttpResponseBuilder::new(resp.status());
if let Some(content_type) = content_type {
resp_builder.insert_header((CONTENT_TYPE, content_type));
}
push_headers(&mut resp_builder);
return ServerResponse::HttpResponse(
resp_builder.body(resp.bytes().await.unwrap_or_default()),
);
}
let (content_type, length, last_mod) = {
let headers = resp.headers_mut();
(
headers.remove(CONTENT_TYPE),
headers.remove(CONTENT_LENGTH),
headers.remove(LAST_MODIFIED),
)
};
let body = resp.bytes_stream().map_err(|e| e.into());
debug!("Inserting into cache");
let metadata = ImageMetadata::new(content_type, length, last_mod).unwrap();
let stream = {
match cache.put(key, Box::new(body), metadata).await {
Ok(stream) => stream,
Err(e) => {
warn!("Failed to insert into cache: {}", e);
return ServerResponse::HttpResponse(
HttpResponse::InternalServerError().finish(),
);
}
}
};
debug!("Done putting into cache");
construct_response(stream, &metadata)
} }
FetchResult::InternalServerError => { Err(e) => {
ServerResponse::HttpResponse(HttpResponse::InternalServerError().finish()) error!("Failed to fetch image from server: {}", e);
ServerResponse::HttpResponse(
push_headers(&mut HttpResponse::ServiceUnavailable()).finish(),
)
} }
FetchResult::Data(status, headers, data) => {
let mut resp = HttpResponseBuilder::new(status);
let mut resp = resp.body(data);
*resp.headers_mut() = headers;
ServerResponse::HttpResponse(resp)
}
FetchResult::Processing => panic!("Race condition found with fetch result"),
} }
} }
#[inline] fn construct_response(
pub fn construct_response(
data: impl Stream<Item = Result<Bytes, UpstreamError>> + Unpin + 'static, data: impl Stream<Item = Result<Bytes, UpstreamError>> + Unpin + 'static,
metadata: &ImageMetadata, metadata: &ImageMetadata,
) -> ServerResponse { ) -> ServerResponse {
trace!("Constructing response"); trace!("Constructing response");
let mut resp = HttpResponse::Ok(); let mut resp = HttpResponse::Ok();
let mut headers = DEFAULT_HEADERS.clone();
if let Some(content_type) = metadata.content_type { if let Some(content_type) = metadata.content_type {
headers.insert( resp.append_header((CONTENT_TYPE, content_type.as_ref()));
CONTENT_TYPE,
HeaderValue::from_str(content_type.as_ref()).unwrap(),
);
} }
if let Some(content_length) = metadata.content_length { if let Some(content_length) = metadata.content_length {
headers.insert(CONTENT_LENGTH, HeaderValue::from(content_length)); resp.append_header((CONTENT_LENGTH, content_length));
} }
if let Some(last_modified) = metadata.last_modified { if let Some(last_modified) = metadata.last_modified {
headers.insert( resp.append_header((LAST_MODIFIED, last_modified.to_rfc2822()));
LAST_MODIFIED,
HeaderValue::from_str(&last_modified.to_rfc2822()).unwrap(),
);
} }
let mut ret = resp.streaming(data); ServerResponse::HttpResponse(push_headers(&mut resp).streaming(data))
*ret.headers_mut() = headers;
ServerResponse::HttpResponse(ret)
}
#[cfg(test)]
mod token_validation {
use super::{BASE64_CONFIG, DecodeError, PrecomputedKey, TokenValidationError, Utc, validate_token};
use sodiumoxide::crypto::box_::precompute;
use sodiumoxide::crypto::box_::seal_precomputed;
use sodiumoxide::crypto::box_::{gen_keypair, gen_nonce, PRECOMPUTEDKEYBYTES};
#[test]
fn invalid_base64() {
let res = validate_token(
&PrecomputedKey::from_slice(&b"1".repeat(PRECOMPUTEDKEYBYTES))
.expect("valid test token"),
"a".to_string(),
"b",
);
assert_eq!(
res,
Err(TokenValidationError::DecodeError(
DecodeError::InvalidLength
))
);
}
#[test]
fn not_long_enough_for_nonce() {
let res = validate_token(
&PrecomputedKey::from_slice(&b"1".repeat(PRECOMPUTEDKEYBYTES))
.expect("valid test token"),
"aGVsbG8gaW50ZXJuZXR-Cg==".to_string(),
"b",
);
assert_eq!(res, Err(TokenValidationError::IncompleteNonce));
}
#[test]
fn invalid_precomputed_key() {
let precomputed_1 = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let precomputed_2 = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
// Seal with precomputed_2, open with precomputed_1
let data = seal_precomputed(b"hello world", &nonce, &precomputed_2);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed_1, data, "b");
assert_eq!(res, Err(TokenValidationError::DecryptionFailure));
}
#[test]
fn invalid_token_data() {
let precomputed = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
let data = seal_precomputed(b"hello world", &nonce, &precomputed);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed, data, "b");
assert_eq!(res, Err(TokenValidationError::InvalidToken));
}
#[test]
fn token_must_have_valid_expiration() {
let precomputed = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
let time = Utc::now() - chrono::Duration::weeks(1);
let data = seal_precomputed(
serde_json::json!({
"expires": time.to_rfc3339(),
"hash": "b",
})
.to_string()
.as_bytes(),
&nonce,
&precomputed,
);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed, data, "b");
assert_eq!(res, Err(TokenValidationError::TokenExpired));
}
#[test]
fn token_must_have_valid_chapter_hash() {
let precomputed = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
let time = Utc::now() + chrono::Duration::weeks(1);
let data = seal_precomputed(
serde_json::json!({
"expires": time.to_rfc3339(),
"hash": "b",
})
.to_string()
.as_bytes(),
&nonce,
&precomputed,
);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed, data, "");
assert_eq!(res, Err(TokenValidationError::InvalidChapterHash));
}
#[test]
fn valid_token_returns_ok() {
let precomputed = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
let time = Utc::now() + chrono::Duration::weeks(1);
let data = seal_precomputed(
serde_json::json!({
"expires": time.to_rfc3339(),
"hash": "b",
})
.to_string()
.as_bytes(),
&nonce,
&precomputed,
);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed, data, "b");
assert!(res.is_ok());
}
} }

View file

@ -1,19 +1,17 @@
use std::str::FromStr; use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use crate::client::HTTP_CLIENT; use crate::config::{CliArgs, UnstableOptions, OFFLINE_MODE, SEND_SERVER_VERSION, VALIDATE_TOKENS};
use crate::config::{ClientSecret, Config, OFFLINE_MODE};
use crate::ping::{Request, Response, CONTROL_CENTER_PING_URL}; use crate::ping::{Request, Response, CONTROL_CENTER_PING_URL};
use arc_swap::{ArcSwap, ArcSwapOption}; use arc_swap::ArcSwap;
use log::{error, info, warn};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use parking_lot::RwLock; use parking_lot::RwLock;
use rustls::server::{ClientHello, ResolvesServerCert}; use rustls::sign::{CertifiedKey, SigningKey};
use rustls::sign::{CertifiedKey, RsaSigningKey, SigningKey};
use rustls::Certificate; use rustls::Certificate;
use rustls::{ClientHello, ResolvesServerCert};
use sodiumoxide::crypto::box_::{PrecomputedKey, PRECOMPUTEDKEYBYTES}; use sodiumoxide::crypto::box_::{PrecomputedKey, PRECOMPUTEDKEYBYTES};
use thiserror::Error; use thiserror::Error;
use tracing::{error, info, warn};
use url::Url; use url::Url;
pub struct ServerState { pub struct ServerState {
@ -27,10 +25,8 @@ pub static PREVIOUSLY_PAUSED: AtomicBool = AtomicBool::new(false);
pub static PREVIOUSLY_COMPROMISED: AtomicBool = AtomicBool::new(false); pub static PREVIOUSLY_COMPROMISED: AtomicBool = AtomicBool::new(false);
pub static TLS_PREVIOUSLY_CREATED: OnceCell<ArcSwap<String>> = OnceCell::new(); pub static TLS_PREVIOUSLY_CREATED: OnceCell<ArcSwap<String>> = OnceCell::new();
static TLS_SIGNING_KEY: OnceCell<ArcSwap<RsaSigningKey>> = OnceCell::new(); pub static TLS_SIGNING_KEY: OnceCell<ArcSwap<Box<dyn SigningKey>>> = OnceCell::new();
static TLS_CERTS: OnceCell<ArcSwap<Vec<Certificate>>> = OnceCell::new(); pub static TLS_CERTS: OnceCell<ArcSwap<Vec<Certificate>>> = OnceCell::new();
pub static CERTIFIED_KEY: ArcSwapOption<CertifiedKey> = ArcSwapOption::const_empty();
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ServerInitError { pub enum ServerInitError {
@ -49,14 +45,18 @@ pub enum ServerInitError {
} }
impl ServerState { impl ServerState {
pub async fn init(secret: &ClientSecret, config: &Config) -> Result<Self, ServerInitError> { pub async fn init(secret: &str, config: &CliArgs) -> Result<Self, ServerInitError> {
let resp = HTTP_CLIENT let resp = reqwest::Client::new()
.inner()
.post(CONTROL_CENTER_PING_URL) .post(CONTROL_CENTER_PING_URL)
.json(&Request::from((secret, config))) .json(&Request::from((secret, config)))
.send() .send()
.await; .await;
if config.enable_server_string {
warn!("Client will send Server header in responses. This is not recommended!");
SEND_SERVER_VERSION.store(true, Ordering::Release);
}
match resp { match resp {
Ok(resp) => match resp.json::<Response>().await { Ok(resp) => match resp.json::<Response>().await {
Ok(Response::Ok(mut resp)) => { Ok(Response::Ok(mut resp)) => {
@ -64,16 +64,15 @@ impl ServerState {
.token_key .token_key
.ok_or(ServerInitError::MissingTokenKey) .ok_or(ServerInitError::MissingTokenKey)
.and_then(|key| { .and_then(|key| {
base64::decode(&key) if let Some(key) = base64::decode(&key)
.ok() .ok()
.and_then(|k| PrecomputedKey::from_slice(&k)) .and_then(|k| PrecomputedKey::from_slice(&k))
.map_or_else( {
|| { Ok(key)
error!("Failed to parse token key: got {}", key); } else {
Err(ServerInitError::KeyParseError(key)) error!("Failed to parse token key: got {}", key);
}, Err(ServerInitError::KeyParseError(key))
Ok, }
)
})?; })?;
PREVIOUSLY_COMPROMISED.store(resp.compromised, Ordering::Release); PREVIOUSLY_COMPROMISED.store(resp.compromised, Ordering::Release);
@ -89,19 +88,26 @@ impl ServerState {
if let Some(ref override_url) = config.override_upstream { if let Some(ref override_url) = config.override_upstream {
resp.image_server = override_url.clone(); resp.image_server = override_url.clone();
warn!("Upstream URL overridden to: {}", resp.image_server); warn!("Upstream URL overridden to: {}", resp.image_server);
} else {
} }
info!("This client's URL has been set to {}", resp.url); info!("This client's URL has been set to {}", resp.url);
if config
.unstable_options
.contains(&UnstableOptions::DisableTokenValidation)
{
warn!("Token validation is explicitly disabled!");
} else {
if resp.force_tokens {
info!("This client will validate tokens.");
} else {
info!("This client will not validate tokens.");
}
VALIDATE_TOKENS.store(resp.force_tokens, Ordering::Release);
}
let tls = resp.tls.unwrap(); let tls = resp.tls.unwrap();
CERTIFIED_KEY.store(Some(Arc::new(CertifiedKey {
cert: tls.certs.clone(),
key: Arc::clone(&tls.priv_key) as Arc<dyn SigningKey>,
ocsp: None,
sct_list: None,
})));
std::mem::drop( std::mem::drop(
TLS_PREVIOUSLY_CREATED.set(ArcSwap::from_pointee(tls.created_at)), TLS_PREVIOUSLY_CREATED.set(ArcSwap::from_pointee(tls.created_at)),
); );
@ -143,10 +149,9 @@ impl ServerState {
pub fn init_offline() -> Self { pub fn init_offline() -> Self {
assert!(OFFLINE_MODE.load(Ordering::Acquire)); assert!(OFFLINE_MODE.load(Ordering::Acquire));
Self { Self {
precomputed_key: PrecomputedKey::from_slice(&[41; PRECOMPUTEDKEYBYTES]) precomputed_key: PrecomputedKey::from_slice(&[41; PRECOMPUTEDKEYBYTES]).unwrap(),
.expect("expect offline config to work"), image_server: Url::from_file_path("/dev/null").unwrap(),
image_server: Url::from_file_path("/dev/null").expect("expect offline config to work"), url: Url::from_str("http://localhost").unwrap(),
url: Url::from_str("http://localhost").expect("expect offline config to work"),
url_overridden: false, url_overridden: false,
} }
} }
@ -157,9 +162,14 @@ pub struct RwLockServerState(pub RwLock<ServerState>);
pub struct DynamicServerCert; pub struct DynamicServerCert;
impl ResolvesServerCert for DynamicServerCert { impl ResolvesServerCert for DynamicServerCert {
fn resolve(&self, _: ClientHello) -> Option<Arc<CertifiedKey>> { fn resolve(&self, _: ClientHello) -> Option<CertifiedKey> {
// TODO: wait for actix-web to use a new version of rustls so we can // TODO: wait for actix-web to use a new version of rustls so we can
// remove cloning the certs all the time // remove cloning the certs all the time
CERTIFIED_KEY.load_full() Some(CertifiedKey {
cert: TLS_CERTS.get().unwrap().load().as_ref().clone(),
key: TLS_SIGNING_KEY.get().unwrap().load_full(),
ocsp: None,
sct_list: None,
})
} }
} }

View file

@ -1,24 +1,20 @@
#![cfg(not(tarpaulin_include))] use log::{info, warn};
use reqwest::StatusCode; use reqwest::StatusCode;
use serde::Serialize; use serde::Serialize;
use tracing::{info, warn};
use crate::client::HTTP_CLIENT; const CONTROL_CENTER_STOP_URL: &str = "https://api.mangadex.network/ping";
use crate::config::ClientSecret;
const CONTROL_CENTER_STOP_URL: &str = "https://api.mangadex.network/stop";
#[derive(Serialize)] #[derive(Serialize)]
struct StopRequest<'a> { struct StopRequest<'a> {
secret: &'a ClientSecret, secret: &'a str,
} }
pub async fn send_stop(secret: &ClientSecret) { pub async fn send_stop(secret: &str) {
match HTTP_CLIENT let request = StopRequest { secret };
.inner() let client = reqwest::Client::new();
match client
.post(CONTROL_CENTER_STOP_URL) .post(CONTROL_CENTER_STOP_URL)
.json(&StopRequest { secret }) .json(&request)
.send() .send()
.await .await
{ {
@ -32,17 +28,3 @@ pub async fn send_stop(secret: &ClientSecret) {
Err(e) => warn!("Got error while sending stop message: {}", e), Err(e) => warn!("Got error while sending stop message: {}", e),
} }
} }
#[cfg(test)]
mod stop {
use super::CONTROL_CENTER_STOP_URL;
#[test]
fn stop_url_does_not_have_ping_in_url() {
// This looks like a dumb test, yes, but it ensures that clients don't
// get marked compromised because apparently just sending a json obj
// with just the secret is acceptable to the ping endpoint, which messes
// up non-trivial client configs.
assert!(!CONTROL_CENTER_STOP_URL.contains("ping"))
}
}

View file

@ -1,99 +0,0 @@
use std::fmt::Display;
use std::num::{NonZeroU16, NonZeroU64, ParseIntError};
use std::str::FromStr;
use serde::{Deserialize, Serialize};
/// Wrapper type for a port number.
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub struct Port(NonZeroU16);
impl Port {
pub const fn get(self) -> u16 {
self.0.get()
}
pub fn new(amt: u16) -> Option<Self> {
NonZeroU16::new(amt).map(Self)
}
}
impl Default for Port {
fn default() -> Self {
Self(unsafe { NonZeroU16::new_unchecked(443) })
}
}
impl FromStr for Port {
type Err = <NonZeroU16 as FromStr>::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
NonZeroU16::from_str(s).map(Self)
}
}
impl Display for Port {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[derive(Copy, Clone, Deserialize, Default, Debug, Hash, Eq, PartialEq)]
pub struct Mebibytes(usize);
impl Mebibytes {
#[cfg(test)]
pub fn new(size: usize) -> Self {
Self(size)
}
}
impl FromStr for Mebibytes {
type Err = ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.parse::<usize>().map(Self)
}
}
#[derive(Serialize, Debug)]
pub struct Bytes(pub usize);
impl Bytes {
pub const fn get(&self) -> usize {
self.0
}
}
impl From<Mebibytes> for Bytes {
fn from(mib: Mebibytes) -> Self {
Self(mib.0 << 20)
}
}
#[derive(Copy, Clone, Deserialize, Debug, Hash, Eq, PartialEq)]
pub struct KilobitsPerSecond(NonZeroU64);
impl KilobitsPerSecond {
#[cfg(test)]
pub fn new(size: u64) -> Option<Self> {
NonZeroU64::new(size).map(Self)
}
}
impl FromStr for KilobitsPerSecond {
type Err = ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.parse::<NonZeroU64>().map(Self)
}
}
#[derive(Copy, Clone, Serialize, Debug, Hash, Eq, PartialEq)]
pub struct BytesPerSecond(NonZeroU64);
impl From<KilobitsPerSecond> for BytesPerSecond {
fn from(kbps: KilobitsPerSecond) -> Self {
Self(unsafe { NonZeroU64::new_unchecked(kbps.0.get() * 125) })
}
}