Compare commits

...

120 commits

Author SHA1 Message Date
1152f775b9
lint tests 2022-03-26 16:30:50 -07:00
e404d144a2
clippy 2022-03-26 16:28:58 -07:00
bc6ed7d07e
Update actix 2022-03-26 16:21:27 -07:00
ff0944f58c
Update dependencies 2022-03-26 16:20:00 -07:00
81604a7e94
Clippy 2022-01-02 13:25:00 -08:00
f6a9caf653
Update Cargo.lock 2022-01-02 13:18:08 -08:00
5c6f02b9a5
Update minor version 2022-01-02 12:48:03 -08:00
557f141ed2
Update lru 2022-01-02 12:42:22 -08:00
55f6279dce
Update to actix beta 18 2022-01-02 12:34:00 -08:00
0bf76eab6b
Update deps 2022-01-02 12:12:35 -08:00
a838a94ce9
Update to modern code 2022-01-02 12:12:23 -08:00
63eba4dc37
add db listener test 2021-07-23 20:27:52 -04:00
4544061845
clarify iso code impl 2021-07-23 17:41:17 -04:00
42cbd81375
convert some unwraps to expects 2021-07-22 13:46:40 -04:00
07bd39e69d
Add unit tests for token validation 2021-07-22 13:37:43 -04:00
8f5799211c
Try reserve port during startup 2021-07-22 13:37:32 -04:00
0300135a6a
Add exclusions for code coverage 2021-07-20 16:47:04 -04:00
7bbbf44328
Hololive-ify sample config file 2021-07-18 23:57:38 -04:00
2878ddf0dc
Update dependencies 2021-07-18 23:45:24 -04:00
e4af231829
Remove tracing-futures 2021-07-18 22:05:36 -04:00
b04fac9b01
Added docker-compose file 2021-07-18 21:36:54 -04:00
5aa72e9821
Add parser directive to Dockerfile 2021-07-18 21:33:56 -04:00
6d6bf7371b
Add dockerfile 2021-07-18 21:27:29 -04:00
6b1c913b5d
Remove unneeded imports 2021-07-18 18:35:22 -04:00
acd37297fd
Remove legacy token validation field 2021-07-18 18:32:19 -04:00
bd306455bc
Use ReaderStream 2021-07-18 11:37:39 -04:00
afa2cf55fa
Fix some future not send lints 2021-07-17 13:32:43 -04:00
e95afd3e32
Bump to 0.5.3 2021-07-17 13:05:13 -04:00
fbcf9566c1
Update readme 2021-07-17 13:04:39 -04:00
d42b80d7e1
Remove encryption option warning 2021-07-17 12:52:41 -04:00
931da0c3ff
Add redis support 2021-07-17 12:52:02 -04:00
5fdcfa5071
Add newline to Cargo.toml 2021-07-16 21:15:47 -04:00
93ff76aa89
create folder if not found 2021-07-16 20:03:59 -04:00
f8f4098fae
Finish mem tests 2021-07-16 16:40:07 -04:00
bfcf131b33
Add stop url test 2021-07-16 15:26:18 -04:00
5da486d43d
Add put test for mem cache 2021-07-16 15:18:10 -04:00
51546eb387
Add memory cache get tests 2021-07-16 14:06:28 -04:00
712257429a
Remove Compression wrapper 2021-07-16 12:32:24 -04:00
5e7a82a610
Add partial test for mem cache 2021-07-16 01:13:51 -04:00
afb02db7b7
Remove unnecessary async keyword 2021-07-16 01:13:31 -04:00
b41ae8cb79
Add sync restriction on CacheStream 2021-07-16 01:13:01 -04:00
54c8fe1cb3
Turn DB messages into struct from tuple 2021-07-15 21:49:19 -04:00
8556f37904
Extract internal cache listener to function 2021-07-15 21:49:19 -04:00
fc930285f0
Bump to 0.5.2 2021-07-15 19:14:54 -04:00
041760f9e9
clippy 2021-07-15 19:13:31 -04:00
87271c85a7
Fix deleting legacy names 2021-07-15 19:03:39 -04:00
3e4260f6e1
rename /metric endpoint to /prometheus 2021-07-15 15:54:26 -04:00
dc99437aec
Fix legacy path lookup 2021-07-15 13:47:55 -04:00
3dbf2f8bb0
Read legacy path on disk 2021-07-15 13:25:46 -04:00
f7b037e8e1
Update readme 2021-07-15 12:58:40 -04:00
a552523a3a
increment version to 0.5.1 2021-07-15 12:39:19 -04:00
5935e4220b
Fix stop url 2021-07-15 12:37:55 -04:00
fa9ab93c77
Add proxy support 2021-07-15 12:29:55 -04:00
833a0c0468
Add automatic migration to new db location 2021-07-15 11:17:54 -04:00
b71253d8dc
Remove debug statement 2021-07-15 10:48:25 -04:00
84941e2cb4
Fix sending bytes instead of mebibytes 2021-07-15 03:01:15 -04:00
940af6508c
Default metadata path is metadata.db now 2021-07-15 02:53:00 -04:00
d3434e8408
Fix includes for publishing 2021-07-15 02:45:05 -04:00
3786827f20
Add sqlx json 2021-07-15 02:16:33 -04:00
261427a735
Bump version to 0.5.0 2021-07-15 02:14:41 -04:00
c4fa53fa40
Add geo ip logging support 2021-07-15 02:14:04 -04:00
1c00c993bf
Rename metrics to conventions 2021-07-15 02:13:31 -04:00
b2650da556
Documented more CLI options 2021-07-15 02:12:20 -04:00
6b3c6ce03a
Added geo ip dependencies 2021-07-15 02:10:30 -04:00
032db4e1dd
Add special thanks to readme 2021-07-15 02:09:57 -04:00
355fd936ab
Add debug message to outgoing ping request 2021-07-15 01:19:21 -04:00
6415c3dee6
Respect external ip config 2021-07-15 00:54:31 -04:00
acf6dc1cb1
Add geoip config reading 2021-07-14 22:32:05 -04:00
d4d22ec674
Reduce tokio features 2021-07-14 21:56:46 -04:00
353ee72713
Add unit tests 2021-07-14 21:56:29 -04:00
b1797dafd2
Fix double write bug 2021-07-14 19:11:46 -04:00
9209b822a9
Simplify DiskWriter poll_flush 2021-07-14 14:20:31 -04:00
5338ff81a5
Move impl block to correct location 2021-07-14 14:00:02 -04:00
53015e116f
Renamed EncryptedDiskReader to EncryptedReader 2021-07-14 13:32:26 -04:00
7ce974b4f9
Make EncryptedDiskReader generic over R 2021-07-14 13:32:00 -04:00
973ece3604
MetadataFuture tests, fix UB 2021-07-14 13:28:09 -04:00
0c78b379f1
Remove comments w.r.t. potential db optimizations
DB queries already are performed on their own thread, so spawning a
thread to do db work is unncessary.
2021-07-14 11:21:28 -04:00
94375b185f
More debugging 2021-07-13 23:12:29 -04:00
6ac8582183
Encrypted files work in debug mode 2021-07-13 20:38:01 -04:00
656543b539
more partial work into encryption 2021-07-13 16:39:32 -04:00
2ace8d3d66
Partial rewrite of encrypted writer 2021-07-13 13:16:44 -04:00
160f369a72
Migrate to tracing crate 2021-07-12 23:23:51 -04:00
f8ee49ffd7
Add extended options for sample config 2021-07-12 22:35:06 -04:00
9f76a7a1b3
Add compression middleware 2021-07-12 16:39:06 -04:00
2f271a220a
remove tarpaulin-report 2021-07-12 16:35:17 -04:00
868278991a
Add disk tests 2021-07-12 15:59:52 -04:00
e8bea39100
Remove serialize impl from legacy structs 2021-07-12 13:43:47 -04:00
20e349fd79
Add sqlx envs 2021-07-12 01:44:01 -04:00
80eeacd884
try fix ci 2021-07-12 01:35:32 -04:00
580d05d00c
add sqlx check 2021-07-12 01:34:37 -04:00
acd8f234ab
ignore cfg tarpaulin for now 2021-07-12 01:08:19 -04:00
4c135cd72d
Add coverage action 2021-07-12 01:05:18 -04:00
acc3ab2186
use hex instead of u8s in test 2021-07-12 00:49:06 -04:00
8daa6bdc27
add tests for Md5Hash conversions 2021-07-12 00:48:00 -04:00
69587b9ade
Clippy lints 2021-07-12 00:12:15 -04:00
8f3430fb77
Add support for reading old db image ids 2021-07-11 23:33:22 -04:00
ec9473fa78
clippy lint 2021-07-11 23:25:17 -04:00
3764af0ed5
Optimize header creation 2021-07-11 14:23:15 -04:00
099c795cca
Add potential perf gains in the future 2021-07-11 14:22:59 -04:00
92a66e60dc
Seek file from beginning on encrypted header 2021-07-11 13:25:02 -04:00
5143dff888
Add logging 2021-07-11 13:21:57 -04:00
7546948196
nightly clippy lints 2021-07-11 13:19:37 -04:00
5f4be9809a
Add support for legacy files 2021-07-11 02:33:51 -04:00
8040c49a1e
Add warning for encryption 2021-07-11 00:15:43 -04:00
9871fc3774
bump to 0.4 2021-07-10 19:07:55 -04:00
f64d03493e
use static default headers 2021-07-10 19:04:27 -04:00
93249397f1
Simply codebase 2021-07-10 18:53:28 -04:00
154679967b
writing optimizations 2021-07-10 14:22:29 -04:00
b90edd72a6
Add clippy to CI 2021-07-09 21:25:08 -04:00
3ec4d1c125
remove anchor from github workflows 2021-07-09 20:59:23 -04:00
52ca595029
fix workflow ignores 2021-07-09 20:58:07 -04:00
e0bd29751a
Update paths to ignore 2021-07-09 20:56:35 -04:00
7bd9189ebd
Update readme 2021-07-09 20:51:45 -04:00
c98a6d59f4
Have build script create .env if not found 2021-07-09 20:43:33 -04:00
60bec5592a
Have build script modify .env file 2021-07-09 20:37:21 -04:00
de5816a44a
add env value to CI 2021-07-09 20:22:58 -04:00
6fa301a4a5
rename workflow 2021-07-09 20:12:58 -04:00
7b5738da8d
add standard CI 2021-07-09 20:11:43 -04:00
5ab04a9e9c
add security CI 2021-07-09 20:02:22 -04:00
e65a7ba9ef
Update dependencies 2021-07-09 19:53:58 -04:00
28 changed files with 4222 additions and 1644 deletions

10
.dockerignore Normal file
View file

@ -0,0 +1,10 @@
# Ignore everything
*
# Only include necessary paths (This should be synchronized with `Cargo.toml`)
!db_queries/
!src/
!settings.sample.yaml
!sqlx-data.json
!Cargo.toml
!Cargo.lock

53
.github/workflows/build_and_test.yml vendored Normal file
View file

@ -0,0 +1,53 @@
name: Build and test
on:
push:
branches: [ master ]
paths-ignore:
- "docs/**"
- "settings.sample.yaml"
- "README.md"
- "LICENSE"
pull_request:
branches: [ master ]
paths-ignore:
- "docs/**"
- "settings.sample.yaml"
- "README.md"
- "LICENSE"
env:
CARGO_TERM_COLOR: always
DATABASE_URL: sqlite:./cache/metadata.sqlite
SQLX_OFFLINE: true
jobs:
clippy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- run: rustup component add clippy
- uses: actions-rs/clippy-check@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
args: --all-features
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose
sqlx-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Install sqlx-cli
run: cargo install sqlx-cli
- name: Initialize database
run: mkdir -p cache && sqlite3 cache/metadata.sqlite < db_queries/init.sql
- name: Check sqlx statements
run: cargo sqlx prepare --check

22
.github/workflows/coverage.yml vendored Normal file
View file

@ -0,0 +1,22 @@
name: coverage
on: [push]
jobs:
test:
name: coverage
runs-on: ubuntu-latest
container:
image: xd009642/tarpaulin:develop-nightly
options: --security-opt seccomp=unconfined
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Generate code coverage
run: |
cargo +nightly tarpaulin --verbose --all-features --workspace --timeout 120 --avoid-cfg-tarpaulin --out Xml
- name: Upload to codecov.io
uses: codecov/codecov-action@v1
with:
fail_ci_if_error: true

14
.github/workflows/security_audit.yml vendored Normal file
View file

@ -0,0 +1,14 @@
name: Security audit
on:
push:
paths:
- '**/Cargo.toml'
- '**/Cargo.lock'
jobs:
security_audit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions-rs/audit-check@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}

4
.gitignore vendored
View file

@ -4,4 +4,6 @@
flamegraph*.svg flamegraph*.svg
perf.data* perf.data*
dhat.out.* dhat.out.*
settings.yaml settings.yaml
tarpaulin-report.html
GeoLite2-Country.mmdb

1595
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,50 +1,69 @@
[package] [package]
name = "mangadex-home" name = "mangadex-home"
version = "0.3.0" version = "0.5.4"
license = "GPL-3.0-or-later" license = "GPL-3.0-or-later"
authors = ["Edward Shen <code@eddie.sh>"] authors = ["Edward Shen <code@eddie.sh>"]
edition = "2018" edition = "2018"
include = ["src/**/*", "db_queries", "LICENSE", "README.md"] include = [
"src/**/*",
"db_queries",
"LICENSE",
"README.md",
"sqlx-data.json",
"settings.sample.yaml"
]
description = "A MangaDex@Home implementation in Rust." description = "A MangaDex@Home implementation in Rust."
repository = "https://github.com/edward-shen/mangadex-home-rs" repository = "https://github.com/edward-shen/mangadex-home-rs"
[profile.release] [profile.release]
lto = true lto = true
codegen-units = 1 codegen-units = 1
# debug = 1 debug = 1
[dependencies] [dependencies]
actix-web = { version = "4.0.0-beta.4", features = [ "rustls" ] } # Pin because we're using unstable versions
actix-web = { version = "4", features = [ "rustls" ] }
arc-swap = "1" arc-swap = "1"
async-trait = "0.1" async-trait = "0.1"
base64 = "0.13" base64 = "0.13"
bincode = "1" bincode = "1"
bytes = "1" bytes = { version = "1", features = [ "serde" ] }
chacha20 = "0.7"
chrono = { version = "0.4", features = [ "serde" ] } chrono = { version = "0.4", features = [ "serde" ] }
clap = { version = "3.0.0-beta.2", features = [ "wrap_help" ] } clap = { version = "3", features = [ "wrap_help", "derive", "cargo" ] }
ctrlc = "3" ctrlc = "3"
dotenv = "0.15" dotenv = "0.15"
flate2 = { version = "1", features = [ "tokio" ] }
futures = "0.3" futures = "0.3"
once_cell = "1" once_cell = "1"
log = { version = "0.4", features = [ "serde" ] } log = { version = "0.4", features = [ "serde" ] }
lfu_cache = "1" lfu_cache = "1"
lru = "0.6" lru = "0.7"
maxminddb = "0.20"
md-5 = "0.9"
parking_lot = "0.11" parking_lot = "0.11"
prometheus = { version = "0.12", features = [ "process" ] } prometheus = { version = "0.12", features = [ "process" ] }
redis = "0.21"
reqwest = { version = "0.11", default_features = false, features = [ "json", "stream", "rustls-tls" ] } reqwest = { version = "0.11", default_features = false, features = [ "json", "stream", "rustls-tls" ] }
rustls = "0.19" rustls = "0.20"
rustls-pemfile = "0.2"
serde = "1" serde = "1"
serde_json = "1" serde_json = "1"
serde_repr = "0.1" serde_repr = "0.1"
serde_yaml = "0.8" serde_yaml = "0.8"
simple_logger = "1"
sodiumoxide = "0.2" sodiumoxide = "0.2"
sqlx = { version = "0.5", features = [ "runtime-actix-rustls", "sqlite", "time", "chrono", "macros" ] } sqlx = { version = "0.5", features = [ "runtime-actix-rustls", "sqlite", "time", "chrono", "macros", "offline" ] }
tar = "0.4"
thiserror = "1" thiserror = "1"
tokio = { version = "1", features = [ "full", "parking_lot" ] } tokio = { version = "1", features = [ "rt-multi-thread", "macros", "fs", "time", "sync", "parking_lot" ] }
tokio-stream = { version = "0.1", features = [ "sync" ] } tokio-stream = { version = "0.1", features = [ "sync" ] }
tokio-util = { version = "0.6", features = [ "codec" ] } tokio-util = { version = "0.6", features = [ "codec" ] }
tracing = "0.1"
tracing-subscriber = { version = "0.2", features = [ "parking_lot" ] }
url = { version = "2", features = [ "serde" ] } url = { version = "2", features = [ "serde" ] }
[build-dependencies] [build-dependencies]
vergen = "5" vergen = "5"
[dev-dependencies]
tempfile = "3"

11
Dockerfile Normal file
View file

@ -0,0 +1,11 @@
# syntax=docker/dockerfile:1
FROM rust:alpine as builder
COPY . .
RUN apk add --no-cache file make musl-dev \
&& cargo install --path . \
&& strip /usr/local/cargo/bin/mangadex-home
FROM alpine:latest
COPY --from=builder /usr/local/cargo/bin/mangadex-home /usr/local/bin/mangadex-home
CMD ["mangadex-home"]

107
README.md
View file

@ -2,75 +2,86 @@ A Rust implementation of a MangaDex@Home client.
This client contains the following features: This client contains the following features:
- Multi-threaded - Easy migration from the official client
- HTTP/2 support - Fully compliant with MangaDex@Home specifications
- No support for TLS 1.1 or 1.0 - Multi-threaded, high performance, and low overhead client
- HTTP/2 support for API users, HTTP/2 only for upstream connections
- Secure and privacy oriented features:
- Only supports TLS 1.2 or newer; HTTP is not enabled by default
- Options for no logging and no metrics
- Support for on-disk XChaCha20 encryption with ephemeral key generation
- Supports an internal LFU, LRU, or a redis instance for in-memory caching
## Building ## Building
Since we use SQLx there are a few things you'll need to do. First, you'll need
to run the init cache script, which initializes the db cache at
`./cache/metadata.sqlite`. Then you'll need to add the location of that to a
`.env` file:
```sh ```sh
# In the project root
./init_cache.sh
echo "DATABASE_URL=sqlite:./cache/metadata.sqlite" >> .env
cargo build cargo build
cargo test
``` ```
## Cache implementation You may need to set a client secret, see Configuration for more information.
This client implements a multi-tier in-memory and on-disk LRU cache constrained # Migration
by quotas. In essence, it acts as an unified LRU, where in-memory items are
evicted and pushed into the on-disk LRU and fetching a item from the on-disk LRU
promotes it to the in-memory LRU.
Note that the capacity of each LRU is dynamic, depending on the maximum byte Migration from the official client was made to be as painless as possible. There
capacity that you permit each cache to be. A large item may evict multiple are caveats though:
smaller items to fit within this constraint, for example. - If you ever want to return to using the official client, you will need to
clear your cache.
- As this is an unofficial client implementation, the only support you can
probably get is from me.
Note that these quotas are closer to a rough estimate, and is not guaranteed to Otherwise, the steps to migration is easy:
be strictly below these values, so it's recommended to under set your config 1. Place the binary in the same folder as your `images` folder and
values to make sure you don't exceed the actual quota. `settings.yaml`.
2. Rename `images` to `cache`.
# Client implementation
This client follows a secure-first approach. As such, your statistics may report
a _ever-so-slightly_ higher-than-average failure rate. Specifically, this client
choses to:
- Not support TLS 1.1 or 1.0, which would be a primary source of
incompatibility.
- Not provide a server identification string in the header of served requests.
- HTTPS by enabled by default, HTTP is provided (and unsupported).
That being said, this client should be backwards compatibility with the official
client data and config. That means you should be able to replace the binary and
preserve all your settings and cache.
## Installation ## Installation
Either build it from source or run `cargo install mangadex-home`. Either build it from source or run `cargo install mangadex-home`.
## Running
Run `mangadex-home`, and make sure the advertised port is open on your firewall.
Do note that some configuration fields are required. See the next section for
details.
## Configuration ## Configuration
Most configuration options can be either provided on the command line, sourced Most configuration options can be either provided on the command line or sourced
from a `.env` file, or sourced directly from the environment. Do not that the from a file named `settings.yaml` from the directory you ran the command from,
client secret is an exception. You must provide the client secret from the which will be created on first run.
environment or from the `.env` file, as providing client secrets in a shell is a
operation security risk.
The following options are required: Note that the client secret (`CLIENT_SECRET`) is the only configuration option
that can only can be provided from the environment, an `.env` file, or the
`settings.yaml` file. In other words, you _cannot_ provide this value from the
command line.
- Client Secret ## Special thanks
- Memory cache quota
- Disk cache quota
- Advertised network speed
The following are optional as a default value will be set for you: This project could not have been completed without the assistance of the
following:
- Port #### Development Assistance (Alphabetical Order)
- Disk cache path
### Advanced configuration - carbotaniuman#6974
- LFlair#1337
- Plykiya#1738
- Tristan 9#6752
- The Rust Discord community
This implementation prefers to act more secure by default. As a result, some #### Beta testers
features that the official specification requires are not enabled by default.
If you don't know why these features are disabled by default, then don't enable
these, as they may generally weaken the security stance of the client for more
compatibility.
- Sending Server version string - NigelVH#7162
---
If using the geo IP logging feature, then this product includes GeoLite2 data
created by MaxMind, available from https://www.maxmind.com.

View file

@ -1,5 +1,4 @@
use std::error::Error; use std::error::Error;
use std::process::Command;
use vergen::{vergen, Config, ShaKind}; use vergen::{vergen, Config, ShaKind};
@ -9,17 +8,5 @@ fn main() -> Result<(), Box<dyn Error>> {
*config.git_mut().sha_kind_mut() = ShaKind::Short; *config.git_mut().sha_kind_mut() = ShaKind::Short;
vergen(config)?; vergen(config)?;
// Initialize SQL stuff
let project_root = std::env::var("CARGO_MANIFEST_DIR").unwrap();
Command::new("mkdir")
.args(["cache", "--parents"])
.current_dir(&project_root)
.output()?;
Command::new("sqlite3")
.args(["cache/metadata.sqlite", include_str!("db_queries/init.sql")])
.current_dir(&project_root)
.output()?;
Ok(()) Ok(())
} }

View file

@ -0,0 +1 @@
insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing

9
docker-compose.yml Normal file
View file

@ -0,0 +1,9 @@
version: "3.9"
services:
mangadex-home:
build: .
ports:
- "443:443"
volumes:
- ./cache:/cache
- ./settings.yaml:/settings.yaml

View file

@ -1,28 +1,41 @@
--- ---
# ⢸⣿⣿⣿⣿⠃⠄⢀⣴⡾⠃⠄⠄⠄⠄⠄⠈⠺⠟⠛⠛⠛⠛⠻⢿⣿⣿⣿⣿⣶⣤⡀⠄ # ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢦⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
# ⢸⣿⣿⣿⡟⢀⣴⣿⡿⠁⠄⠄⠄⠄⠄⠄⠄⠄⠄⠄⠄⠄⠄⠄⣸⣿⣿⣿⣿⣿⣿⣿⣷ #⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢰⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⢼⣈⡆⠀⠀⠀⠀⠀⠀⠀⠀⠀
# ⢸⣿⣿⠟⣴⣿⡿⡟⡼⢹⣷⢲⡶⣖⣾⣶⢄⠄⠄⠄⠄⠄⢀⣼⣿⢿⣿⣿⣿⣿⣿⣿⣿ #⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠀⠀⠀⠀⠀⠀⣀⣀⣀⣀⡀⠀⠀⠀⢸⡧⣀⠀⣄⡠⠒⠉⠀⠀⠀⢀⡈⢑⢦⠀⠀⠀⠀⠀⠀⠀⠀
# ⢸⣿⢫⣾⣿⡟⣾⡸⢠⡿⢳⡿⠍⣼⣿⢏⣿⣷⢄⡀⠄⢠⣾⢻⣿⣸⣿⣿⣿⣿⣿⣿⣿ #⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⣴⡗⠒⠉⠉⠉⠉⠉⠀⠀⠀⠀⠈⠉⠉⠉⠑⠣⡀⠉⠚⡷⡆⠀⣀⣀⣀⠤⡺⠈⠫⠗⠢⡄⠀⠀⠀⠀⠀
# ⡿⣡⣿⣿⡟⡼⡁⠁⣰⠂⡾⠉⢨⣿⠃⣿⡿⠍⣾⣟⢤⣿⢇⣿⢇⣿⣿⢿⣿⣿⣿⣿⣿ #⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠔⠊⠁⢰⠃⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢰⠢⣧⠘⣎⠐⠠⠟⠋⠀⠀⠀⠀⢄⢸⠀⠀⠀⠀⠀
# ⣱⣿⣿⡟⡐⣰⣧⡷⣿⣴⣧⣤⣼⣯⢸⡿⠁⣰⠟⢀⣼⠏⣲⠏⢸⣿⡟⣿⣿⣿⣿⣿⣿ #⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠴⠋⠀⠀⠀⠀⡜⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⡆⡏⠀⠸⡄⢀⠀⠀⠪⠴⠒⠒⠚⠘⢤⠀⠀⠀⠀
# ⣿⣿⡟⠁⠄⠟⣁⠄⢡⣿⣿⣿⣿⣿⣿⣦⣼⢟⢀⡼⠃⡹⠃⡀⢸⡿⢸⣿⣿⣿⣿⣿⡟ #⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠔⠁⠀⠀⠀⠀⠀⠀⡇⠀⠀⢣⡀⢀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠀⢱⠃⠀⠀⡇⡇⠉⠖⡤⠤⠤⠤⠴⢒⠺⠀⠀⠀⠀
# ⣿⣿⠃⠄⢀⣾⠋⠓⢰⣿⣿⣿⣿⣿⣿⠿⣿⣿⣾⣅⢔⣕⡇⡇⡼⢁⣿⣿⣿⣿⣿⣿⢣ #⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠋⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠱⡼⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠘⣆⡎⠀⠀⢀⡇⠙⠦⣌⣹⣏⠉⠉⠁⣀⠠⣂⠀⠀⠀
# ⣿⡟⠄⠄⣾⣇⠷⣢⣿⣿⣿⣿⣿⣿⣿⣭⣀⡈⠙⢿⣿⣿⡇⡧⢁⣾⣿⣿⣿⣿⣿⢏⣾ #⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡴⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢧⠀⠀⠀⣇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢰⠀⢻⣿⣿⣿⡟⠓⠤⣀⡀⠉⠓⠭⠭⠔⠒⠉⠈⡆⠀⠀
# ⣿⡇⠄⣼⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠟⢻⠇⠄⠄⢿⣿⡇⢡⣾⣿⣿⣿⣿⣿⣏⣼⣿ #⠀⠀⠀⠀⠀⠀⢀⣀⣀⣀⣀⣀⡸⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢣⣠⠴⢻⠀⠀⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣧⢸⡈⠉⠉⢣⠀⠀⠀⠉⠑⠢⢄⣀⣀⡤⠖⠋⠳⡀⠀
# ⣿⣷⢰⣿⣿⣾⣿⣿⣿⣿⣿⣿⣿⣿⣿⢰⣧⣀⡄⢀⠘⡿⣰⣿⣿⣿⣿⣿⣿⠟⣼⣿⣿ #⠀⠀⠀⠀⠀⡔⠁⠀⢇⠀⠀⠀⠈⠳⡀⠀⠀⢀⠂⠀⠀⠀⠀⢀⠃⢠⠏⠀⠀⠈⡆⢠⢷⠀⠀⠀⠀⠀⠀⠀⠀⢰⠀⣿⡎⡇⠀⡠⠈⡇⠀⠀⠀⠀⠀⠀⠈⣦⠃⠀⠀⠈⢦⠀
# ⢹⣿⢸⣿⣿⠟⠻⢿⣿⣿⣿⣿⣿⣿⣿⣶⣭⣉⣤⣿⢈⣼⣿⣿⣿⣿⣿⣿⠏⣾⣹⣿⣿ #⠀⠀⠀⠀⡰⠃⠀⠀⠘⡄⠀⠀⠀⠀⢱⠀⠀⡜⠀⠀⠀⠀⠀⢸⠐⡮⢀⠀⠀⠀⢱⡸⡄⢧⠀⠀⠀⡀⠀⠀⠀⢸⣇⡿⢳⡧⠊⠀⠀⡇⡇⠀⠆⠀⠀⠀⠀⢱⡀⡠⠂⠀⠈⡇
# ⢸⠇⡜⣿⡟⠄⠄⠄⠈⠙⣿⣿⣿⣿⣿⣿⣿⣿⠟⣱⣻⣿⣿⣿⣿⣿⠟⠁⢳⠃⣿⣿⣿ #⠀⠀⠀⢰⠁⠀⠀⡀⠀⠘⡄⠀⠀⠀⢸⠀⢠⠃⠀⠀⢀⠀⠀⣼⢠⠃⠀⠁⠢⢄⠀⠳⣇⠀⠣⡀⠀⢣⡀⠀⠀⢸⢹⡧⢺⠀⠀⠀⠀⡷⢹⠀⢠⠀⠀⠀⠀⠈⡏⠳⡄⠀⠀⢳
# ⠄⣰⡗⠹⣿⣄⠄⠄⠄⢀⣿⣿⣿⣿⣿⣿⠟⣅⣥⣿⣿⣿⣿⠿⠋⠄⠄⣾⡌⢠⣿⡿⠃ #⠀⠀⠀⢸⠀⠀⠀⠈⠢⢄⠈⣢⠔⠒⠙⠒⢼⠀⠀⠀⢸⠀⢀⠿⣸⠀⠀⠀⠀⠀⠉⠢⢌⠀⠀⠈⠉⠒⠯⠉⠒⠚⠚⣠⠾⠶⠿⠷⢶⣧⡈⢆⢸⠀⠀⠀⠀⠀⢣⠀⢱⠀⠀⡎
# ⠜⠋⢠⣷⢻⣿⣿⣶⣾⣿⣿⣿⣿⠿⣛⣥⣾⣿⠿⠟⠛⠉⠄⠄ #⠀⠀⠀⢸⠀⠀⠸⠀⠀⠀⢹⠁⠀⠀⠀⡀⡞⠀⠀⠀⢸⠀⢸⠀⢿⠀⣠⣴⠾⠛⠛⠓⠢⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠁⡀⠘⣿⣶⣄⠈⢻⣆⢻⠀⡆⠀⠀⠀⢸⠠⣸⠂⢠⠃
#⠀⠀⠀⠘⡄⠀⠀⢡⠀⠀⡼⠀⡠⠐⠁⠀⡇⠀⠀⠀⠈⡆⢸⠀⢨⡾⠋⠀⠀⢻⣿⣿⣷⣦⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⡿⢻⣿⣿⠻⣇⠀⢻⣾⠀⡇⠀⠀⠀⠈⡞⠁⣠⠋⠀
#⠀⠀⠀⠀⠱⡀⠀⠀⠑⢄⡑⣅⠀⠀⠀⠀⡇⠀⠀⠀⠀⠘⣼⠀⣿⠁⠀⢠⡷⢾⣿⣿⡟⠛⡇⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠈⠛⠁⠀⢸⠀⠈⢹⡸⠀⠀⠀⠀⠀⡧⠚⡇⠀⠀
#⠀⠀⠀⠀⠀⠈⠢⢄⣀⠀⠈⠉⢑⣶⠴⠒⡇⠀⠀⠀⠀⠀⡟⠧⡇⠀⠀⠸⡁⠀⠙⠋⠀⠀⡞⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⠦⣤⣤⡴⠃⠀⠀⠼⠣⡀⠀⡇⠀⠀⡷⢄⣇⠀⠀
#⠀⠀⠀⠀⠀⠀⢀⠞⠀⡏⠉⢉⣵⣳⠀⠀⡇⠀⠀⠀⠀⠀⢱⠀⠁⠀⠀⠀⠑⠤⠤⡠⠤⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠠⠡⢁⠀⠀⢱⠀⡇⠀⢠⡇⠀⢻⡀⠀
#⠀⠀⠀⠀⠀⢠⠎⠀⢸⣇⡔⠉⠀⢹⡀⠀⡇⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠐⡀⢀⠀⠄⡀⠀⠀⠀⠀⠀⠀⢀⣀⣀⣀⣀⣀⣀⣤⠈⠠⠡⠁⠂⠌⠀⢸⠀⡗⠀⢸⠇⠀⢀⡇⠀
#⠀⠀⠀⠀⢠⠃⠀⠀⡎⡏⠀⠀⠀⠀⡇⠀⡇⠀⡆⠀⠀⠀⠘⡄⠀⠀⠈⠌⠠⠂⠌⠐⠀⢀⠎⠉⠒⠉⠉⠉⠉⠙⠛⠧⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⡼⢀⠇⠀⢸⣀⠴⠋⢱⠀
#⠀⠀⠀⢠⠃⠀⠀⢰⠙⢣⡀⠀⠀⣇⡇⠀⢧⠀⡇⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⡿⠀⠀⠀⠀⠀⢀⡼⠃⡜⠀⠀⡏⢱⠀⠐⠈⡇
#⠀⠀⢠⢃⠀⠀⢠⠇⠀⢸⡉⠓⠶⢿⠃⠀⢸⠀⡇⠀⠀⠀⡄⢹⠀⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡴⠁⠀⠀⠀⢀⣴⡟⠁⡰⠁⠀⢰⢧⠈⡆⠀⠇⢇
#⠀⠀⡜⡄⠀⢀⡎⠀⠀⠀⡇⠀⠀⢸⠀⠀⠈⡇⣿⠀⠀⠀⢧⠈⡗⠢⢤⣀⠀⠀⠀⠀⠀⠀⠙⢄⡀⠀⠀⠀⠀⠀⣀⡤⠚⠁⢀⣠⡤⠞⠋⢀⠇⡴⠁⠀⠀⠾⣼⠀⢱⠀⢸⢸
#⠀⠀⡇⡇⠀⡜⠀⠀⠀⠀⡇⠀⠀⣾⠀⠀⠀⢹⡏⡆⠀⠀⢸⢆⠸⡄⠀⠀⠉⢑⣦⣤⡀⠀⠀⠀⠉⠑⠒⣒⣋⣉⣡⣤⠒⠊⠉⡇⠀⠀⠀⣾⣊⠀⠀⠀⠈⢠⢻⠀⢸⣀⣿⡜
#⠀⠀⣷⢇⢸⠁⠀⠀⠀⠀⡇⠀⢰⢹⠀⠀⠀⠀⢿⠹⡀⠀⠸⡀⠳⣵⡀⡠⠚⠉⠙⢿⣿⣷⣦⣀⠀⠀⠀⣱⣿⣿⠀⠈⠉⠲⣄⢧⣠⠒⢌⡇⡠⣃⣀⡠⠔⠁⠀⡇⢸⡟⢸⠇
#⠀⠀⢻⠘⣼⠀⠀⠀⠀⢰⠁⣠⠃⢸⠀⠀⠀⠀⠘⠀⠳⡀⠀⡇⠀⢀⠟⠦⠀⡀⠀⢸⣛⣻⣿⣿⣿⣶⣭⣿⣿⣻⡆⠀⠀⠀⠈⢦⠸⣽⢝⠿⡫⡁⢸⡇⠀⠀⠀⢣⠘⠁⠘⠀
#⠀⠀⠘⠆⠸⠄⠀⠀⢠⠏⡰⠁⠀⡞⠀⠀⠀⠀⠀⠀⠀⠙⢄⣸⣶⣷⣶⣶⣶⣤⣤⣼⣿⣽⣯⣿⣿⣿⣷⣾⣿⣿⣿⣾⣤⣴⣶⣾⣷⣇⠺⠤⠕⠈⢉⠇⠀⠀⠀⠘⡄
# #
# MangaDex@Home configuration file # MangaDex@Home configuration file
# We are pleased to have you here #
# May fate stay the night with you! # Thanks for contributing to MangaDex@Home, friend!
# Beat up a pineapple, and don't forget your AsaCoco!
# #
# Default values are commented out. # Default values are commented out.
# The size in mebibytes of the cache # The size in mebibytes of the cache You can use megabytes instead in a pinch,
# You can use megabytes instead in a pinch,
# but just know the two are **NOT** the same. # but just know the two are **NOT** the same.
max_cache_size_in_mebibytes: 0 max_cache_size_in_mebibytes: 0
@ -34,7 +47,7 @@ server_settings:
# port: 443 # port: 443
# This controls the value the server receives for your upload speed. # This controls the value the server receives for your upload speed.
external_max_kilobits_per_second: 0 external_max_kilobits_per_second: 1
# #
# Advanced settings # Advanced settings
@ -55,3 +68,47 @@ server_settings:
# the backend will infer it from where it was sent from, which may fail in the # the backend will infer it from where it was sent from, which may fail in the
# presence of multiple IPs. # presence of multiple IPs.
# external_ip: ~ # external_ip: ~
# Settings for geo IP analytics
metric_settings:
# Whether to enable geo IP analytics
# enable_geoip: false
# geoip_license_key: none
# These settings are unique to the Rust client, and may be ignored or behave
# differently from the official client.
extended_options:
# Which cache type to use. By default, this is `on_disk`, but one can select
# `lfu`, `lru`, or `redis` to use a LFU, LRU, or redis instance in addition
# to the on-disk cache to improve lookup times. Generally speaking, using one
# is almost always better, but by how much depends on how much memory you let
# the node use, how large is your node, and which caching implementation you
# use.
# cache_type: on_disk
# The redis url to connect with. Does nothing if the cache type isn't`redis`.
# redis_url: "redis://127.0.0.1/"
# The amount of memory the client should use when using an in-memory cache.
# This does nothing if only the on-disk cache is used.
# memory_quota: 0
# Whether or not to expose a prometheus endpoint at /metrics. This is a
# completely open endpoint, so best practice is to make sure this is only
# readable from the internal network.
# enable_metrics: false
# If you'd like to specify a different path location for the cache, you can do
# so here.
# cache_path: "./cache"
# What logging level to use. Valid options are "error", "warn", "info",
# "debug", "trace", and "off", which disables logging.
# logging_level: info
# Enables disk encryption where the key is stored in memory. In other words,
# when the MD@H program is stopped, all cached files are irrecoverable.
# Practically speaking, this isn't all too useful (and definitely hurts
# performance), but for peace of mind, this may be useful.
# ephemeral_disk_encryption: false

75
sqlx-data.json Normal file
View file

@ -0,0 +1,75 @@
{
"db": "SQLite",
"24b536161a0ed44d0595052ad069c023631ffcdeadb15a01ee294717f87cdd42": {
"query": "update Images set accessed = ? where id = ?",
"describe": {
"columns": [],
"parameters": {
"Right": 2
},
"nullable": []
}
},
"2a8aa6dd2c59241a451cd73f23547d0e94930e35654692839b5d11bb8b87703e": {
"query": "insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing",
"describe": {
"columns": [],
"parameters": {
"Right": 3
},
"nullable": []
}
},
"311721fec7824c2fc3ecf53f714949a49245c11a6b622efdb04fdac24be41ba3": {
"query": "SELECT IFNULL(SUM(size), 0) AS size FROM Images",
"describe": {
"columns": [
{
"name": "size",
"ordinal": 0,
"type_info": "Int"
}
],
"parameters": {
"Right": 0
},
"nullable": [
true
]
}
},
"44234188e873a467ecf2c60dfb4731011e0b7afc4472339ed2ae33aee8b0c9dd": {
"query": "select id, size from Images order by accessed asc limit 1000",
"describe": {
"columns": [
{
"name": "id",
"ordinal": 0,
"type_info": "Text"
},
{
"name": "size",
"ordinal": 1,
"type_info": "Int64"
}
],
"parameters": {
"Right": 0
},
"nullable": [
false,
false
]
}
},
"a60501a30fd75b2a2a59f089e850343af075436a5c543a267ecb4fa841593ce9": {
"query": "create table if not exists Images(\n id varchar primary key not null,\n size integer not null,\n accessed timestamp not null default CURRENT_TIMESTAMP\n);\ncreate index if not exists Images_accessed on Images(accessed);",
"describe": {
"columns": [],
"parameters": {
"Right": 0
},
"nullable": []
}
}
}

152
src/cache/compat.rs vendored Normal file
View file

@ -0,0 +1,152 @@
//! These structs have alternative deserialize and serializations
//! implementations to assist reading from the official client file format.
use std::str::FromStr;
use chrono::{DateTime, FixedOffset};
use serde::de::{Unexpected, Visitor};
use serde::{Deserialize, Serialize};
use super::ImageContentType;
#[derive(Copy, Clone, Deserialize)]
pub struct LegacyImageMetadata {
pub(crate) content_type: Option<LegacyImageContentType>,
pub(crate) size: Option<u32>,
pub(crate) last_modified: Option<LegacyDateTime>,
}
#[derive(Copy, Clone, Serialize)]
pub struct LegacyDateTime(pub DateTime<FixedOffset>);
impl<'de> Deserialize<'de> for LegacyDateTime {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct LegacyDateTimeVisitor;
impl<'de> Visitor<'de> for LegacyDateTimeVisitor {
type Value = LegacyDateTime;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "a valid image type")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
DateTime::parse_from_rfc2822(v)
.map(LegacyDateTime)
.map_err(|_| E::invalid_value(Unexpected::Str(v), &"a valid image type"))
}
}
deserializer.deserialize_str(LegacyDateTimeVisitor)
}
}
#[derive(Copy, Clone)]
pub struct LegacyImageContentType(pub ImageContentType);
impl<'de> Deserialize<'de> for LegacyImageContentType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct LegacyImageContentTypeVisitor;
impl<'de> Visitor<'de> for LegacyImageContentTypeVisitor {
type Value = LegacyImageContentType;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "a valid image type")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
ImageContentType::from_str(v)
.map(LegacyImageContentType)
.map_err(|_| E::invalid_value(Unexpected::Str(v), &"a valid image type"))
}
}
deserializer.deserialize_str(LegacyImageContentTypeVisitor)
}
}
#[cfg(test)]
mod parse {
use std::error::Error;
use chrono::DateTime;
use crate::cache::ImageContentType;
use super::LegacyImageMetadata;
#[test]
fn from_valid_legacy_format() -> Result<(), Box<dyn Error>> {
let legacy_header = r#"{"content_type":"image/jpeg","last_modified":"Sat, 10 Apr 2021 10:55:22 GMT","size":117888}"#;
let metadata: LegacyImageMetadata = serde_json::from_str(legacy_header)?;
assert_eq!(
metadata.content_type.map(|v| v.0),
Some(ImageContentType::Jpeg)
);
assert_eq!(metadata.size, Some(117_888));
assert_eq!(
metadata.last_modified.map(|v| v.0),
Some(DateTime::parse_from_rfc2822(
"Sat, 10 Apr 2021 10:55:22 GMT"
)?)
);
Ok(())
}
#[test]
fn empty_metadata() -> Result<(), Box<dyn Error>> {
let legacy_header = "{}";
let metadata: LegacyImageMetadata = serde_json::from_str(legacy_header)?;
assert!(metadata.content_type.is_none());
assert!(metadata.size.is_none());
assert!(metadata.last_modified.is_none());
Ok(())
}
#[test]
fn invalid_image_mime_value() {
let legacy_header = r#"{"content_type":"image/not-a-real-image"}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn invalid_date_time() {
let legacy_header = r#"{"last_modified":"idk last tuesday?"}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn invalid_size() {
let legacy_header = r#"{"size":-1}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn wrong_image_type() {
let legacy_header = r#"{"content_type":25}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn wrong_date_time_type() {
let legacy_header = r#"{"last_modified":false}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
}

590
src/cache/disk.rs vendored
View file

@ -1,31 +1,39 @@
//! Low memory caching stuff //! Low memory caching stuff
use std::path::PathBuf; use std::convert::TryFrom;
use std::hint::unreachable_unchecked;
use std::os::unix::prelude::OsStrExt;
use std::path::{Path, PathBuf};
use std::str::FromStr; use std::str::FromStr;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
use log::{error, warn, LevelFilter}; use log::LevelFilter;
use md5::digest::generic_array::GenericArray;
use md5::{Digest, Md5};
use sodiumoxide::hex;
use sqlx::sqlite::SqliteConnectOptions; use sqlx::sqlite::SqliteConnectOptions;
use sqlx::{ConnectOptions, SqlitePool}; use sqlx::{ConnectOptions, Sqlite, SqlitePool, Transaction};
use tokio::fs::remove_file; use tokio::fs::{create_dir_all, remove_file, rename, File};
use tokio::join;
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, info, instrument, warn};
use crate::units::Bytes; use crate::units::Bytes;
use super::{ use super::{Cache, CacheEntry, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata};
BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata,
};
#[derive(Debug)]
pub struct DiskCache { pub struct DiskCache {
disk_path: PathBuf, disk_path: PathBuf,
disk_cur_size: AtomicU64, disk_cur_size: AtomicU64,
db_update_channel_sender: Sender<DbMessage>, db_update_channel_sender: Sender<DbMessage>,
} }
#[derive(Debug)]
enum DbMessage { enum DbMessage {
Get(Arc<PathBuf>), Get(Arc<PathBuf>),
Put(Arc<PathBuf>, u64), Put(Arc<PathBuf>, u64),
@ -36,28 +44,46 @@ impl DiskCache {
/// This internally spawns a task that will wait for filesystem /// This internally spawns a task that will wait for filesystem
/// notifications when a file has been written. /// notifications when a file has been written.
pub async fn new(disk_max_size: Bytes, disk_path: PathBuf) -> Arc<Self> { pub async fn new(disk_max_size: Bytes, disk_path: PathBuf) -> Arc<Self> {
let (db_tx, db_rx) = channel(128); if let Err(e) = create_dir_all(&disk_path).await {
error!("Failed to create cache folder: {}", e);
}
let cache_path = disk_path.to_string_lossy();
// Migrate old to new path
if rename(
format!("{}/metadata.sqlite", cache_path),
format!("{}/metadata.db", cache_path),
)
.await
.is_ok()
{
info!("Found old metadata file, migrating to new location.");
}
let db_pool = { let db_pool = {
let db_url = format!("sqlite:{}/metadata.sqlite", disk_path.to_string_lossy()); let db_url = format!("sqlite:{}/metadata.db", cache_path);
let mut options = SqliteConnectOptions::from_str(&db_url) let mut options = SqliteConnectOptions::from_str(&db_url)
.unwrap() .unwrap()
.create_if_missing(true); .create_if_missing(true);
options.log_statements(LevelFilter::Trace); options.log_statements(LevelFilter::Trace);
let db = SqlitePool::connect_with(options).await.unwrap(); SqlitePool::connect_with(options).await.unwrap()
// Run db init
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut db.acquire().await.unwrap())
.await
.unwrap();
db
}; };
Self::from_db_pool(db_pool, disk_max_size, disk_path).await
}
async fn from_db_pool(pool: SqlitePool, disk_max_size: Bytes, disk_path: PathBuf) -> Arc<Self> {
let (db_tx, db_rx) = channel(128);
// Run db init
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut pool.acquire().await.unwrap())
.await
.unwrap();
// This is intentional. // This is intentional.
#[allow(clippy::cast_sign_loss)] #[allow(clippy::cast_sign_loss)]
let disk_cur_size = { let disk_cur_size = {
let mut conn = db_pool.acquire().await.unwrap(); let mut conn = pool.acquire().await.unwrap();
sqlx::query!("SELECT IFNULL(SUM(size), 0) AS size FROM Images") sqlx::query!("SELECT IFNULL(SUM(size), 0) AS size FROM Images")
.fetch_one(&mut conn) .fetch_one(&mut conn)
.await .await
@ -75,12 +101,25 @@ impl DiskCache {
tokio::spawn(db_listener( tokio::spawn(db_listener(
Arc::clone(&new_self), Arc::clone(&new_self),
db_rx, db_rx,
db_pool, pool,
disk_max_size.get() as u64 / 20 * 19, disk_max_size.get() as u64 / 20 * 19,
)); ));
new_self new_self
} }
#[cfg(test)]
fn in_memory() -> (Self, Receiver<DbMessage>) {
let (db_tx, db_rx) = channel(128);
(
Self {
disk_path: PathBuf::new(),
disk_cur_size: AtomicU64::new(0),
db_update_channel_sender: db_tx,
},
db_rx,
)
}
} }
/// Spawn a new task that will listen for updates to the db, pruning if the size /// Spawn a new task that will listen for updates to the db, pruning if the size
@ -91,9 +130,10 @@ async fn db_listener(
db_pool: SqlitePool, db_pool: SqlitePool,
max_on_disk_size: u64, max_on_disk_size: u64,
) { ) {
// This is in a receiver stream to process up to 128 simultaneous db updates
// in one transaction
let mut recv_stream = ReceiverStream::new(db_rx).ready_chunks(128); let mut recv_stream = ReceiverStream::new(db_rx).ready_chunks(128);
while let Some(messages) = recv_stream.next().await { while let Some(messages) = recv_stream.next().await {
let now = chrono::Utc::now();
let mut transaction = match db_pool.begin().await { let mut transaction = match db_pool.begin().await {
Ok(transaction) => transaction, Ok(transaction) => transaction,
Err(e) => { Err(e) => {
@ -101,38 +141,12 @@ async fn db_listener(
continue; continue;
} }
}; };
for message in messages { for message in messages {
match message { match message {
DbMessage::Get(entry) => { DbMessage::Get(entry) => handle_db_get(&entry, &mut transaction).await,
let key = entry.as_os_str().to_str();
let query =
sqlx::query!("update Images set accessed = ? where id = ?", now, key)
.execute(&mut transaction)
.await;
if let Err(e) = query {
warn!("Failed to update timestamp in db for {:?}: {}", key, e);
}
}
DbMessage::Put(entry, size) => { DbMessage::Put(entry, size) => {
let key = entry.as_os_str().to_str(); handle_db_put(&entry, size, &cache, &mut transaction).await;
{
// This is intentional.
#[allow(clippy::cast_possible_wrap)]
let size = size as i64;
let query = sqlx::query!(
"insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing",
key,
size,
now,
)
.execute(&mut transaction)
.await;
if let Err(e) = query {
warn!("Failed to add {:?} to db: {}", key, e);
}
}
cache.disk_cur_size.fetch_add(size, Ordering::Release);
} }
} }
} }
@ -146,21 +160,10 @@ async fn db_listener(
let on_disk_size = (cache.disk_cur_size.load(Ordering::Acquire) + 4095) / 4096 * 4096; let on_disk_size = (cache.disk_cur_size.load(Ordering::Acquire) + 4095) / 4096 * 4096;
if on_disk_size >= max_on_disk_size { if on_disk_size >= max_on_disk_size {
let mut conn = match db_pool.acquire().await {
Ok(conn) => conn,
Err(e) => {
error!(
"Failed to get a DB connection and cannot prune disk cache: {}",
e
);
continue;
}
};
let items = { let items = {
let request = let request =
sqlx::query!("select id, size from Images order by accessed asc limit 1000") sqlx::query!("select id, size from Images order by accessed asc limit 1000")
.fetch_all(&mut conn) .fetch_all(&db_pool)
.await; .await;
match request { match request {
Ok(items) => items, Ok(items) => items,
@ -177,8 +180,9 @@ async fn db_listener(
let mut size_freed = 0; let mut size_freed = 0;
#[allow(clippy::cast_sign_loss)] #[allow(clippy::cast_sign_loss)]
for item in items { for item in items {
debug!("deleting file due to exceeding cache size");
size_freed += item.size as u64; size_freed += item.size as u64;
tokio::spawn(remove_file(item.id)); tokio::spawn(remove_file_handler(item.id));
} }
cache.disk_cur_size.fetch_sub(size_freed, Ordering::Release); cache.disk_cur_size.fetch_sub(size_freed, Ordering::Release);
@ -186,6 +190,126 @@ async fn db_listener(
} }
} }
/// Returns if a file was successfully deleted.
async fn remove_file_handler(key: String) -> bool {
let error = if let Err(e) = remove_file(&key).await {
e
} else {
return true;
};
if error.kind() != std::io::ErrorKind::NotFound {
warn!("Failed to delete file `{}` from cache: {}", &key, error);
return false;
}
if let Ok(bytes) = hex::decode(&key) {
if bytes.len() != 16 {
warn!("Failed to delete file `{}`; invalid hash size.", &key);
return false;
}
let hash = Md5Hash(*GenericArray::from_slice(&bytes));
let path: PathBuf = hash.into();
if let Err(e) = remove_file(&path).await {
warn!(
"Failed to delete file `{}` from cache: {}",
path.to_string_lossy(),
e
);
false
} else {
true
}
} else {
warn!("Failed to delete file `{}`; not a md5hash.", &key);
false
}
}
#[instrument(level = "debug", skip(transaction))]
async fn handle_db_get(entry: &Path, transaction: &mut Transaction<'_, Sqlite>) {
let key = entry.as_os_str().to_str();
let now = chrono::Utc::now();
let query = sqlx::query!("update Images set accessed = ? where id = ?", now, key)
.execute(transaction)
.await;
if let Err(e) = query {
warn!("Failed to update timestamp in db for {:?}: {}", key, e);
}
}
#[instrument(level = "debug", skip(transaction, cache))]
async fn handle_db_put(
entry: &Path,
size: u64,
cache: &DiskCache,
transaction: &mut Transaction<'_, Sqlite>,
) {
let key = entry.as_os_str().to_str();
let now = chrono::Utc::now();
// This is intentional.
#[allow(clippy::cast_possible_wrap)]
let casted_size = size as i64;
let query = sqlx::query_file!("./db_queries/insert_image.sql", key, casted_size, now)
.execute(transaction)
.await;
if let Err(e) = query {
warn!("Failed to add to db: {}", e);
}
cache.disk_cur_size.fetch_add(size, Ordering::Release);
}
/// Represents a Md5 hash that can be converted to and from a path. This is used
/// for compatibility with the official client, where the image id and on-disk
/// path is determined by file path.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
struct Md5Hash(GenericArray<u8, <Md5 as md5::Digest>::OutputSize>);
impl Md5Hash {
fn to_hex_string(self) -> String {
format!("{:x}", self.0)
}
}
impl TryFrom<&Path> for Md5Hash {
type Error = ();
fn try_from(path: &Path) -> Result<Self, Self::Error> {
let mut iter = path.iter();
let file_name = iter.next_back().ok_or(())?;
let chapter_hash = iter.next_back().ok_or(())?;
let is_data_saver = iter.next_back().ok_or(())? == "saver";
let mut hasher = Md5::new();
if is_data_saver {
hasher.update("saver");
}
hasher.update(chapter_hash.as_bytes());
hasher.update(".");
hasher.update(file_name.as_bytes());
Ok(Self(hasher.finalize()))
}
}
impl From<Md5Hash> for PathBuf {
fn from(hash: Md5Hash) -> Self {
let hex_value = hash.to_hex_string();
let path = hex_value[0..3]
.chars()
.rev()
.map(|char| Self::from(char.to_string()))
.reduce(|first, second| first.join(second));
match path {
Some(p) => p.join(hex_value),
None => unsafe { unreachable_unchecked() }, // literally not possible
}
}
}
#[async_trait] #[async_trait]
impl Cache for DiskCache { impl Cache for DiskCache {
async fn get( async fn get(
@ -197,12 +321,33 @@ impl Cache for DiskCache {
let path = Arc::new(self.disk_path.clone().join(PathBuf::from(key))); let path = Arc::new(self.disk_path.clone().join(PathBuf::from(key)));
let path_0 = Arc::clone(&path); let path_0 = Arc::clone(&path);
tokio::spawn(async move { channel.send(DbMessage::Get(path_0)).await }); let legacy_path = Md5Hash::try_from(path_0.as_path())
.map(PathBuf::from)
.map(|path| self.disk_path.clone().join(path))
.map(Arc::new);
super::fs::read_file(&path).await.map(|res| { // Get file and path of first existing location path
let (inner, maybe_header, metadata) = res?; let (file, path) = if let Ok(legacy_path) = legacy_path {
CacheStream::new(inner, maybe_header) let maybe_files = join!(
.map(|stream| (stream, metadata)) File::open(legacy_path.as_path()),
File::open(path.as_path()),
);
match maybe_files {
(Ok(f), _) => Some((f, legacy_path)),
(_, Ok(f)) => Some((f, path)),
_ => return None,
}
} else {
File::open(path.as_path())
.await
.ok()
.map(|file| (file, path))
}?;
tokio::spawn(async move { channel.send(DbMessage::Get(path)).await });
super::fs::read_file(file).await.map(|res| {
res.map(|(stream, _, metadata)| (stream, metadata))
.map_err(|_| CacheError::DecryptionFailure) .map_err(|_| CacheError::DecryptionFailure)
}) })
} }
@ -210,9 +355,9 @@ impl Cache for DiskCache {
async fn put( async fn put(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: bytes::Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
) -> Result<CacheStream, CacheError> { ) -> Result<(), CacheError> {
let channel = self.db_update_channel_sender.clone(); let channel = self.db_update_channel_sender.clone();
let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key))); let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key)));
@ -225,9 +370,6 @@ impl Cache for DiskCache {
super::fs::write_file(&path, key, image, metadata, db_callback, None) super::fs::write_file(&path, key, image, metadata, db_callback, None)
.await .await
.map_err(CacheError::from) .map_err(CacheError::from)
.and_then(|(inner, maybe_header)| {
CacheStream::new(inner, maybe_header).map_err(|_| CacheError::DecryptionFailure)
})
} }
} }
@ -236,10 +378,10 @@ impl CallbackCache for DiskCache {
async fn put_with_on_completed_callback( async fn put_with_on_completed_callback(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: bytes::Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
on_complete: Sender<(CacheKey, bytes::Bytes, ImageMetadata, u64)>, on_complete: Sender<CacheEntry>,
) -> Result<CacheStream, CacheError> { ) -> Result<(), CacheError> {
let channel = self.db_update_channel_sender.clone(); let channel = self.db_update_channel_sender.clone();
let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key))); let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key)));
@ -253,8 +395,298 @@ impl CallbackCache for DiskCache {
super::fs::write_file(&path, key, image, metadata, db_callback, Some(on_complete)) super::fs::write_file(&path, key, image, metadata, db_callback, Some(on_complete))
.await .await
.map_err(CacheError::from) .map_err(CacheError::from)
.and_then(|(inner, maybe_header)| { }
CacheStream::new(inner, maybe_header).map_err(|_| CacheError::DecryptionFailure) }
})
#[cfg(test)]
mod db_listener {
use super::{db_listener, DbMessage};
use crate::DiskCache;
use futures::TryStreamExt;
use sqlx::{Row, SqlitePool};
use std::error::Error;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc::channel;
#[tokio::test]
async fn can_handle_multiple_events() -> Result<(), Box<dyn Error>> {
let (mut cache, rx) = DiskCache::in_memory();
let (mut tx, _) = channel(1);
// Swap the tx with the new one, else the receiver will never end
std::mem::swap(&mut cache.db_update_channel_sender, &mut tx);
assert_eq!(tx.capacity(), 128);
let cache = Arc::new(cache);
let db = SqlitePool::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&db)
.await?;
// Populate the queue with messages
for c in 'a'..='z' {
tx.send(DbMessage::Put(Arc::new(PathBuf::from(c.to_string())), 10))
.await?;
tx.send(DbMessage::Get(Arc::new(PathBuf::from(c.to_string()))))
.await?;
}
// Explicitly close the channel so that the listener terminates
std::mem::drop(tx);
db_listener(cache, rx, db.clone(), u64::MAX).await;
let count = Arc::new(AtomicUsize::new(0));
sqlx::query("select * from Images")
.fetch(&db)
.try_for_each_concurrent(None, |row| {
let count = Arc::clone(&count);
async move {
assert_eq!(row.get::<i32, _>("size"), 10);
count.fetch_add(1, Ordering::Release);
Ok(())
}
})
.await?;
assert_eq!(count.load(Ordering::Acquire), 26);
Ok(())
}
}
#[cfg(test)]
mod remove_file_handler {
use std::error::Error;
use tempfile::tempdir;
use tokio::fs::{create_dir_all, remove_dir_all};
use super::{remove_file_handler, File};
#[tokio::test]
async fn should_not_panic_on_invalid_path() {
assert!(!remove_file_handler("/this/is/a/non-existent/path/".to_string()).await);
}
#[tokio::test]
async fn should_not_panic_on_invalid_hash() {
assert!(!remove_file_handler("68b329da9893e34099c7d8ad5cb9c940".to_string()).await);
}
#[tokio::test]
async fn should_not_panic_on_malicious_hashes() {
assert!(!remove_file_handler("68b329da9893e34".to_string()).await);
assert!(
!remove_file_handler("68b329da9893e34099c7d8ad5cb9c940aaaaaaaaaaaaaaaaaa".to_string())
.await
);
}
#[tokio::test]
async fn should_delete_existing_file() -> Result<(), Box<dyn Error>> {
let temp_dir = tempdir()?;
let mut dir_path = temp_dir.path().to_path_buf();
dir_path.push("abc123.png");
// create a file, it can be empty
File::create(&dir_path).await?;
assert!(remove_file_handler(dir_path.to_string_lossy().into_owned()).await);
Ok(())
}
#[tokio::test]
async fn should_delete_existing_hash() -> Result<(), Box<dyn Error>> {
create_dir_all("b/8/6").await?;
File::create("b/8/6/68b329da9893e34099c7d8ad5cb9c900").await?;
assert!(remove_file_handler("68b329da9893e34099c7d8ad5cb9c900".to_string()).await);
remove_dir_all("b").await?;
Ok(())
}
}
#[cfg(test)]
mod disk_cache {
use std::error::Error;
use std::path::PathBuf;
use std::sync::atomic::Ordering;
use chrono::Utc;
use sqlx::SqlitePool;
use crate::units::Bytes;
use super::DiskCache;
#[tokio::test]
async fn db_is_initialized() -> Result<(), Box<dyn Error>> {
let conn = SqlitePool::connect("sqlite::memory:").await?;
let _cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await;
let res = sqlx::query("select * from Images").execute(&conn).await;
assert!(res.is_ok());
Ok(())
}
#[tokio::test]
async fn db_initializes_empty() -> Result<(), Box<dyn Error>> {
let conn = SqlitePool::connect("sqlite::memory:").await?;
let cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await;
assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 0);
Ok(())
}
#[tokio::test]
async fn db_can_load_from_existing() -> Result<(), Box<dyn Error>> {
let conn = SqlitePool::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&conn)
.await?;
let now = Utc::now();
sqlx::query_file!("./db_queries/insert_image.sql", "a", 4, now)
.execute(&conn)
.await?;
let now = Utc::now();
sqlx::query_file!("./db_queries/insert_image.sql", "b", 15, now)
.execute(&conn)
.await?;
let cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await;
assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 19);
Ok(())
}
}
#[cfg(test)]
mod db {
use chrono::{DateTime, Utc};
use sqlx::{Connection, Row, SqliteConnection};
use std::error::Error;
use super::{handle_db_get, handle_db_put, DiskCache, FromStr, Ordering, PathBuf, StreamExt};
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn get() -> Result<(), Box<dyn Error>> {
let (cache, _) = DiskCache::in_memory();
let path = PathBuf::from_str("a/b/c")?;
let mut conn = SqliteConnection::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut conn)
.await?;
// Add an entry
let mut transaction = conn.begin().await?;
handle_db_put(&path, 10, &cache, &mut transaction).await;
transaction.commit().await?;
let time_fence = Utc::now();
let mut transaction = conn.begin().await?;
handle_db_get(&path, &mut transaction).await;
transaction.commit().await?;
let mut rows: Vec<_> = sqlx::query("select * from Images")
.fetch(&mut conn)
.collect()
.await;
assert_eq!(rows.len(), 1);
let entry = rows.pop().unwrap()?;
assert!(time_fence < entry.get::<'_, DateTime<Utc>, _>("accessed"));
Ok(())
}
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn put() -> Result<(), Box<dyn Error>> {
let (cache, _) = DiskCache::in_memory();
let path = PathBuf::from_str("a/b/c")?;
let mut conn = SqliteConnection::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut conn)
.await?;
let mut transaction = conn.begin().await?;
let transaction_time = Utc::now();
handle_db_put(&path, 10, &cache, &mut transaction).await;
transaction.commit().await?;
let mut rows: Vec<_> = sqlx::query("select * from Images")
.fetch(&mut conn)
.collect()
.await;
assert_eq!(rows.len(), 1);
let entry = rows.pop().unwrap()?;
assert_eq!(entry.get::<'_, &str, _>("id"), "a/b/c");
assert_eq!(entry.get::<'_, i64, _>("size"), 10);
let accessed: DateTime<Utc> = entry.get("accessed");
assert!(transaction_time < accessed);
assert!(accessed < Utc::now());
assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 10);
Ok(())
}
}
#[cfg(test)]
mod md5_hash {
use super::{Digest, GenericArray, Md5, Md5Hash, Path, PathBuf, TryFrom};
#[test]
fn to_cache_path() {
let hash = Md5Hash(
*GenericArray::<_, <Md5 as md5::Digest>::OutputSize>::from_slice(&[
0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd,
0xef, 0xab,
]),
);
assert_eq!(
PathBuf::from(hash).to_str(),
Some("c/b/a/abcdefabcdefabcdefabcdefabcdefab")
)
}
#[test]
fn from_data_path() {
let mut expected_hasher = Md5::new();
expected_hasher.update("foo.bar.png");
assert_eq!(
Md5Hash::try_from(Path::new("data/foo/bar.png")),
Ok(Md5Hash(expected_hasher.finalize()))
);
}
#[test]
fn from_data_saver_path() {
let mut expected_hasher = Md5::new();
expected_hasher.update("saverfoo.bar.png");
assert_eq!(
Md5Hash::try_from(Path::new("saver/foo/bar.png")),
Ok(Md5Hash(expected_hasher.finalize()))
);
}
#[test]
fn can_handle_long_paths() {
assert_eq!(
Md5Hash::try_from(Path::new("a/b/c/d/e/f/g/saver/foo/bar.png")),
Md5Hash::try_from(Path::new("saver/foo/bar.png")),
);
}
#[test]
fn from_invalid_paths() {
assert!(Md5Hash::try_from(Path::new("foo/bar.png")).is_err());
assert!(Md5Hash::try_from(Path::new("bar.png")).is_err());
assert!(Md5Hash::try_from(Path::new("")).is_err());
} }
} }

886
src/cache/fs.rs vendored

File diff suppressed because it is too large Load diff

692
src/cache/mem.rs vendored
View file

@ -1,19 +1,74 @@
use std::borrow::Cow;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use super::{ use super::{Cache, CacheEntry, CacheKey, CacheStream, CallbackCache, ImageMetadata, MemStream};
BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata,
InnerStream, MemStream,
};
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use futures::FutureExt; use futures::FutureExt;
use lfu_cache::LfuCache; use lfu_cache::LfuCache;
use lru::LruCache; use lru::LruCache;
use tokio::sync::mpsc::{channel, Sender}; use redis::{
Client as RedisClient, Commands, FromRedisValue, RedisError, RedisResult, ToRedisArgs,
};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::warn;
type CacheValue = (Bytes, ImageMetadata, u64); #[derive(Clone, Serialize, Deserialize)]
pub struct CacheValue {
data: Bytes,
metadata: ImageMetadata,
on_disk_size: u64,
}
impl CacheValue {
#[inline]
fn new(data: Bytes, metadata: ImageMetadata, on_disk_size: u64) -> Self {
Self {
data,
metadata,
on_disk_size,
}
}
}
impl FromRedisValue for CacheValue {
fn from_redis_value(v: &redis::Value) -> RedisResult<Self> {
use bincode::ErrorKind;
if let redis::Value::Data(data) = v {
bincode::deserialize(data).map_err(|err| match *err {
ErrorKind::Io(e) => RedisError::from(e),
ErrorKind::Custom(e) => RedisError::from((
redis::ErrorKind::ResponseError,
"bincode deserialize failed",
e,
)),
e => RedisError::from((
redis::ErrorKind::ResponseError,
"bincode deserialized failed",
e.to_string(),
)),
})
} else {
Err(RedisError::from((
redis::ErrorKind::TypeError,
"Got non data type from redis db",
)))
}
}
}
impl ToRedisArgs for CacheValue {
fn write_redis_args<W>(&self, out: &mut W)
where
W: ?Sized + redis::RedisWrite,
{
out.write_arg(&bincode::serialize(self).expect("serialization to work"));
}
}
/// Use LRU as the eviction strategy /// Use LRU as the eviction strategy
pub type Lru = LruCache<CacheKey, CacheValue>; pub type Lru = LruCache<CacheKey, CacheValue>;
@ -21,22 +76,29 @@ pub type Lru = LruCache<CacheKey, CacheValue>;
pub type Lfu = LfuCache<CacheKey, CacheValue>; pub type Lfu = LfuCache<CacheKey, CacheValue>;
/// Adapter trait for memory cache backends /// Adapter trait for memory cache backends
pub trait InternalMemoryCacheInitializer: InternalMemoryCache {
fn new() -> Self;
}
pub trait InternalMemoryCache: Sync + Send { pub trait InternalMemoryCache: Sync + Send {
fn unbounded() -> Self; fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>>;
fn get(&mut self, key: &CacheKey) -> Option<&CacheValue>;
fn push(&mut self, key: CacheKey, data: CacheValue); fn push(&mut self, key: CacheKey, data: CacheValue);
fn pop(&mut self) -> Option<(CacheKey, CacheValue)>; fn pop(&mut self) -> Option<(CacheKey, CacheValue)>;
} }
impl InternalMemoryCache for Lfu { #[cfg(not(tarpaulin_include))]
impl InternalMemoryCacheInitializer for Lfu {
#[inline] #[inline]
fn unbounded() -> Self { fn new() -> Self {
Self::unbounded() Self::unbounded()
} }
}
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCache for Lfu {
#[inline] #[inline]
fn get(&mut self, key: &CacheKey) -> Option<&CacheValue> { fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>> {
self.get(key) self.get(key).map(Cow::Borrowed)
} }
#[inline] #[inline]
@ -50,15 +112,19 @@ impl InternalMemoryCache for Lfu {
} }
} }
impl InternalMemoryCache for Lru { #[cfg(not(tarpaulin_include))]
impl InternalMemoryCacheInitializer for Lru {
#[inline] #[inline]
fn unbounded() -> Self { fn new() -> Self {
Self::unbounded() Self::unbounded()
} }
}
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCache for Lru {
#[inline] #[inline]
fn get(&mut self, key: &CacheKey) -> Option<&CacheValue> { fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>> {
self.get(key) self.get(key).map(Cow::Borrowed)
} }
#[inline] #[inline]
@ -72,13 +138,73 @@ impl InternalMemoryCache for Lru {
} }
} }
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCache for RedisClient {
fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>> {
Commands::get(self, key).ok().map(Cow::Owned)
}
fn push(&mut self, key: CacheKey, data: CacheValue) {
if let Err(e) = Commands::set::<_, _, ()>(self, key, data) {
warn!("Failed to push to redis: {}", e);
}
}
fn pop(&mut self) -> Option<(CacheKey, CacheValue)> {
unimplemented!("redis should handle its own memory")
}
}
/// Memory accelerated disk cache. Uses the internal cache implementation in /// Memory accelerated disk cache. Uses the internal cache implementation in
/// memory to speed up reads. /// memory to speed up reads.
pub struct MemoryCache<MemoryCacheImpl, ColdCache> { pub struct MemoryCache<MemoryCacheImpl, ColdCache> {
inner: ColdCache, inner: ColdCache,
cur_mem_size: AtomicU64, cur_mem_size: AtomicU64,
mem_cache: Mutex<MemoryCacheImpl>, mem_cache: Mutex<MemoryCacheImpl>,
master_sender: Sender<(CacheKey, Bytes, ImageMetadata, u64)>, master_sender: Sender<CacheEntry>,
}
impl<MemoryCacheImpl, ColdCache> MemoryCache<MemoryCacheImpl, ColdCache>
where
MemoryCacheImpl: 'static + InternalMemoryCacheInitializer,
ColdCache: 'static + Cache,
{
pub fn new(inner: ColdCache, max_mem_size: crate::units::Bytes) -> Arc<Self> {
let (tx, rx) = channel(100);
let new_self = Arc::new(Self {
inner,
cur_mem_size: AtomicU64::new(0),
mem_cache: Mutex::new(MemoryCacheImpl::new()),
master_sender: tx,
});
tokio::spawn(internal_cache_listener(
Arc::clone(&new_self),
max_mem_size,
rx,
));
new_self
}
/// Returns an instance of the cache with the receiver for callback events
/// Really only useful for inspecting the receiver, e.g. for testing
#[cfg(test)]
pub fn new_with_receiver(
inner: ColdCache,
_: crate::units::Bytes,
) -> (Self, Receiver<CacheEntry>) {
let (tx, rx) = channel(100);
(
Self {
inner,
cur_mem_size: AtomicU64::new(0),
mem_cache: Mutex::new(MemoryCacheImpl::new()),
master_sender: tx,
},
rx,
)
}
} }
impl<MemoryCacheImpl, ColdCache> MemoryCache<MemoryCacheImpl, ColdCache> impl<MemoryCacheImpl, ColdCache> MemoryCache<MemoryCacheImpl, ColdCache>
@ -86,54 +212,68 @@ where
MemoryCacheImpl: 'static + InternalMemoryCache, MemoryCacheImpl: 'static + InternalMemoryCache,
ColdCache: 'static + Cache, ColdCache: 'static + Cache,
{ {
pub async fn new(inner: ColdCache, max_mem_size: crate::units::Bytes) -> Arc<Self> { pub fn new_with_cache(inner: ColdCache, init_mem_cache: MemoryCacheImpl) -> Self {
let (tx, mut rx) = channel(100); Self {
let new_self = Arc::new(Self {
inner, inner,
cur_mem_size: AtomicU64::new(0), cur_mem_size: AtomicU64::new(0),
mem_cache: Mutex::new(MemoryCacheImpl::unbounded()), mem_cache: Mutex::new(init_mem_cache),
master_sender: tx, master_sender: channel(1).0,
}); }
let new_self_0 = Arc::clone(&new_self);
tokio::spawn(async move {
let new_self = new_self_0;
let max_mem_size = max_mem_size.get() / 20 * 19;
while let Some((key, bytes, metadata, size)) = rx.recv().await {
// Add to memory cache
// We can add first because we constrain our memory usage to 95%
new_self
.cur_mem_size
.fetch_add(size as u64, Ordering::Release);
new_self
.mem_cache
.lock()
.await
.push(key, (bytes, metadata, size));
// Pop if too large
while new_self.cur_mem_size.load(Ordering::Acquire) >= max_mem_size as u64 {
let popped = new_self
.mem_cache
.lock()
.await
.pop()
.map(|(key, (bytes, metadata, size))| (key, bytes, metadata, size));
if let Some((_, _, _, size)) = popped {
new_self
.cur_mem_size
.fetch_sub(size as u64, Ordering::Release);
} else {
break;
}
}
}
});
new_self
} }
} }
async fn internal_cache_listener<MemoryCacheImpl, ColdCache>(
cache: Arc<MemoryCache<MemoryCacheImpl, ColdCache>>,
max_mem_size: crate::units::Bytes,
mut rx: Receiver<CacheEntry>,
) where
MemoryCacheImpl: InternalMemoryCache,
ColdCache: Cache,
{
let max_mem_size = mem_threshold(&max_mem_size);
while let Some(CacheEntry {
key,
data,
metadata,
on_disk_size,
}) = rx.recv().await
{
// Add to memory cache
// We can add first because we constrain our memory usage to 95%
cache
.cur_mem_size
.fetch_add(on_disk_size as u64, Ordering::Release);
cache
.mem_cache
.lock()
.await
.push(key, CacheValue::new(data, metadata, on_disk_size));
// Pop if too large
while cache.cur_mem_size.load(Ordering::Acquire) >= max_mem_size as u64 {
let popped = cache.mem_cache.lock().await.pop().map(
|(
key,
CacheValue {
data,
metadata,
on_disk_size,
},
)| (key, data, metadata, on_disk_size),
);
if let Some((_, _, _, size)) = popped {
cache.cur_mem_size.fetch_sub(size as u64, Ordering::Release);
} else {
break;
}
}
}
}
const fn mem_threshold(bytes: &crate::units::Bytes) -> usize {
bytes.get() / 20 * 19
}
#[async_trait] #[async_trait]
impl<MemoryCacheImpl, ColdCache> Cache for MemoryCache<MemoryCacheImpl, ColdCache> impl<MemoryCacheImpl, ColdCache> Cache for MemoryCache<MemoryCacheImpl, ColdCache>
where where
@ -146,16 +286,16 @@ where
key: &CacheKey, key: &CacheKey,
) -> Option<Result<(CacheStream, ImageMetadata), super::CacheError>> { ) -> Option<Result<(CacheStream, ImageMetadata), super::CacheError>> {
match self.mem_cache.lock().now_or_never() { match self.mem_cache.lock().now_or_never() {
Some(mut mem_cache) => match mem_cache.get(key).map(|(bytes, metadata, _)| { Some(mut mem_cache) => {
Ok((InnerStream::Memory(MemStream(bytes.clone())), *metadata)) match mem_cache.get(key).map(Cow::into_owned).map(
}) { |CacheValue { data, metadata, .. }| {
Some(v) => Some(v.and_then(|(inner, metadata)| { Ok((CacheStream::Memory(MemStream(data)), metadata))
CacheStream::new(inner, None) },
.map(|v| (v, metadata)) ) {
.map_err(|_| CacheError::DecryptionFailure) Some(v) => Some(v),
})), None => self.inner.get(key).await,
None => self.inner.get(key).await, }
}, }
None => self.inner.get(key).await, None => self.inner.get(key).await,
} }
} }
@ -164,11 +304,419 @@ where
async fn put( async fn put(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
) -> Result<CacheStream, super::CacheError> { ) -> Result<(), super::CacheError> {
self.inner self.inner
.put_with_on_completed_callback(key, image, metadata, self.master_sender.clone()) .put_with_on_completed_callback(key, image, metadata, self.master_sender.clone())
.await .await
} }
} }
#[cfg(test)]
mod test_util {
use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap};
use super::{CacheValue, InternalMemoryCache, InternalMemoryCacheInitializer};
use crate::cache::{
Cache, CacheEntry, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata,
};
use async_trait::async_trait;
use parking_lot::Mutex;
use tokio::io::BufReader;
use tokio::sync::mpsc::Sender;
use tokio_util::io::ReaderStream;
#[derive(Default)]
pub struct TestDiskCache(
pub Mutex<RefCell<HashMap<CacheKey, Result<(CacheStream, ImageMetadata), CacheError>>>>,
);
#[async_trait]
impl Cache for TestDiskCache {
async fn get(
&self,
key: &CacheKey,
) -> Option<Result<(CacheStream, ImageMetadata), CacheError>> {
self.0.lock().get_mut().remove(key)
}
async fn put(
&self,
key: CacheKey,
image: bytes::Bytes,
metadata: ImageMetadata,
) -> Result<(), CacheError> {
let reader = Box::pin(BufReader::new(tokio_util::io::StreamReader::new(
tokio_stream::once(Ok::<_, std::io::Error>(image)),
)));
let stream = CacheStream::Completed(ReaderStream::new(reader));
self.0.lock().get_mut().insert(key, Ok((stream, metadata)));
Ok(())
}
}
#[async_trait]
impl CallbackCache for TestDiskCache {
async fn put_with_on_completed_callback(
&self,
key: CacheKey,
data: bytes::Bytes,
metadata: ImageMetadata,
on_complete: Sender<CacheEntry>,
) -> Result<(), CacheError> {
self.put(key.clone(), data.clone(), metadata)
.await?;
let on_disk_size = data.len() as u64;
let _ = on_complete
.send(CacheEntry {
key,
data,
metadata,
on_disk_size,
})
.await;
Ok(())
}
}
#[derive(Default)]
pub struct TestMemoryCache(pub BTreeMap<CacheKey, CacheValue>);
impl InternalMemoryCacheInitializer for TestMemoryCache {
fn new() -> Self {
Self::default()
}
}
impl InternalMemoryCache for TestMemoryCache {
fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>> {
self.0.get(key).map(Cow::Borrowed)
}
fn push(&mut self, key: CacheKey, data: CacheValue) {
self.0.insert(key, data);
}
fn pop(&mut self) -> Option<(CacheKey, CacheValue)> {
let mut cache = BTreeMap::new();
std::mem::swap(&mut cache, &mut self.0);
let mut iter = cache.into_iter();
let ret = iter.next();
self.0 = iter.collect();
ret
}
}
}
#[cfg(test)]
mod cache_ops {
use std::error::Error;
use bytes::Bytes;
use futures::{FutureExt, StreamExt};
use crate::cache::mem::{CacheValue, InternalMemoryCache};
use crate::cache::{Cache, CacheEntry, CacheKey, CacheStream, ImageMetadata, MemStream};
use super::test_util::{TestDiskCache, TestMemoryCache};
use super::MemoryCache;
#[tokio::test]
async fn get_mem_cached() -> Result<(), Box<dyn Error>> {
let (cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
let value = CacheValue::new(bytes.clone(), metadata, 34);
// Populate the cache, need to drop the lock else it's considered locked
// when we actually call the cache
{
let mem_cache = &mut cache.mem_cache.lock().await;
mem_cache.push(key.clone(), value.clone());
}
let (stream, ret_metadata) = cache.get(&key).await.unwrap()?;
assert_eq!(metadata, ret_metadata);
if let CacheStream::Memory(MemStream(ret_stream)) = stream {
assert_eq!(bytes, ret_stream);
} else {
panic!("wrong stream type");
}
assert!(rx.recv().now_or_never().is_none());
Ok(())
}
#[tokio::test]
async fn get_disk_cached() -> Result<(), Box<dyn Error>> {
let (mut cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
{
let cache = &mut cache.inner;
cache
.put(key.clone(), bytes.clone(), metadata)
.await?;
}
let (mut stream, ret_metadata) = cache.get(&key).await.unwrap()?;
assert_eq!(metadata, ret_metadata);
assert!(matches!(stream, CacheStream::Completed(_)));
assert_eq!(stream.next().await, Some(Ok(bytes.clone())));
assert!(rx.recv().now_or_never().is_none());
Ok(())
}
// Identical to the get_disk_cached test but we hold a lock on the mem_cache
#[tokio::test]
async fn get_mem_locked() -> Result<(), Box<dyn Error>> {
let (mut cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
{
let cache = &mut cache.inner;
cache
.put(key.clone(), bytes.clone(), metadata)
.await?;
}
// intentionally not dropped
let _mem_cache = &mut cache.mem_cache.lock().await;
let (mut stream, ret_metadata) = cache.get(&key).await.unwrap()?;
assert_eq!(metadata, ret_metadata);
assert!(matches!(stream, CacheStream::Completed(_)));
assert_eq!(stream.next().await, Some(Ok(bytes.clone())));
assert!(rx.recv().now_or_never().is_none());
Ok(())
}
#[tokio::test]
async fn get_miss() {
let (cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
assert!(cache.get(&key).await.is_none());
assert!(rx.recv().now_or_never().is_none());
}
#[tokio::test]
async fn put_puts_into_disk_and_hears_from_rx() -> Result<(), Box<dyn Error>> {
let (cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
let bytes_len = bytes.len() as u64;
cache
.put(key.clone(), bytes.clone(), metadata)
.await?;
// Because the callback is supposed to let the memory cache insert the
// entry into its cache, we can check that it properly stored it on the
// disk layer by checking if we can successfully fetch it.
let (mut stream, ret_metadata) = cache.get(&key).await.unwrap()?;
assert_eq!(metadata, ret_metadata);
assert!(matches!(stream, CacheStream::Completed(_)));
assert_eq!(stream.next().await, Some(Ok(bytes.clone())));
// Check that we heard back
let cache_entry = rx
.recv()
.now_or_never()
.flatten()
.ok_or("failed to hear back from cache")?;
assert_eq!(
cache_entry,
CacheEntry {
key,
data: bytes,
metadata,
on_disk_size: bytes_len,
}
);
Ok(())
}
}
#[cfg(test)]
mod db_listener {
use std::error::Error;
use std::iter::FromIterator;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use bytes::Bytes;
use tokio::task;
use crate::cache::{Cache, CacheKey, ImageMetadata};
use super::test_util::{TestDiskCache, TestMemoryCache};
use super::{internal_cache_listener, MemoryCache};
#[tokio::test]
async fn put_into_memory() -> Result<(), Box<dyn Error>> {
let (cache, rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(0),
);
let cache = Arc::new(cache);
tokio::spawn(internal_cache_listener(
Arc::clone(&cache),
crate::units::Bytes(20),
rx,
));
// put small image into memory
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
cache.put(key.clone(), bytes.clone(), metadata).await?;
// let the listener run first
for _ in 0..10 {
task::yield_now().await;
}
assert_eq!(
cache.cur_mem_size.load(Ordering::SeqCst),
bytes.len() as u64
);
// Since we didn't populate the cache, fetching must be from memory, so
// this should succeed since the cache listener should push the item
// into cache
assert!(cache.get(&key).await.is_some());
Ok(())
}
#[tokio::test]
async fn pops_items() -> Result<(), Box<dyn Error>> {
let (cache, rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(0),
);
let cache = Arc::new(cache);
tokio::spawn(internal_cache_listener(
Arc::clone(&cache),
crate::units::Bytes(20),
rx,
));
// put small image into memory
let key_0 = CacheKey("a".to_string(), "b".to_string(), false);
let key_1 = CacheKey("c".to_string(), "d".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcde");
cache.put(key_0, bytes.clone(), metadata).await?;
cache.put(key_1, bytes.clone(), metadata).await?;
// let the listener run first
task::yield_now().await;
for _ in 0..10 {
task::yield_now().await;
}
// Items should be in cache now
assert_eq!(
cache.cur_mem_size.load(Ordering::SeqCst),
(bytes.len() * 2) as u64
);
let key_3 = CacheKey("e".to_string(), "f".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_iter(b"0".repeat(16).into_iter());
let bytes_len = bytes.len();
cache.put(key_3, bytes, metadata).await?;
// let the listener run first
task::yield_now().await;
for _ in 0..10 {
task::yield_now().await;
}
// Items should have been evicted, only 16 bytes should be there now
assert_eq!(cache.cur_mem_size.load(Ordering::SeqCst), bytes_len as u64);
Ok(())
}
}
#[cfg(test)]
mod mem_threshold {
use crate::units::Bytes;
use super::mem_threshold;
#[test]
fn small_amount_works() {
assert_eq!(mem_threshold(&Bytes(100)), 95);
}
#[test]
fn large_amount_cannot_overflow() {
assert_eq!(mem_threshold(&Bytes(usize::MAX)), 17_524_406_870_024_074_020);
}
}

138
src/cache/mod.rs vendored
View file

@ -5,34 +5,46 @@ use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use actix_web::http::HeaderValue; use actix_web::http::header::HeaderValue;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::{Bytes, BytesMut}; use bytes::Bytes;
use chacha20::Key;
use chrono::{DateTime, FixedOffset}; use chrono::{DateTime, FixedOffset};
use fs::ConcurrentFsStream;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use redis::ToRedisArgs;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr}; use serde_repr::{Deserialize_repr, Serialize_repr};
use sodiumoxide::crypto::secretstream::{Header, Key, Pull, Stream as SecretStream};
use thiserror::Error; use thiserror::Error;
use tokio::io::AsyncRead;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use tokio_util::codec::{BytesCodec, FramedRead}; use tokio_util::io::ReaderStream;
pub use disk::DiskCache; pub use disk::DiskCache;
pub use fs::UpstreamError; pub use fs::UpstreamError;
pub use mem::MemoryCache; pub use mem::MemoryCache;
use self::compat::LegacyImageMetadata;
use self::fs::MetadataFetch;
pub static ENCRYPTION_KEY: OnceCell<Key> = OnceCell::new(); pub static ENCRYPTION_KEY: OnceCell<Key> = OnceCell::new();
mod compat;
mod disk; mod disk;
mod fs; mod fs;
pub mod mem; pub mod mem;
#[derive(PartialEq, Eq, Hash, Clone)] #[derive(PartialEq, Eq, Hash, Clone, Debug, PartialOrd, Ord)]
pub struct CacheKey(pub String, pub String, pub bool); pub struct CacheKey(pub String, pub String, pub bool);
impl ToRedisArgs for CacheKey {
fn write_redis_args<W>(&self, out: &mut W)
where
W: ?Sized + redis::RedisWrite,
{
out.write_arg_fmt(self);
}
}
impl Display for CacheKey { impl Display for CacheKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.2 { if self.2 {
@ -60,7 +72,7 @@ impl From<&CacheKey> for PathBuf {
#[derive(Clone)] #[derive(Clone)]
pub struct CachedImage(pub Bytes); pub struct CachedImage(pub Bytes);
#[derive(Copy, Clone, Serialize, Deserialize)] #[derive(Copy, Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct ImageMetadata { pub struct ImageMetadata {
pub content_type: Option<ImageContentType>, pub content_type: Option<ImageContentType>,
pub content_length: Option<u32>, pub content_length: Option<u32>,
@ -68,7 +80,7 @@ pub struct ImageMetadata {
} }
// Confirmed by Ply to be these types: https://link.eddie.sh/ZXfk0 // Confirmed by Ply to be these types: https://link.eddie.sh/ZXfk0
#[derive(Copy, Clone, Serialize_repr, Deserialize_repr)] #[derive(Copy, Clone, Serialize_repr, Deserialize_repr, Debug, PartialEq, Eq)]
#[repr(u8)] #[repr(u8)]
pub enum ImageContentType { pub enum ImageContentType {
Png = 0, Png = 0,
@ -103,12 +115,21 @@ impl AsRef<str> for ImageContentType {
} }
} }
#[allow(clippy::pub_enum_variant_names)] impl From<LegacyImageMetadata> for ImageMetadata {
fn from(legacy: LegacyImageMetadata) -> Self {
Self {
content_type: legacy.content_type.map(|v| v.0),
content_length: legacy.size,
last_modified: legacy.last_modified.map(|v| v.0),
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub enum ImageRequestError { pub enum ImageRequestError {
InvalidContentType, ContentType,
InvalidContentLength, ContentLength,
InvalidLastModified, LastModified,
} }
impl ImageMetadata { impl ImageMetadata {
@ -124,14 +145,14 @@ impl ImageMetadata {
Err(_) => Err(InvalidContentType), Err(_) => Err(InvalidContentType),
}) })
.transpose() .transpose()
.map_err(|_| ImageRequestError::InvalidContentType)?, .map_err(|_| ImageRequestError::ContentType)?,
content_length: content_length content_length: content_length
.map(|header_val| { .map(|header_val| {
header_val header_val
.to_str() .to_str()
.map_err(|_| ImageRequestError::InvalidContentLength)? .map_err(|_| ImageRequestError::ContentLength)?
.parse() .parse()
.map_err(|_| ImageRequestError::InvalidContentLength) .map_err(|_| ImageRequestError::ContentLength)
}) })
.transpose()?, .transpose()?,
last_modified: last_modified last_modified: last_modified
@ -139,17 +160,15 @@ impl ImageMetadata {
DateTime::parse_from_rfc2822( DateTime::parse_from_rfc2822(
header_val header_val
.to_str() .to_str()
.map_err(|_| ImageRequestError::InvalidLastModified)?, .map_err(|_| ImageRequestError::LastModified)?,
) )
.map_err(|_| ImageRequestError::InvalidLastModified) .map_err(|_| ImageRequestError::LastModified)
}) })
.transpose()?, .transpose()?,
}) })
} }
} }
type BoxedImageStream = Box<dyn Stream<Item = Result<Bytes, CacheError>> + Unpin + Send>;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum CacheError { pub enum CacheError {
#[error(transparent)] #[error(transparent)]
@ -170,9 +189,9 @@ pub trait Cache: Send + Sync {
async fn put( async fn put(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
) -> Result<CacheStream, CacheError>; ) -> Result<(), CacheError>;
} }
#[async_trait] #[async_trait]
@ -189,9 +208,9 @@ impl<T: Cache> Cache for Arc<T> {
async fn put( async fn put(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
) -> Result<CacheStream, CacheError> { ) -> Result<(), CacheError> {
self.as_ref().put(key, image, metadata).await self.as_ref().put(key, image, metadata).await
} }
} }
@ -201,74 +220,43 @@ pub trait CallbackCache: Cache {
async fn put_with_on_completed_callback( async fn put_with_on_completed_callback(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>, on_complete: Sender<CacheEntry>,
) -> Result<CacheStream, CacheError>; ) -> Result<(), CacheError>;
} }
#[async_trait] #[async_trait]
impl<T: CallbackCache> CallbackCache for Arc<T> { impl<T: CallbackCache> CallbackCache for Arc<T> {
#[inline] #[inline]
#[cfg(not(tarpaulin_include))]
async fn put_with_on_completed_callback( async fn put_with_on_completed_callback(
&self, &self,
key: CacheKey, key: CacheKey,
image: BoxedImageStream, image: Bytes,
metadata: ImageMetadata, metadata: ImageMetadata,
on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>, on_complete: Sender<CacheEntry>,
) -> Result<CacheStream, CacheError> { ) -> Result<(), CacheError> {
self.as_ref() self.as_ref()
.put_with_on_completed_callback(key, image, metadata, on_complete) .put_with_on_completed_callback(key, image, metadata, on_complete)
.await .await
} }
} }
pub struct CacheStream { #[derive(PartialEq, Eq, Debug)]
inner: InnerStream, pub struct CacheEntry {
decrypt: Option<SecretStream<Pull>>, key: CacheKey,
data: Bytes,
metadata: ImageMetadata,
on_disk_size: u64,
} }
impl CacheStream { pub enum CacheStream {
pub(self) fn new(inner: InnerStream, header: Option<Header>) -> Result<Self, ()> {
Ok(Self {
inner,
decrypt: header
.and_then(|header| ENCRYPTION_KEY.get().map(|key| SecretStream::init_pull(&header, key)))
.transpose()?,
})
}
}
impl Stream for CacheStream {
type Item = CacheStreamItem;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx).map(|data| {
// False positive (`data`): https://link.eddie.sh/r1fXX
#[allow(clippy::option_if_let_else)]
if let Some(keystream) = self.decrypt.as_mut() {
data.map(|bytes_res| {
bytes_res.and_then(|bytes| {
keystream
.pull(&bytes, None)
.map(|(data, _tag)| Bytes::from(data))
.map_err(|_| UpstreamError)
})
})
} else {
data
}
})
}
}
pub(self) enum InnerStream {
Concurrent(ConcurrentFsStream),
Memory(MemStream), Memory(MemStream),
Completed(FramedRead<Pin<Box<dyn AsyncRead + Send>>, BytesCodec>), Completed(ReaderStream<Pin<Box<dyn MetadataFetch + Send + Sync>>>),
} }
impl From<CachedImage> for InnerStream { impl From<CachedImage> for CacheStream {
fn from(image: CachedImage) -> Self { fn from(image: CachedImage) -> Self {
Self::Memory(MemStream(image.0)) Self::Memory(MemStream(image.0))
} }
@ -276,17 +264,13 @@ impl From<CachedImage> for InnerStream {
type CacheStreamItem = Result<Bytes, UpstreamError>; type CacheStreamItem = Result<Bytes, UpstreamError>;
impl Stream for InnerStream { impl Stream for CacheStream {
type Item = CacheStreamItem; type Item = CacheStreamItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.get_mut() { match self.get_mut() {
Self::Concurrent(stream) => stream.poll_next_unpin(cx),
Self::Memory(stream) => stream.poll_next_unpin(cx), Self::Memory(stream) => stream.poll_next_unpin(cx),
Self::Completed(stream) => stream Self::Completed(stream) => stream.poll_next_unpin(cx).map_err(|_| UpstreamError),
.poll_next_unpin(cx)
.map_ok(BytesMut::freeze)
.map_err(|_| UpstreamError),
} }
} }
} }

220
src/client.rs Normal file
View file

@ -0,0 +1,220 @@
use std::collections::HashMap;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
use actix_web::http::header::{HeaderMap, HeaderName, HeaderValue};
use actix_web::web::Data;
use bytes::Bytes;
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use reqwest::header::{
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, CACHE_CONTROL, CONTENT_LENGTH,
CONTENT_TYPE, LAST_MODIFIED, X_CONTENT_TYPE_OPTIONS,
};
use reqwest::{Client, Proxy, StatusCode};
use tokio::sync::watch::{channel, Receiver};
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
use crate::cache::{Cache, CacheKey, ImageMetadata};
use crate::config::{DISABLE_CERT_VALIDATION, USE_PROXY};
pub static HTTP_CLIENT: Lazy<CachingClient> = Lazy::new(|| {
let mut inner = Client::builder()
.pool_idle_timeout(Duration::from_secs(180))
.https_only(true)
.http2_prior_knowledge();
if let Some(socket_addr) = USE_PROXY.get() {
info!(
"Using {} as a proxy for upstream requests.",
socket_addr.as_str()
);
inner = inner.proxy(Proxy::all(socket_addr.as_str()).unwrap());
}
if DISABLE_CERT_VALIDATION.load(Ordering::Acquire) {
inner = inner.danger_accept_invalid_certs(true);
}
let inner = inner.build().expect("Client initialization to work");
CachingClient {
inner,
locks: RwLock::new(HashMap::new()),
}
});
#[cfg(not(tarpaulin_include))]
pub static DEFAULT_HEADERS: Lazy<HeaderMap> = Lazy::new(|| {
let mut headers = HeaderMap::with_capacity(8);
headers.insert(X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff"));
headers.insert(
ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_static("https://mangadex.org"),
);
headers.insert(ACCESS_CONTROL_EXPOSE_HEADERS, HeaderValue::from_static("*"));
headers.insert(
CACHE_CONTROL,
HeaderValue::from_static("public, max-age=1209600"),
);
headers.insert(
HeaderName::from_static("timing-allow-origin"),
HeaderValue::from_static("https://mangadex.org"),
);
headers
});
pub struct CachingClient {
inner: Client,
locks: RwLock<HashMap<String, Receiver<FetchResult>>>,
}
#[derive(Clone, Debug)]
pub enum FetchResult {
ServiceUnavailable,
InternalServerError,
Data(StatusCode, HeaderMap, Bytes),
Processing,
}
impl CachingClient {
pub async fn fetch_and_cache(
&'static self,
url: String,
key: CacheKey,
cache: Data<dyn Cache>,
) -> FetchResult {
let maybe_receiver = {
let lock = self.locks.read();
lock.get(&url).map(Clone::clone)
};
if let Some(mut recv) = maybe_receiver {
loop {
if !matches!(*recv.borrow(), FetchResult::Processing) {
break;
}
if recv.changed().await.is_err() {
break;
}
}
return recv.borrow().clone();
}
let notify = Arc::new(Notify::new());
tokio::spawn(self.fetch_and_cache_impl(cache, url.clone(), key, Arc::clone(&notify)));
notify.notified().await;
let mut recv = self
.locks
.read()
.get(&url)
.expect("receiver to exist since we just made one")
.clone();
loop {
if !matches!(*recv.borrow(), FetchResult::Processing) {
break;
}
if recv.changed().await.is_err() {
break;
}
}
let resp = recv.borrow().clone();
resp
}
async fn fetch_and_cache_impl(
&self,
cache: Data<dyn Cache>,
url: String,
key: CacheKey,
notify: Arc<Notify>,
) {
let (tx, rx) = channel(FetchResult::Processing);
self.locks.write().insert(url.clone(), rx);
notify.notify_one();
let resp = self.inner.get(&url).send().await;
let resp = match resp {
Ok(mut resp) => {
let content_type = resp.headers().get(CONTENT_TYPE);
let is_image = content_type
.map(|v| String::from_utf8_lossy(v.as_ref()).contains("image/"))
.unwrap_or_default();
if resp.status() != StatusCode::OK || !is_image {
warn!("Got non-OK or non-image response code from upstream, proxying and not caching result.");
let mut headers = DEFAULT_HEADERS.clone();
if let Some(content_type) = content_type {
headers.insert(CONTENT_TYPE, content_type.clone());
}
FetchResult::Data(
resp.status(),
headers,
resp.bytes().await.unwrap_or_default(),
)
} else {
let (content_type, length, last_mod) = {
let headers = resp.headers_mut();
(
headers.remove(CONTENT_TYPE),
headers.remove(CONTENT_LENGTH),
headers.remove(LAST_MODIFIED),
)
};
let body = resp.bytes().await.unwrap();
debug!("Inserting into cache");
let metadata =
ImageMetadata::new(content_type.clone(), length.clone(), last_mod.clone())
.unwrap();
match cache.put(key, body.clone(), metadata).await {
Ok(()) => {
debug!("Done putting into cache");
let mut headers = DEFAULT_HEADERS.clone();
if let Some(content_type) = content_type {
headers.insert(CONTENT_TYPE, content_type);
}
if let Some(content_length) = length {
headers.insert(CONTENT_LENGTH, content_length);
}
if let Some(last_modified) = last_mod {
headers.insert(LAST_MODIFIED, last_modified);
}
FetchResult::Data(StatusCode::OK, headers, body)
}
Err(e) => {
warn!("Failed to insert into cache: {}", e);
FetchResult::InternalServerError
}
}
}
}
Err(e) => {
error!("Failed to fetch image from server: {}", e);
FetchResult::ServiceUnavailable
}
};
// This shouldn't happen
tx.send(resp).unwrap();
self.locks.write().remove(&url);
}
#[inline]
pub const fn inner(&self) -> &Client {
&self.inner
}
}

View file

@ -8,10 +8,12 @@ use std::path::{Path, PathBuf};
use std::str::FromStr; use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use clap::{crate_authors, crate_description, crate_version, Clap}; use clap::{crate_authors, crate_description, crate_version, Parser};
use log::LevelFilter; use log::LevelFilter;
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
use tracing::level_filters::LevelFilter as TracingLevelFilter;
use url::Url; use url::Url;
use crate::units::{KilobitsPerSecond, Mebibytes, Port}; use crate::units::{KilobitsPerSecond, Mebibytes, Port};
@ -19,6 +21,8 @@ use crate::units::{KilobitsPerSecond, Mebibytes, Port};
// Validate tokens is an atomic because it's faster than locking on rwlock. // Validate tokens is an atomic because it's faster than locking on rwlock.
pub static VALIDATE_TOKENS: AtomicBool = AtomicBool::new(false); pub static VALIDATE_TOKENS: AtomicBool = AtomicBool::new(false);
pub static OFFLINE_MODE: AtomicBool = AtomicBool::new(false); pub static OFFLINE_MODE: AtomicBool = AtomicBool::new(false);
pub static USE_PROXY: OnceCell<Url> = OnceCell::new();
pub static DISABLE_CERT_VALIDATION: AtomicBool = AtomicBool::new(false);
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ConfigError { pub enum ConfigError {
@ -69,6 +73,19 @@ pub fn load_config() -> Result<Config, ConfigError> {
Ordering::Release, Ordering::Release,
); );
if let Some(socket) = config.proxy.clone() {
USE_PROXY
.set(socket)
.expect("USE_PROXY to be set only by this function");
}
DISABLE_CERT_VALIDATION.store(
config
.unstable_options
.contains(&UnstableOptions::DisableCertValidation),
Ordering::Release,
);
Ok(config) Ok(config)
} }
@ -78,7 +95,7 @@ pub struct Config {
pub cache_type: CacheType, pub cache_type: CacheType,
pub cache_path: PathBuf, pub cache_path: PathBuf,
pub shutdown_timeout: NonZeroU16, pub shutdown_timeout: NonZeroU16,
pub log_level: LevelFilter, pub log_level: TracingLevelFilter,
pub client_secret: ClientSecret, pub client_secret: ClientSecret,
pub port: Port, pub port: Port,
pub bind_address: SocketAddr, pub bind_address: SocketAddr,
@ -90,6 +107,9 @@ pub struct Config {
pub unstable_options: Vec<UnstableOptions>, pub unstable_options: Vec<UnstableOptions>,
pub override_upstream: Option<Url>, pub override_upstream: Option<Url>,
pub enable_metrics: bool, pub enable_metrics: bool,
pub geoip_license_key: Option<ClientSecret>,
pub proxy: Option<Url>,
pub redis_url: Option<Url>,
} }
impl Config { impl Config {
@ -97,15 +117,24 @@ impl Config {
let file_extended_options = file_args.extended_options.unwrap_or_default(); let file_extended_options = file_args.extended_options.unwrap_or_default();
let log_level = match (cli_args.quiet, cli_args.verbose) { let log_level = match (cli_args.quiet, cli_args.verbose) {
(n, _) if n > 2 => LevelFilter::Off, (n, _) if n > 2 => TracingLevelFilter::OFF,
(2, _) => LevelFilter::Error, (2, _) => TracingLevelFilter::ERROR,
(1, _) => LevelFilter::Warn, (1, _) => TracingLevelFilter::WARN,
// Use log level from file if no flags were provided to CLI // Use log level from file if no flags were provided to CLI
(0, 0) => file_extended_options (0, 0) => {
.logging_level file_extended_options
.unwrap_or(LevelFilter::Info), .logging_level
(_, 1) => LevelFilter::Debug, .map_or(TracingLevelFilter::INFO, |filter| match filter {
(_, n) if n > 1 => LevelFilter::Trace, LevelFilter::Off => TracingLevelFilter::OFF,
LevelFilter::Error => TracingLevelFilter::ERROR,
LevelFilter::Warn => TracingLevelFilter::WARN,
LevelFilter::Info => TracingLevelFilter::INFO,
LevelFilter::Debug => TracingLevelFilter::DEBUG,
LevelFilter::Trace => TracingLevelFilter::TRACE,
})
}
(_, 1) => TracingLevelFilter::DEBUG,
(_, n) if n > 1 => TracingLevelFilter::TRACE,
// compiler can't figure it out // compiler can't figure it out
_ => unsafe { unreachable_unchecked() }, _ => unsafe { unreachable_unchecked() },
}; };
@ -155,10 +184,10 @@ impl Config {
.server_settings .server_settings
.external_ip .external_ip
.map(|ip_addr| SocketAddr::new(ip_addr, external_port)), .map(|ip_addr| SocketAddr::new(ip_addr, external_port)),
ephemeral_disk_encryption: cli_args ephemeral_disk_encryption: cli_args.ephemeral_disk_encryption
.ephemeral_disk_encryption || file_extended_options
.or(file_extended_options.ephemeral_disk_encryption) .ephemeral_disk_encryption
.unwrap_or_default(), .unwrap_or_default(),
network_speed: cli_args network_speed: cli_args
.network_speed .network_speed
.unwrap_or(file_args.server_settings.external_max_kilobits_per_second), .unwrap_or(file_args.server_settings.external_max_kilobits_per_second),
@ -174,16 +203,71 @@ impl Config {
// Unstable options (and related) should never be in yaml config // Unstable options (and related) should never be in yaml config
unstable_options: cli_args.unstable_options, unstable_options: cli_args.unstable_options,
override_upstream: cli_args.override_upstream, override_upstream: cli_args.override_upstream,
geoip_license_key: file_args.metric_settings.and_then(|args| {
if args.enable_geoip.unwrap_or_default() {
args.geoip_license_key
} else {
None
}
}),
proxy: cli_args.proxy,
redis_url: file_extended_options.redis_url,
} }
} }
} }
// this intentionally does not implement display
#[derive(Deserialize, Serialize, Clone)]
pub struct ClientSecret(String);
impl ClientSecret {
pub fn as_str(&self) -> &str {
self.0.as_ref()
}
}
impl std::fmt::Debug for ClientSecret {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "[client secret]")
}
}
#[derive(Deserialize, Copy, Clone, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum CacheType {
OnDisk,
Lru,
Lfu,
Redis,
}
impl FromStr for CacheType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"on_disk" => Ok(Self::OnDisk),
"lru" => Ok(Self::Lru),
"lfu" => Ok(Self::Lfu),
"redis" => Ok(Self::Redis),
_ => Err(format!("Unknown option: {}", s)),
}
}
}
impl Default for CacheType {
fn default() -> Self {
Self::OnDisk
}
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct YamlArgs { struct YamlArgs {
// Naming is legacy // Naming is legacy
max_cache_size_in_mebibytes: Mebibytes, max_cache_size_in_mebibytes: Mebibytes,
server_settings: YamlServerSettings, server_settings: YamlServerSettings,
// This implementation custom options metric_settings: Option<YamlMetricSettings>,
// This implementation's custom options
extended_options: Option<YamlExtendedOptions>, extended_options: Option<YamlExtendedOptions>,
} }
@ -200,14 +284,10 @@ struct YamlServerSettings {
external_ip: Option<IpAddr>, external_ip: Option<IpAddr>,
} }
// this intentionally does not implement display #[derive(Deserialize)]
#[derive(Deserialize, Serialize, Clone)] struct YamlMetricSettings {
pub struct ClientSecret(String); enable_geoip: Option<bool>,
geoip_license_key: Option<ClientSecret>,
impl std::fmt::Debug for ClientSecret {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "[client secret]")
}
} }
#[derive(Deserialize, Default)] #[derive(Deserialize, Default)]
@ -218,36 +298,10 @@ struct YamlExtendedOptions {
enable_metrics: Option<bool>, enable_metrics: Option<bool>,
logging_level: Option<LevelFilter>, logging_level: Option<LevelFilter>,
cache_path: Option<PathBuf>, cache_path: Option<PathBuf>,
redis_url: Option<Url>,
} }
#[derive(Deserialize, Copy, Clone, Debug)] #[derive(Parser, Clone)]
#[serde(rename_all = "snake_case")]
pub enum CacheType {
OnDisk,
Lru,
Lfu,
}
impl FromStr for CacheType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"on_disk" => Ok(Self::OnDisk),
"lru" => Ok(Self::Lru),
"lfu" => Ok(Self::Lfu),
_ => Err(format!("Unknown option: {}", s)),
}
}
}
impl Default for CacheType {
fn default() -> Self {
Self::OnDisk
}
}
#[derive(Clap, Clone)]
#[clap(version = crate_version!(), author = crate_authors!(), about = crate_description!())] #[clap(version = crate_version!(), author = crate_authors!(), about = crate_description!())]
struct CliArgs { struct CliArgs {
/// The port to listen on. /// The port to listen on.
@ -276,19 +330,29 @@ struct CliArgs {
/// respectively. /// respectively.
#[clap(short, long, parse(from_occurrences), conflicts_with = "verbose")] #[clap(short, long, parse(from_occurrences), conflicts_with = "verbose")]
pub quiet: usize, pub quiet: usize,
/// Unstable options. Intentionally not documented.
#[clap(short = 'Z', long)] #[clap(short = 'Z', long)]
pub unstable_options: Vec<UnstableOptions>, pub unstable_options: Vec<UnstableOptions>,
/// Override the image server with the one provided. Do not set this unless
/// you know what you're doing.
#[clap(long)] #[clap(long)]
pub override_upstream: Option<Url>, pub override_upstream: Option<Url>,
/// Enables ephemeral disk encryption. Items written to disk are first /// Enables ephemeral disk encryption. Items written to disk are first
/// encrypted with a key generated at runtime. There are implications to /// encrypted with a key generated at runtime. There are implications to
/// performance, privacy, and usability with this flag enabled. /// performance, privacy, and usability with this flag enabled.
#[clap(short, long)] #[clap(short, long)]
pub ephemeral_disk_encryption: Option<bool>, pub ephemeral_disk_encryption: bool,
/// The path to the config file. Default value is `./settings.yaml`.
#[clap(short, long)] #[clap(short, long)]
pub config_path: Option<PathBuf>, pub config_path: Option<PathBuf>,
/// Whether to use an in-memory cache in addition to the disk cache. Default
/// value is "on_disk", other options are "lfu", "lru", and "redis".
#[clap(short = 't', long)] #[clap(short = 't', long)]
pub cache_type: Option<CacheType>, pub cache_type: Option<CacheType>,
/// Whether or not to use a proxy for upstream requests. This affects all
/// requests except for the shutdown request.
#[clap(short = 'P', long)]
pub proxy: Option<Url>,
} }
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
@ -306,6 +370,10 @@ pub enum UnstableOptions {
/// Serves HTTP in plaintext /// Serves HTTP in plaintext
DisableTls, DisableTls,
/// Disable certificate validation. Only useful for debugging with a MITM
/// proxy
DisableCertValidation,
} }
impl FromStr for UnstableOptions { impl FromStr for UnstableOptions {
@ -317,6 +385,7 @@ impl FromStr for UnstableOptions {
"disable-token-validation" => Ok(Self::DisableTokenValidation), "disable-token-validation" => Ok(Self::DisableTokenValidation),
"offline-mode" => Ok(Self::OfflineMode), "offline-mode" => Ok(Self::OfflineMode),
"disable-tls" => Ok(Self::DisableTls), "disable-tls" => Ok(Self::DisableTls),
"disable-cert-validation" => Ok(Self::DisableCertValidation),
_ => Err(format!("Unknown unstable option '{}'", s)), _ => Err(format!("Unknown unstable option '{}'", s)),
} }
} }
@ -329,6 +398,82 @@ impl Display for UnstableOptions {
Self::DisableTokenValidation => write!(f, "disable-token-validation"), Self::DisableTokenValidation => write!(f, "disable-token-validation"),
Self::OfflineMode => write!(f, "offline-mode"), Self::OfflineMode => write!(f, "offline-mode"),
Self::DisableTls => write!(f, "disable-tls"), Self::DisableTls => write!(f, "disable-tls"),
Self::DisableCertValidation => write!(f, "disable-cert-validation"),
} }
} }
} }
#[cfg(test)]
mod sample_yaml {
use crate::config::YamlArgs;
#[test]
fn parses() {
assert!(serde_yaml::from_str::<YamlArgs>(include_str!("../settings.sample.yaml")).is_ok());
}
}
#[cfg(test)]
mod config {
use std::path::PathBuf;
use log::LevelFilter;
use tracing::level_filters::LevelFilter as TracingLevelFilter;
use crate::config::{CacheType, ClientSecret, Config, YamlExtendedOptions, YamlServerSettings};
use crate::units::{KilobitsPerSecond, Mebibytes, Port};
use super::{CliArgs, YamlArgs};
#[test]
fn cli_has_priority() {
let cli_config = CliArgs {
port: Port::new(1234),
memory_quota: Some(Mebibytes::new(10)),
disk_quota: Some(Mebibytes::new(10)),
cache_path: Some(PathBuf::from("a")),
network_speed: KilobitsPerSecond::new(10),
verbose: 1,
quiet: 0,
unstable_options: vec![],
override_upstream: None,
ephemeral_disk_encryption: true,
config_path: None,
cache_type: Some(CacheType::Lfu),
proxy: None,
};
let yaml_args = YamlArgs {
max_cache_size_in_mebibytes: Mebibytes::new(50),
server_settings: YamlServerSettings {
secret: ClientSecret(String::new()),
port: Port::new(4321).expect("to work?"),
external_max_kilobits_per_second: KilobitsPerSecond::new(50).expect("to work?"),
external_port: None,
graceful_shutdown_wait_seconds: None,
hostname: None,
external_ip: None,
},
metric_settings: None,
extended_options: Some(YamlExtendedOptions {
memory_quota: Some(Mebibytes::new(50)),
cache_type: Some(CacheType::Lru),
ephemeral_disk_encryption: Some(false),
enable_metrics: None,
logging_level: Some(LevelFilter::Error),
cache_path: Some(PathBuf::from("b")),
redis_url: None,
}),
};
let config = Config::from_cli_and_file(cli_config, yaml_args);
assert_eq!(Some(config.port), Port::new(1234));
assert_eq!(config.memory_quota, Mebibytes::new(10));
assert_eq!(config.disk_quota, Mebibytes::new(10));
assert_eq!(config.cache_path, PathBuf::from("a"));
assert_eq!(Some(config.network_speed), KilobitsPerSecond::new(10));
assert_eq!(config.log_level, TracingLevelFilter::DEBUG);
assert_eq!(config.ephemeral_disk_encryption, true);
assert_eq!(config.cache_type, CacheType::Lfu);
}
}

View file

@ -5,31 +5,40 @@
use std::env::VarError; use std::env::VarError;
use std::error::Error; use std::error::Error;
use std::fmt::Display; use std::fmt::Display;
use std::net::SocketAddr;
use std::num::ParseIntError; use std::num::ParseIntError;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use actix_web::dev::Service;
use actix_web::rt::{spawn, time, System}; use actix_web::rt::{spawn, time, System};
use actix_web::web::{self, Data}; use actix_web::web::{self, Data};
use actix_web::{App, HttpResponse, HttpServer}; use actix_web::{App, HttpResponse, HttpServer};
use cache::{Cache, DiskCache}; use cache::{Cache, DiskCache};
use chacha20::Key;
use config::Config; use config::Config;
use log::{debug, error, info, warn}; use maxminddb::geoip2;
use parking_lot::RwLock; use parking_lot::RwLock;
use rustls::{NoClientAuth, ServerConfig}; use redis::Client as RedisClient;
use simple_logger::SimpleLogger;
use sodiumoxide::crypto::secretstream::gen_key; use rustls::server::NoClientAuth;
use rustls::ServerConfig;
use sodiumoxide::crypto::stream::xchacha20::gen_key;
use state::{RwLockServerState, ServerState}; use state::{RwLockServerState, ServerState};
use stop::send_stop; use stop::send_stop;
use thiserror::Error; use thiserror::Error;
use tracing::{debug, error, info, warn};
use crate::cache::mem::{Lfu, Lru}; use crate::cache::mem::{Lfu, Lru};
use crate::cache::{MemoryCache, ENCRYPTION_KEY}; use crate::cache::{MemoryCache, ENCRYPTION_KEY};
use crate::config::{CacheType, UnstableOptions, OFFLINE_MODE}; use crate::config::{CacheType, UnstableOptions, OFFLINE_MODE};
use crate::metrics::{record_country_visit, GEOIP_DATABASE};
use crate::state::DynamicServerCert; use crate::state::DynamicServerCert;
mod cache; mod cache;
mod client;
mod config; mod config;
mod metrics; mod metrics;
mod ping; mod ping;
@ -74,12 +83,15 @@ async fn main() -> Result<(), Box<dyn Error>> {
.unstable_options .unstable_options
.contains(&UnstableOptions::DisableTls); .contains(&UnstableOptions::DisableTls);
let bind_address = config.bind_address; let bind_address = config.bind_address;
let redis_url = config.redis_url.clone();
// //
// Logging and warnings // Logging and warnings
// //
SimpleLogger::new().with_level(config.log_level).init()?; tracing_subscriber::fmt()
.with_max_level(config.log_level)
.init();
if let Err(e) = print_preamble_and_warnings(&config) { if let Err(e) = print_preamble_and_warnings(&config) {
error!("{}", e); error!("{}", e);
@ -93,15 +105,30 @@ async fn main() -> Result<(), Box<dyn Error>> {
if config.ephemeral_disk_encryption { if config.ephemeral_disk_encryption {
info!("Running with at-rest encryption!"); info!("Running with at-rest encryption!");
ENCRYPTION_KEY.set(gen_key()).unwrap(); ENCRYPTION_KEY
.set(*Key::from_slice(gen_key().as_ref()))
.unwrap();
} }
if config.enable_metrics { if config.enable_metrics {
metrics::init(); metrics::init();
} }
if let Some(key) = config.geoip_license_key.clone() {
if let Err(e) = metrics::load_geo_ip_data(key).await {
error!("Failed to initialize geo ip db: {}", e);
}
}
// HTTP Server init // HTTP Server init
// Try bind to provided port first
let port_reservation = std::net::TcpListener::bind(bind_address);
if let Err(e) = port_reservation {
error!("Failed to bind to port!");
return Err(e.into());
};
let server = if OFFLINE_MODE.load(Ordering::Acquire) { let server = if OFFLINE_MODE.load(Ordering::Acquire) {
ServerState::init_offline() ServerState::init_offline()
} else { } else {
@ -129,7 +156,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
send_stop(&client_secret).await; send_stop(&client_secret).await;
} else { } else {
warn!("Got second Ctrl-C, forcefully exiting"); warn!("Got second Ctrl-C, forcefully exiting");
system.stop() system.stop();
} }
}); });
} }
@ -154,8 +181,16 @@ async fn main() -> Result<(), Box<dyn Error>> {
let cache = DiskCache::new(disk_quota.into(), cache_path.clone()).await; let cache = DiskCache::new(disk_quota.into(), cache_path.clone()).await;
let cache: Arc<dyn Cache> = match cache_type { let cache: Arc<dyn Cache> = match cache_type {
CacheType::OnDisk => cache, CacheType::OnDisk => cache,
CacheType::Lru => MemoryCache::<Lfu, _>::new(cache, memory_max_size).await, CacheType::Lru => MemoryCache::<Lfu, _>::new(cache, memory_max_size),
CacheType::Lfu => MemoryCache::<Lru, _>::new(cache, memory_max_size).await, CacheType::Lfu => MemoryCache::<Lru, _>::new(cache, memory_max_size),
CacheType::Redis => {
let url = redis_url.unwrap_or_else(|| {
url::Url::parse("redis://127.0.0.1/").expect("default redis url to be parsable")
});
info!("Trying to connect to redis instance at {}", url);
let mem_cache = RedisClient::open(url)?;
Arc::new(MemoryCache::new_with_cache(cache, mem_cache))
}
}; };
let cache_0 = Arc::clone(&cache); let cache_0 = Arc::clone(&cache);
@ -163,6 +198,23 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Start HTTPS server // Start HTTPS server
let server = HttpServer::new(move || { let server = HttpServer::new(move || {
App::new() App::new()
.wrap_fn(|req, srv| {
if let Some(reader) = GEOIP_DATABASE.get() {
let maybe_country = req
.connection_info()
.realip_remote_addr()
.map(SocketAddr::from_str)
.and_then(Result::ok)
.as_ref()
.map(SocketAddr::ip)
.map(|ip| reader.lookup::<geoip2::Country>(ip))
.and_then(Result::ok);
record_country_visit(maybe_country);
}
srv.call(req)
})
.service(routes::index) .service(routes::index)
.service(routes::token_data) .service(routes::token_data)
.service(routes::token_data_saver) .service(routes::token_data_saver)
@ -181,22 +233,25 @@ async fn main() -> Result<(), Box<dyn Error>> {
}) })
.shutdown_timeout(60); .shutdown_timeout(60);
// drop port reservation, might have a TOCTOU but it's not a big deal; this
// is just a best effort.
std::mem::drop(port_reservation);
if disable_tls { if disable_tls {
server.bind(bind_address)?.run().await?; server.bind(bind_address)?.run().await?;
} else { } else {
// Rustls only supports TLS 1.2 and 1.3. // Rustls only supports TLS 1.2 and 1.3.
let tls_config = { let tls_config = ServerConfig::builder()
let mut tls_config = ServerConfig::new(NoClientAuth::new()); .with_safe_defaults()
tls_config.cert_resolver = Arc::new(DynamicServerCert); .with_client_cert_verifier(NoClientAuth::new())
tls_config .with_cert_resolver(Arc::new(DynamicServerCert));
};
server.bind_rustls(bind_address, tls_config)?.run().await?; server.bind_rustls(bind_address, tls_config)?.run().await?;
} }
// Waiting for us to finish sending stop message // Waiting for us to finish sending stop message
while running.load(Ordering::SeqCst) { while running.load(Ordering::SeqCst) {
std::thread::sleep(Duration::from_millis(250)); tokio::time::sleep(Duration::from_millis(250)).await;
} }
Ok(()) Ok(())
@ -207,6 +262,7 @@ enum InvalidCombination {
MissingUnstableOption(&'static str, UnstableOptions), MissingUnstableOption(&'static str, UnstableOptions),
} }
#[cfg(not(tarpaulin_include))]
impl Display for InvalidCombination { impl Display for InvalidCombination {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
@ -223,32 +279,38 @@ impl Display for InvalidCombination {
impl Error for InvalidCombination {} impl Error for InvalidCombination {}
#[cfg(not(tarpaulin_include))]
#[allow(clippy::cognitive_complexity)]
fn print_preamble_and_warnings(args: &Config) -> Result<(), Box<dyn Error>> { fn print_preamble_and_warnings(args: &Config) -> Result<(), Box<dyn Error>> {
println!(concat!( let build_string = option_env!("VERGEN_GIT_SHA_SHORT")
env!("CARGO_PKG_NAME"), .map(|git_sha| format!(" ({})", git_sha))
" ", .unwrap_or_default();
env!("CARGO_PKG_VERSION"),
" (", println!(
env!("VERGEN_GIT_SHA_SHORT"), concat!(
")", env!("CARGO_PKG_NAME"),
" Copyright (C) 2021 ", " ",
env!("CARGO_PKG_AUTHORS"), env!("CARGO_PKG_VERSION"),
"\n\n", "{} Copyright (C) 2021 ",
env!("CARGO_PKG_NAME"), env!("CARGO_PKG_AUTHORS"),
" is free software: you can redistribute it and/or modify\n\ "\n\n",
it under the terms of the GNU General Public License as published by\n\ env!("CARGO_PKG_NAME"),
the Free Software Foundation, either version 3 of the License, or\n\ " is free software: you can redistribute it and/or modify\n\
(at your option) any later version.\n\n", it under the terms of the GNU General Public License as published by\n\
env!("CARGO_PKG_NAME"), the Free Software Foundation, either version 3 of the License, or\n\
" is distributed in the hope that it will be useful,\n\ (at your option) any later version.\n\n",
but WITHOUT ANY WARRANTY; without even the implied warranty of\n\ env!("CARGO_PKG_NAME"),
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n\ " is distributed in the hope that it will be useful,\n\
GNU General Public License for more details.\n\n\ but WITHOUT ANY WARRANTY; without even the implied warranty of\n\
You should have received a copy of the GNU General Public License\n\ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n\
along with ", GNU General Public License for more details.\n\n\
env!("CARGO_PKG_NAME"), You should have received a copy of the GNU General Public License\n\
". If not, see <https://www.gnu.org/licenses/>.\n" along with ",
)); env!("CARGO_PKG_NAME"),
". If not, see <https://www.gnu.org/licenses/>.\n"
),
build_string
);
if !args.unstable_options.is_empty() { if !args.unstable_options.is_empty() {
warn!("Unstable options are enabled. These options should not be used in production!"); warn!("Unstable options are enabled. These options should not be used in production!");
@ -265,6 +327,13 @@ fn print_preamble_and_warnings(args: &Config) -> Result<(), Box<dyn Error>> {
warn!("Serving insecure traffic! You better be running this for development only."); warn!("Serving insecure traffic! You better be running this for development only.");
} }
if args
.unstable_options
.contains(&UnstableOptions::DisableCertValidation)
{
error!("Cert validation disabled! You REALLY only better be debugging.");
}
if args.override_upstream.is_some() if args.override_upstream.is_some()
&& !args && !args
.unstable_options .unstable_options

View file

@ -1,5 +1,31 @@
use once_cell::sync::Lazy; #![cfg(not(tarpaulin_include))]
use prometheus::{register_int_counter, IntCounter};
use std::fs::metadata;
use std::hint::unreachable_unchecked;
use std::time::SystemTime;
use chrono::Duration;
use flate2::read::GzDecoder;
use maxminddb::geoip2::Country;
use once_cell::sync::{Lazy, OnceCell};
use prometheus::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec};
use tar::Archive;
use thiserror::Error;
use tracing::{debug, field::debug, info, warn};
use crate::client::HTTP_CLIENT;
use crate::config::ClientSecret;
pub static GEOIP_DATABASE: OnceCell<maxminddb::Reader<Vec<u8>>> = OnceCell::new();
static COUNTRY_VISIT_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"country_visits_total",
"The number of visits from a country",
&["country"]
)
.unwrap()
});
macro_rules! init_counters { macro_rules! init_counters {
($(($counter:ident, $ty:ty, $name:literal, $desc:literal),)*) => { ($(($counter:ident, $ty:ty, $name:literal, $desc:literal),)*) => {
@ -11,7 +37,11 @@ macro_rules! init_counters {
#[allow(clippy::shadow_unrelated)] #[allow(clippy::shadow_unrelated)]
pub fn init() { pub fn init() {
// These need to be called at least once, otherwise the macro never
// called and thus the metrics don't get logged
$(let _a = $counter.get();)* $(let _a = $counter.get();)*
init_other();
} }
}; };
} }
@ -20,13 +50,13 @@ init_counters!(
( (
CACHE_HIT_COUNTER, CACHE_HIT_COUNTER,
IntCounter, IntCounter,
"cache_hit", "cache_hit_total",
"The number of cache hits." "The number of cache hits."
), ),
( (
CACHE_MISS_COUNTER, CACHE_MISS_COUNTER,
IntCounter, IntCounter,
"cache_miss", "cache_miss_total",
"The number of cache misses." "The number of cache misses."
), ),
( (
@ -38,19 +68,118 @@ init_counters!(
( (
REQUESTS_DATA_COUNTER, REQUESTS_DATA_COUNTER,
IntCounter, IntCounter,
"requests_data", "requests_data_total",
"The number of requests served from the /data endpoint." "The number of requests served from the /data endpoint."
), ),
( (
REQUESTS_DATA_SAVER_COUNTER, REQUESTS_DATA_SAVER_COUNTER,
IntCounter, IntCounter,
"requests_data_saver", "requests_data_saver_total",
"The number of requests served from the /data-saver endpoint." "The number of requests served from the /data-saver endpoint."
), ),
( (
REQUESTS_OTHER_COUNTER, REQUESTS_OTHER_COUNTER,
IntCounter, IntCounter,
"requests_other", "requests_other_total",
"The total number of request not served by primary endpoints." "The total number of request not served by primary endpoints."
), ),
); );
// initialization for any other counters that aren't simple int counters
fn init_other() {
let _a = COUNTRY_VISIT_COUNTER.local();
}
#[derive(Error, Debug)]
pub enum DbLoadError {
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
MaxMindDb(#[from] maxminddb::MaxMindDBError),
}
pub async fn load_geo_ip_data(license_key: ClientSecret) -> Result<(), DbLoadError> {
const DB_PATH: &str = "./GeoLite2-Country.mmdb";
// Check date of db
let db_date_created = metadata(DB_PATH)
.ok()
.and_then(|metadata| {
if let Ok(time) = metadata.created() {
Some(time)
} else {
debug("fs didn't report birth time, fall back to last modified instead");
metadata.modified().ok()
}
})
.unwrap_or(SystemTime::UNIX_EPOCH);
let duration = if let Ok(time) = SystemTime::now().duration_since(db_date_created) {
Duration::from_std(time).expect("duration to fit")
} else {
warn!("Clock may have gone backwards?");
Duration::max_value()
};
// DB expired, fetch a new one
if duration > Duration::weeks(1) {
fetch_db(license_key).await?;
} else {
info!("Geo IP database isn't old enough, not updating.");
}
// Result literally cannot panic here, buuuuuut if it does we'll panic
GEOIP_DATABASE
.set(maxminddb::Reader::open_readfile(DB_PATH)?)
.map_err(|_| ()) // Need to map err here or can't expect
.expect("to set the geo ip db singleton");
Ok(())
}
async fn fetch_db(license_key: ClientSecret) -> Result<(), DbLoadError> {
let resp = HTTP_CLIENT
.inner()
.get(format!("https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-Country&license_key={}&suffix=tar.gz", license_key.as_str()))
.send()
.await?
.bytes()
.await?;
let mut decoder = Archive::new(GzDecoder::new(resp.as_ref()));
let mut decoded_paths: Vec<_> = decoder
.entries()?
.filter_map(Result::ok)
.filter_map(|mut entry| {
let path = entry.path().ok()?.to_path_buf();
let file_name = path.file_name()?;
if file_name != "GeoLite2-Country.mmdb" {
return None;
}
entry.unpack(file_name).ok()?;
Some(path)
})
.collect();
assert_eq!(decoded_paths.len(), 1);
let path = match decoded_paths.pop() {
Some(path) => path,
None => unsafe { unreachable_unchecked() },
};
debug!("Extracted {}", path.as_path().to_string_lossy());
Ok(())
}
pub fn record_country_visit(country: Option<Country>) {
let iso_code = country
.and_then(|country| country.country.and_then(|c| c.iso_code))
.unwrap_or("unknown");
COUNTRY_VISIT_COUNTER
.get_metric_with_label_values(&[iso_code])
.unwrap()
.inc();
}

View file

@ -1,47 +1,54 @@
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::{io::BufReader, sync::Arc}; use std::{io::BufReader, sync::Arc};
use log::{debug, error, info, warn}; use rustls::sign::{CertifiedKey, RsaSigningKey, SigningKey};
use rustls::internal::pemfile::{certs, rsa_private_keys}; use rustls::{Certificate, PrivateKey};
use rustls::sign::{RSASigningKey, SigningKey}; use rustls_pemfile::{certs, rsa_private_keys};
use rustls::Certificate;
use serde::de::{MapAccess, Visitor}; use serde::de::{MapAccess, Visitor};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_repr::Deserialize_repr; use serde_repr::Deserialize_repr;
use sodiumoxide::crypto::box_::PrecomputedKey; use sodiumoxide::crypto::box_::PrecomputedKey;
use tracing::{debug, error, info, warn};
use url::Url; use url::Url;
use crate::config::{ClientSecret, Config, UnstableOptions, VALIDATE_TOKENS}; use crate::client::HTTP_CLIENT;
use crate::config::{ClientSecret, Config};
use crate::state::{ use crate::state::{
RwLockServerState, PREVIOUSLY_COMPROMISED, PREVIOUSLY_PAUSED, TLS_CERTS, RwLockServerState, CERTIFIED_KEY, PREVIOUSLY_COMPROMISED, PREVIOUSLY_PAUSED,
TLS_PREVIOUSLY_CREATED, TLS_SIGNING_KEY, TLS_PREVIOUSLY_CREATED,
}; };
use crate::units::{BytesPerSecond, Mebibytes, Port}; use crate::units::{Bytes, BytesPerSecond, Port};
use crate::CLIENT_API_VERSION; use crate::CLIENT_API_VERSION;
pub const CONTROL_CENTER_PING_URL: &str = "https://api.mangadex.network/ping"; pub const CONTROL_CENTER_PING_URL: &str = "https://api.mangadex.network/ping";
#[derive(Serialize)] #[derive(Serialize, Debug)]
pub struct Request<'a> { pub struct Request<'a> {
secret: &'a ClientSecret, secret: &'a ClientSecret,
port: Port, port: Port,
disk_space: Mebibytes, disk_space: Bytes,
network_speed: BytesPerSecond, network_speed: BytesPerSecond,
build_version: usize, build_version: usize,
tls_created_at: Option<String>, tls_created_at: Option<String>,
ip_address: Option<IpAddr>,
} }
impl<'a> Request<'a> { impl<'a> Request<'a> {
fn from_config_and_state(secret: &'a ClientSecret, config: &Config) -> Self { fn from_config_and_state(secret: &'a ClientSecret, config: &Config) -> Self {
Self { Self {
secret, secret,
port: config.port, port: config
disk_space: config.disk_quota, .external_address
.and_then(|v| Port::new(v.port()))
.unwrap_or(config.port),
disk_space: config.disk_quota.into(),
network_speed: config.network_speed.into(), network_speed: config.network_speed.into(),
build_version: CLIENT_API_VERSION, build_version: CLIENT_API_VERSION,
tls_created_at: TLS_PREVIOUSLY_CREATED tls_created_at: TLS_PREVIOUSLY_CREATED
.get() .get()
.map(|v| v.load().as_ref().clone()), .map(|v| v.load().as_ref().clone()),
ip_address: config.external_address.as_ref().map(SocketAddr::ip),
} }
} }
} }
@ -50,11 +57,15 @@ impl<'a> From<(&'a ClientSecret, &Config)> for Request<'a> {
fn from((secret, config): (&'a ClientSecret, &Config)) -> Self { fn from((secret, config): (&'a ClientSecret, &Config)) -> Self {
Self { Self {
secret, secret,
port: config.port, port: config
disk_space: config.disk_quota, .external_address
.and_then(|v| Port::new(v.port()))
.unwrap_or(config.port),
disk_space: config.disk_quota.into(),
network_speed: config.network_speed.into(), network_speed: config.network_speed.into(),
build_version: CLIENT_API_VERSION, build_version: CLIENT_API_VERSION,
tls_created_at: None, tls_created_at: None,
ip_address: config.external_address.as_ref().map(SocketAddr::ip),
} }
} }
} }
@ -62,7 +73,7 @@ impl<'a> From<(&'a ClientSecret, &Config)> for Request<'a> {
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
#[serde(untagged)] #[serde(untagged)]
pub enum Response { pub enum Response {
Ok(OkResponse), Ok(Box<OkResponse>),
Error(ErrorResponse), Error(ErrorResponse),
} }
@ -74,8 +85,6 @@ pub struct OkResponse {
pub token_key: Option<String>, pub token_key: Option<String>,
pub compromised: bool, pub compromised: bool,
pub paused: bool, pub paused: bool,
#[serde(default)]
pub force_tokens: bool,
pub tls: Option<Tls>, pub tls: Option<Tls>,
} }
@ -95,7 +104,7 @@ pub enum ErrorCode {
pub struct Tls { pub struct Tls {
pub created_at: String, pub created_at: String,
pub priv_key: Arc<Box<dyn SigningKey>>, pub priv_key: Arc<RsaSigningKey>,
pub certs: Vec<Certificate>, pub certs: Vec<Certificate>,
} }
@ -128,11 +137,12 @@ impl<'de> Deserialize<'de> for Tls {
priv_key = rsa_private_keys(&mut BufReader::new(value.as_bytes())) priv_key = rsa_private_keys(&mut BufReader::new(value.as_bytes()))
.ok() .ok()
.and_then(|mut v| { .and_then(|mut v| {
v.pop().and_then(|key| RSASigningKey::new(&key).ok()) v.pop()
}) .and_then(|key| RsaSigningKey::new(&PrivateKey(key)).ok())
});
} }
"certificate" => { "certificate" => {
certificates = certs(&mut BufReader::new(value.as_bytes())).ok() certificates = certs(&mut BufReader::new(value.as_bytes())).ok();
} }
_ => (), // Ignore extra fields _ => (), // Ignore extra fields
} }
@ -141,8 +151,8 @@ impl<'de> Deserialize<'de> for Tls {
match (created_at, priv_key, certificates) { match (created_at, priv_key, certificates) {
(Some(created_at), Some(priv_key), Some(certificates)) => Ok(Tls { (Some(created_at), Some(priv_key), Some(certificates)) => Ok(Tls {
created_at, created_at,
priv_key: Arc::new(Box::new(priv_key)), priv_key: Arc::new(priv_key),
certs: certificates, certs: certificates.into_iter().map(Certificate).collect(),
}), }),
_ => Err(serde::de::Error::custom("Could not deserialize tls info")), _ => Err(serde::de::Error::custom("Could not deserialize tls info")),
} }
@ -153,6 +163,7 @@ impl<'de> Deserialize<'de> for Tls {
} }
} }
#[cfg(not(tarpaulin_include))]
impl std::fmt::Debug for Tls { impl std::fmt::Debug for Tls {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tls") f.debug_struct("Tls")
@ -167,8 +178,13 @@ pub async fn update_server_state(
data: &mut Arc<RwLockServerState>, data: &mut Arc<RwLockServerState>,
) { ) {
let req = Request::from_config_and_state(secret, cli); let req = Request::from_config_and_state(secret, cli);
let client = reqwest::Client::new(); debug!("Sending ping request: {:?}", req);
let resp = client.post(CONTROL_CENTER_PING_URL).json(&req).send().await; let resp = HTTP_CLIENT
.inner()
.post(CONTROL_CENTER_PING_URL)
.json(&req)
.send()
.await;
match resp { match resp {
Ok(resp) => match resp.json::<Response>().await { Ok(resp) => match resp.json::<Response>().await {
Ok(Response::Ok(resp)) => { Ok(Response::Ok(resp)) => {
@ -183,27 +199,13 @@ pub async fn update_server_state(
} }
if let Some(key) = resp.token_key { if let Some(key) = resp.token_key {
if let Some(key) = base64::decode(&key) base64::decode(&key)
.ok() .ok()
.and_then(|k| PrecomputedKey::from_slice(&k)) .and_then(|k| PrecomputedKey::from_slice(&k))
{ .map_or_else(
write_guard.precomputed_key = key; || error!("Failed to parse token key: got {}", key),
} else { |key| write_guard.precomputed_key = key,
error!("Failed to parse token key: got {}", key); );
}
}
if !cli
.unstable_options
.contains(&UnstableOptions::DisableTokenValidation)
&& VALIDATE_TOKENS.load(Ordering::Acquire) != resp.force_tokens
{
if resp.force_tokens {
info!("Client received command to enforce token validity.");
} else {
info!("Client received command to no longer enforce token validity");
}
VALIDATE_TOKENS.store(resp.force_tokens, Ordering::Release);
} }
if let Some(tls) = resp.tls { if let Some(tls) = resp.tls {
@ -211,8 +213,12 @@ pub async fn update_server_state(
.get() .get()
.unwrap() .unwrap()
.swap(Arc::new(tls.created_at)); .swap(Arc::new(tls.created_at));
TLS_SIGNING_KEY.get().unwrap().swap(tls.priv_key); CERTIFIED_KEY.store(Some(Arc::new(CertifiedKey {
TLS_CERTS.get().unwrap().swap(Arc::new(tls.certs)); cert: tls.certs.clone(),
key: Arc::clone(&tls.priv_key) as Arc<dyn SigningKey>,
ocsp: None,
sct_list: None,
})));
} }
let previously_compromised = PREVIOUSLY_COMPROMISED.load(Ordering::Acquire); let previously_compromised = PREVIOUSLY_COMPROMISED.load(Ordering::Acquire);
@ -251,7 +257,7 @@ pub async fn update_server_state(
}, },
Err(e) => match e { Err(e) => match e {
e if e.is_timeout() => { e if e.is_timeout() => {
error!("Response timed out to control server. Is MangaDex down?") error!("Response timed out to control server. Is MangaDex down?");
} }
e => warn!("Failed to send request: {}", e), e => warn!("Failed to send request: {}", e),
}, },

View file

@ -1,27 +1,24 @@
use std::hint::unreachable_unchecked;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::time::Duration;
use actix_web::body::BoxBody;
use actix_web::error::ErrorNotFound; use actix_web::error::ErrorNotFound;
use actix_web::http::header::{ use actix_web::http::header::{HeaderValue, CONTENT_LENGTH, CONTENT_TYPE, LAST_MODIFIED};
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, CACHE_CONTROL, CONTENT_LENGTH, use actix_web::web::{Data, Path};
CONTENT_TYPE, LAST_MODIFIED, X_CONTENT_TYPE_OPTIONS,
};
use actix_web::web::Path;
use actix_web::HttpResponseBuilder; use actix_web::HttpResponseBuilder;
use actix_web::{get, web::Data, HttpRequest, HttpResponse, Responder}; use actix_web::{get, HttpRequest, HttpResponse, Responder};
use base64::DecodeError; use base64::DecodeError;
use bytes::Bytes; use bytes::Bytes;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use futures::{Stream, TryStreamExt}; use futures::Stream;
use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use prometheus::{Encoder, TextEncoder}; use prometheus::{Encoder, TextEncoder};
use reqwest::{Client, StatusCode};
use serde::Deserialize; use serde::Deserialize;
use sodiumoxide::crypto::box_::{open_precomputed, Nonce, PrecomputedKey, NONCEBYTES}; use sodiumoxide::crypto::box_::{open_precomputed, Nonce, PrecomputedKey, NONCEBYTES};
use thiserror::Error; use thiserror::Error;
use tracing::{debug, error, info, trace};
use crate::cache::{Cache, CacheKey, ImageMetadata, UpstreamError}; use crate::cache::{Cache, CacheKey, ImageMetadata, UpstreamError};
use crate::client::{FetchResult, DEFAULT_HEADERS, HTTP_CLIENT};
use crate::config::{OFFLINE_MODE, VALIDATE_TOKENS}; use crate::config::{OFFLINE_MODE, VALIDATE_TOKENS};
use crate::metrics::{ use crate::metrics::{
CACHE_HIT_COUNTER, CACHE_MISS_COUNTER, REQUESTS_DATA_COUNTER, REQUESTS_DATA_SAVER_COUNTER, CACHE_HIT_COUNTER, CACHE_MISS_COUNTER, REQUESTS_DATA_COUNTER, REQUESTS_DATA_SAVER_COUNTER,
@ -31,21 +28,14 @@ use crate::state::RwLockServerState;
const BASE64_CONFIG: base64::Config = base64::Config::new(base64::CharacterSet::UrlSafe, false); const BASE64_CONFIG: base64::Config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
static HTTP_CLIENT: Lazy<Client> = Lazy::new(|| { pub enum ServerResponse {
Client::builder()
.pool_idle_timeout(Duration::from_secs(180))
.https_only(true)
.http2_prior_knowledge()
.build()
.expect("Client initialization to work")
});
enum ServerResponse {
TokenValidationError(TokenValidationError), TokenValidationError(TokenValidationError),
HttpResponse(HttpResponse), HttpResponse(HttpResponse),
} }
impl Responder for ServerResponse { impl Responder for ServerResponse {
type Body = BoxBody;
#[inline] #[inline]
fn respond_to(self, req: &HttpRequest) -> HttpResponse { fn respond_to(self, req: &HttpRequest) -> HttpResponse {
match self { match self {
@ -58,12 +48,12 @@ impl Responder for ServerResponse {
} }
} }
#[allow(clippy::unused_async)]
#[get("/")] #[get("/")]
async fn index() -> impl Responder { async fn index() -> impl Responder {
HttpResponse::Ok().body(include_str!("index.html")) HttpResponse::Ok().body(include_str!("index.html"))
} }
#[allow(clippy::future_not_send)]
#[get("/{token}/data/{chapter_hash}/{file_name}")] #[get("/{token}/data/{chapter_hash}/{file_name}")]
async fn token_data( async fn token_data(
state: Data<RwLockServerState>, state: Data<RwLockServerState>,
@ -80,7 +70,6 @@ async fn token_data(
fetch_image(state, cache, chapter_hash, file_name, false).await fetch_image(state, cache, chapter_hash, file_name, false).await
} }
#[allow(clippy::future_not_send)]
#[get("/{token}/data-saver/{chapter_hash}/{file_name}")] #[get("/{token}/data-saver/{chapter_hash}/{file_name}")]
async fn token_data_saver( async fn token_data_saver(
state: Data<RwLockServerState>, state: Data<RwLockServerState>,
@ -116,42 +105,43 @@ pub async fn default(state: Data<RwLockServerState>, req: HttpRequest) -> impl R
info!("Got unknown path, just proxying: {}", path); info!("Got unknown path, just proxying: {}", path);
let resp = match HTTP_CLIENT.get(path).send().await { let mut resp = match HTTP_CLIENT.inner().get(path).send().await {
Ok(resp) => resp, Ok(resp) => resp,
Err(e) => { Err(e) => {
error!("{}", e); error!("{}", e);
return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish()); return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish());
} }
}; };
let content_type = resp.headers().get(CONTENT_TYPE); let content_type = resp.headers_mut().remove(CONTENT_TYPE);
let mut resp_builder = HttpResponseBuilder::new(resp.status()); let mut resp_builder = HttpResponseBuilder::new(resp.status());
let mut headers = DEFAULT_HEADERS.clone();
if let Some(content_type) = content_type { if let Some(content_type) = content_type {
resp_builder.insert_header((CONTENT_TYPE, content_type)); headers.insert(CONTENT_TYPE, content_type);
} }
push_headers(&mut resp_builder); // push_headers(&mut resp_builder);
ServerResponse::HttpResponse(resp_builder.body(resp.bytes().await.unwrap_or_default())) let mut resp = resp_builder.body(resp.bytes().await.unwrap_or_default());
*resp.headers_mut() = headers;
ServerResponse::HttpResponse(resp)
} }
#[allow(clippy::future_not_send)] #[allow(clippy::unused_async)]
#[get("/metrics")] #[get("/prometheus")]
pub async fn metrics() -> impl Responder { pub async fn metrics() -> impl Responder {
let metric_families = prometheus::gather(); let metric_families = prometheus::gather();
let mut buffer = Vec::new(); let mut buffer = Vec::new();
TextEncoder::new() TextEncoder::new()
.encode(&metric_families, &mut buffer) .encode(&metric_families, &mut buffer)
.unwrap(); .expect("Should never have an io error writing to a vec");
String::from_utf8(buffer).unwrap() String::from_utf8(buffer).expect("Text encoder should render valid utf-8")
} }
#[derive(Error, Debug)] #[derive(Error, Debug, PartialEq, Eq)]
enum TokenValidationError { pub enum TokenValidationError {
#[error("Failed to decode base64 token.")] #[error("Failed to decode base64 token.")]
DecodeError(#[from] DecodeError), DecodeError(#[from] DecodeError),
#[error("Nonce was too short.")] #[error("Nonce was too short.")]
IncompleteNonce, IncompleteNonce,
#[error("Invalid nonce.")]
InvalidNonce,
#[error("Decryption failed")] #[error("Decryption failed")]
DecryptionFailure, DecryptionFailure,
#[error("The token format was invalid.")] #[error("The token format was invalid.")]
@ -163,9 +153,13 @@ enum TokenValidationError {
} }
impl Responder for TokenValidationError { impl Responder for TokenValidationError {
type Body = BoxBody;
#[inline] #[inline]
fn respond_to(self, _: &HttpRequest) -> HttpResponse { fn respond_to(self, _: &HttpRequest) -> HttpResponse {
push_headers(&mut HttpResponse::Forbidden()).finish() let mut resp = HttpResponse::Forbidden().finish();
*resp.headers_mut() = DEFAULT_HEADERS.clone();
resp
} }
} }
@ -187,7 +181,11 @@ fn validate_token(
let (nonce, encrypted) = data.split_at(NONCEBYTES); let (nonce, encrypted) = data.split_at(NONCEBYTES);
let nonce = Nonce::from_slice(nonce).ok_or(TokenValidationError::InvalidNonce)?; let nonce = match Nonce::from_slice(nonce) {
Some(nonce) => nonce,
// We split at NONCEBYTES, so this should never happen.
None => unsafe { unreachable_unchecked() },
};
let decrypted = open_precomputed(encrypted, &nonce, precomputed_key) let decrypted = open_precomputed(encrypted, &nonce, precomputed_key)
.map_err(|_| TokenValidationError::DecryptionFailure)?; .map_err(|_| TokenValidationError::DecryptionFailure)?;
@ -207,18 +205,6 @@ fn validate_token(
Ok(()) Ok(())
} }
#[inline]
fn push_headers(builder: &mut HttpResponseBuilder) -> &mut HttpResponseBuilder {
builder
.insert_header((X_CONTENT_TYPE_OPTIONS, "nosniff"))
.insert_header((ACCESS_CONTROL_ALLOW_ORIGIN, "https://mangadex.org"))
.insert_header((ACCESS_CONTROL_EXPOSE_HEADERS, "*"))
.insert_header((CACHE_CONTROL, "public, max-age=1209600"))
.insert_header(("Timing-Allow-Origin", "https://mangadex.org"));
builder
}
#[allow(clippy::future_not_send)] #[allow(clippy::future_not_send)]
async fn fetch_image( async fn fetch_image(
state: Data<RwLockServerState>, state: Data<RwLockServerState>,
@ -237,7 +223,7 @@ async fn fetch_image(
Some(Err(_)) => { Some(Err(_)) => {
return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish()); return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish());
} }
_ => (), None => (),
} }
CACHE_MISS_COUNTER.inc(); CACHE_MISS_COUNTER.inc();
@ -249,112 +235,217 @@ async fn fetch_image(
); );
} }
// It's important to not get a write lock before this request, else we're let url = if is_data_saver {
// holding the read lock until the await resolves. format!(
"{}/data-saver/{}/{}",
let resp = if is_data_saver { state.0.read().image_server,
HTTP_CLIENT &key.0,
.get(format!( &key.1,
"{}/data-saver/{}/{}", )
state.0.read().image_server,
&key.0,
&key.1
))
.send()
} else { } else {
HTTP_CLIENT format!("{}/data/{}/{}", state.0.read().image_server, &key.0, &key.1)
.get(format!( };
"{}/data/{}/{}",
state.0.read().image_server,
&key.0,
&key.1
))
.send()
}
.await;
match resp { match HTTP_CLIENT.fetch_and_cache(url, key, cache).await {
Ok(mut resp) => { FetchResult::ServiceUnavailable => {
let content_type = resp.headers().get(CONTENT_TYPE); ServerResponse::HttpResponse(HttpResponse::ServiceUnavailable().finish())
let is_image = content_type
.map(|v| String::from_utf8_lossy(v.as_ref()).contains("image/"))
.unwrap_or_default();
if resp.status() != StatusCode::OK || !is_image {
warn!(
"Got non-OK or non-image response code from upstream, proxying and not caching result.",
);
let mut resp_builder = HttpResponseBuilder::new(resp.status());
if let Some(content_type) = content_type {
resp_builder.insert_header((CONTENT_TYPE, content_type));
}
push_headers(&mut resp_builder);
return ServerResponse::HttpResponse(
resp_builder.body(resp.bytes().await.unwrap_or_default()),
);
}
let (content_type, length, last_mod) = {
let headers = resp.headers_mut();
(
headers.remove(CONTENT_TYPE),
headers.remove(CONTENT_LENGTH),
headers.remove(LAST_MODIFIED),
)
};
let body = resp.bytes_stream().map_err(|e| e.into());
debug!("Inserting into cache");
let metadata = ImageMetadata::new(content_type, length, last_mod).unwrap();
let stream = {
match cache.put(key, Box::new(body), metadata).await {
Ok(stream) => stream,
Err(e) => {
warn!("Failed to insert into cache: {}", e);
return ServerResponse::HttpResponse(
HttpResponse::InternalServerError().finish(),
);
}
}
};
debug!("Done putting into cache");
construct_response(stream, &metadata)
} }
Err(e) => { FetchResult::InternalServerError => {
error!("Failed to fetch image from server: {}", e); ServerResponse::HttpResponse(HttpResponse::InternalServerError().finish())
ServerResponse::HttpResponse(
push_headers(&mut HttpResponse::ServiceUnavailable()).finish(),
)
} }
FetchResult::Data(status, headers, data) => {
let mut resp = HttpResponseBuilder::new(status);
let mut resp = resp.body(data);
*resp.headers_mut() = headers;
ServerResponse::HttpResponse(resp)
}
FetchResult::Processing => panic!("Race condition found with fetch result"),
} }
} }
fn construct_response( #[inline]
pub fn construct_response(
data: impl Stream<Item = Result<Bytes, UpstreamError>> + Unpin + 'static, data: impl Stream<Item = Result<Bytes, UpstreamError>> + Unpin + 'static,
metadata: &ImageMetadata, metadata: &ImageMetadata,
) -> ServerResponse { ) -> ServerResponse {
trace!("Constructing response"); trace!("Constructing response");
let mut resp = HttpResponse::Ok(); let mut resp = HttpResponse::Ok();
let mut headers = DEFAULT_HEADERS.clone();
if let Some(content_type) = metadata.content_type { if let Some(content_type) = metadata.content_type {
resp.append_header((CONTENT_TYPE, content_type.as_ref())); headers.insert(
CONTENT_TYPE,
HeaderValue::from_str(content_type.as_ref()).unwrap(),
);
} }
if let Some(content_length) = metadata.content_length { if let Some(content_length) = metadata.content_length {
resp.append_header((CONTENT_LENGTH, content_length)); headers.insert(CONTENT_LENGTH, HeaderValue::from(content_length));
} }
if let Some(last_modified) = metadata.last_modified { if let Some(last_modified) = metadata.last_modified {
resp.append_header((LAST_MODIFIED, last_modified.to_rfc2822())); headers.insert(
LAST_MODIFIED,
HeaderValue::from_str(&last_modified.to_rfc2822()).unwrap(),
);
} }
ServerResponse::HttpResponse(push_headers(&mut resp).streaming(data)) let mut ret = resp.streaming(data);
*ret.headers_mut() = headers;
ServerResponse::HttpResponse(ret)
}
#[cfg(test)]
mod token_validation {
use super::{BASE64_CONFIG, DecodeError, PrecomputedKey, TokenValidationError, Utc, validate_token};
use sodiumoxide::crypto::box_::precompute;
use sodiumoxide::crypto::box_::seal_precomputed;
use sodiumoxide::crypto::box_::{gen_keypair, gen_nonce, PRECOMPUTEDKEYBYTES};
#[test]
fn invalid_base64() {
let res = validate_token(
&PrecomputedKey::from_slice(&b"1".repeat(PRECOMPUTEDKEYBYTES))
.expect("valid test token"),
"a".to_string(),
"b",
);
assert_eq!(
res,
Err(TokenValidationError::DecodeError(
DecodeError::InvalidLength
))
);
}
#[test]
fn not_long_enough_for_nonce() {
let res = validate_token(
&PrecomputedKey::from_slice(&b"1".repeat(PRECOMPUTEDKEYBYTES))
.expect("valid test token"),
"aGVsbG8gaW50ZXJuZXR-Cg==".to_string(),
"b",
);
assert_eq!(res, Err(TokenValidationError::IncompleteNonce));
}
#[test]
fn invalid_precomputed_key() {
let precomputed_1 = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let precomputed_2 = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
// Seal with precomputed_2, open with precomputed_1
let data = seal_precomputed(b"hello world", &nonce, &precomputed_2);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed_1, data, "b");
assert_eq!(res, Err(TokenValidationError::DecryptionFailure));
}
#[test]
fn invalid_token_data() {
let precomputed = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
let data = seal_precomputed(b"hello world", &nonce, &precomputed);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed, data, "b");
assert_eq!(res, Err(TokenValidationError::InvalidToken));
}
#[test]
fn token_must_have_valid_expiration() {
let precomputed = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
let time = Utc::now() - chrono::Duration::weeks(1);
let data = seal_precomputed(
serde_json::json!({
"expires": time.to_rfc3339(),
"hash": "b",
})
.to_string()
.as_bytes(),
&nonce,
&precomputed,
);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed, data, "b");
assert_eq!(res, Err(TokenValidationError::TokenExpired));
}
#[test]
fn token_must_have_valid_chapter_hash() {
let precomputed = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
let time = Utc::now() + chrono::Duration::weeks(1);
let data = seal_precomputed(
serde_json::json!({
"expires": time.to_rfc3339(),
"hash": "b",
})
.to_string()
.as_bytes(),
&nonce,
&precomputed,
);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed, data, "");
assert_eq!(res, Err(TokenValidationError::InvalidChapterHash));
}
#[test]
fn valid_token_returns_ok() {
let precomputed = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
let time = Utc::now() + chrono::Duration::weeks(1);
let data = seal_precomputed(
serde_json::json!({
"expires": time.to_rfc3339(),
"hash": "b",
})
.to_string()
.as_bytes(),
&nonce,
&precomputed,
);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed, data, "b");
assert!(res.is_ok());
}
} }

View file

@ -1,17 +1,19 @@
use std::str::FromStr; use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use crate::config::{ClientSecret, Config, UnstableOptions, OFFLINE_MODE, VALIDATE_TOKENS}; use crate::client::HTTP_CLIENT;
use crate::config::{ClientSecret, Config, OFFLINE_MODE};
use crate::ping::{Request, Response, CONTROL_CENTER_PING_URL}; use crate::ping::{Request, Response, CONTROL_CENTER_PING_URL};
use arc_swap::ArcSwap; use arc_swap::{ArcSwap, ArcSwapOption};
use log::{error, info, warn};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use parking_lot::RwLock; use parking_lot::RwLock;
use rustls::sign::{CertifiedKey, SigningKey}; use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::{CertifiedKey, RsaSigningKey, SigningKey};
use rustls::Certificate; use rustls::Certificate;
use rustls::{ClientHello, ResolvesServerCert};
use sodiumoxide::crypto::box_::{PrecomputedKey, PRECOMPUTEDKEYBYTES}; use sodiumoxide::crypto::box_::{PrecomputedKey, PRECOMPUTEDKEYBYTES};
use thiserror::Error; use thiserror::Error;
use tracing::{error, info, warn};
use url::Url; use url::Url;
pub struct ServerState { pub struct ServerState {
@ -25,8 +27,10 @@ pub static PREVIOUSLY_PAUSED: AtomicBool = AtomicBool::new(false);
pub static PREVIOUSLY_COMPROMISED: AtomicBool = AtomicBool::new(false); pub static PREVIOUSLY_COMPROMISED: AtomicBool = AtomicBool::new(false);
pub static TLS_PREVIOUSLY_CREATED: OnceCell<ArcSwap<String>> = OnceCell::new(); pub static TLS_PREVIOUSLY_CREATED: OnceCell<ArcSwap<String>> = OnceCell::new();
pub static TLS_SIGNING_KEY: OnceCell<ArcSwap<Box<dyn SigningKey>>> = OnceCell::new(); static TLS_SIGNING_KEY: OnceCell<ArcSwap<RsaSigningKey>> = OnceCell::new();
pub static TLS_CERTS: OnceCell<ArcSwap<Vec<Certificate>>> = OnceCell::new(); static TLS_CERTS: OnceCell<ArcSwap<Vec<Certificate>>> = OnceCell::new();
pub static CERTIFIED_KEY: ArcSwapOption<CertifiedKey> = ArcSwapOption::const_empty();
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ServerInitError { pub enum ServerInitError {
@ -46,7 +50,8 @@ pub enum ServerInitError {
impl ServerState { impl ServerState {
pub async fn init(secret: &ClientSecret, config: &Config) -> Result<Self, ServerInitError> { pub async fn init(secret: &ClientSecret, config: &Config) -> Result<Self, ServerInitError> {
let resp = reqwest::Client::new() let resp = HTTP_CLIENT
.inner()
.post(CONTROL_CENTER_PING_URL) .post(CONTROL_CENTER_PING_URL)
.json(&Request::from((secret, config))) .json(&Request::from((secret, config)))
.send() .send()
@ -59,15 +64,16 @@ impl ServerState {
.token_key .token_key
.ok_or(ServerInitError::MissingTokenKey) .ok_or(ServerInitError::MissingTokenKey)
.and_then(|key| { .and_then(|key| {
if let Some(key) = base64::decode(&key) base64::decode(&key)
.ok() .ok()
.and_then(|k| PrecomputedKey::from_slice(&k)) .and_then(|k| PrecomputedKey::from_slice(&k))
{ .map_or_else(
Ok(key) || {
} else { error!("Failed to parse token key: got {}", key);
error!("Failed to parse token key: got {}", key); Err(ServerInitError::KeyParseError(key))
Err(ServerInitError::KeyParseError(key)) },
} Ok,
)
})?; })?;
PREVIOUSLY_COMPROMISED.store(resp.compromised, Ordering::Release); PREVIOUSLY_COMPROMISED.store(resp.compromised, Ordering::Release);
@ -83,26 +89,19 @@ impl ServerState {
if let Some(ref override_url) = config.override_upstream { if let Some(ref override_url) = config.override_upstream {
resp.image_server = override_url.clone(); resp.image_server = override_url.clone();
warn!("Upstream URL overridden to: {}", resp.image_server); warn!("Upstream URL overridden to: {}", resp.image_server);
} else {
} }
info!("This client's URL has been set to {}", resp.url); info!("This client's URL has been set to {}", resp.url);
if config
.unstable_options
.contains(&UnstableOptions::DisableTokenValidation)
{
warn!("Token validation is explicitly disabled!");
} else {
if resp.force_tokens {
info!("This client will validate tokens.");
} else {
info!("This client will not validate tokens.");
}
VALIDATE_TOKENS.store(resp.force_tokens, Ordering::Release);
}
let tls = resp.tls.unwrap(); let tls = resp.tls.unwrap();
CERTIFIED_KEY.store(Some(Arc::new(CertifiedKey {
cert: tls.certs.clone(),
key: Arc::clone(&tls.priv_key) as Arc<dyn SigningKey>,
ocsp: None,
sct_list: None,
})));
std::mem::drop( std::mem::drop(
TLS_PREVIOUSLY_CREATED.set(ArcSwap::from_pointee(tls.created_at)), TLS_PREVIOUSLY_CREATED.set(ArcSwap::from_pointee(tls.created_at)),
); );
@ -144,9 +143,10 @@ impl ServerState {
pub fn init_offline() -> Self { pub fn init_offline() -> Self {
assert!(OFFLINE_MODE.load(Ordering::Acquire)); assert!(OFFLINE_MODE.load(Ordering::Acquire));
Self { Self {
precomputed_key: PrecomputedKey::from_slice(&[41; PRECOMPUTEDKEYBYTES]).unwrap(), precomputed_key: PrecomputedKey::from_slice(&[41; PRECOMPUTEDKEYBYTES])
image_server: Url::from_file_path("/dev/null").unwrap(), .expect("expect offline config to work"),
url: Url::from_str("http://localhost").unwrap(), image_server: Url::from_file_path("/dev/null").expect("expect offline config to work"),
url: Url::from_str("http://localhost").expect("expect offline config to work"),
url_overridden: false, url_overridden: false,
} }
} }
@ -157,14 +157,9 @@ pub struct RwLockServerState(pub RwLock<ServerState>);
pub struct DynamicServerCert; pub struct DynamicServerCert;
impl ResolvesServerCert for DynamicServerCert { impl ResolvesServerCert for DynamicServerCert {
fn resolve(&self, _: ClientHello) -> Option<CertifiedKey> { fn resolve(&self, _: ClientHello) -> Option<Arc<CertifiedKey>> {
// TODO: wait for actix-web to use a new version of rustls so we can // TODO: wait for actix-web to use a new version of rustls so we can
// remove cloning the certs all the time // remove cloning the certs all the time
Some(CertifiedKey { CERTIFIED_KEY.load_full()
cert: TLS_CERTS.get().unwrap().load().as_ref().clone(),
key: TLS_SIGNING_KEY.get().unwrap().load_full(),
ocsp: None,
sct_list: None,
})
} }
} }

View file

@ -1,10 +1,13 @@
use log::{info, warn}; #![cfg(not(tarpaulin_include))]
use reqwest::StatusCode; use reqwest::StatusCode;
use serde::Serialize; use serde::Serialize;
use tracing::{info, warn};
use crate::client::HTTP_CLIENT;
use crate::config::ClientSecret; use crate::config::ClientSecret;
const CONTROL_CENTER_STOP_URL: &str = "https://api.mangadex.network/ping"; const CONTROL_CENTER_STOP_URL: &str = "https://api.mangadex.network/stop";
#[derive(Serialize)] #[derive(Serialize)]
struct StopRequest<'a> { struct StopRequest<'a> {
@ -12,11 +15,10 @@ struct StopRequest<'a> {
} }
pub async fn send_stop(secret: &ClientSecret) { pub async fn send_stop(secret: &ClientSecret) {
let request = StopRequest { secret }; match HTTP_CLIENT
let client = reqwest::Client::new(); .inner()
match client
.post(CONTROL_CENTER_STOP_URL) .post(CONTROL_CENTER_STOP_URL)
.json(&request) .json(&StopRequest { secret })
.send() .send()
.await .await
{ {
@ -30,3 +32,17 @@ pub async fn send_stop(secret: &ClientSecret) {
Err(e) => warn!("Got error while sending stop message: {}", e), Err(e) => warn!("Got error while sending stop message: {}", e),
} }
} }
#[cfg(test)]
mod stop {
use super::CONTROL_CENTER_STOP_URL;
#[test]
fn stop_url_does_not_have_ping_in_url() {
// This looks like a dumb test, yes, but it ensures that clients don't
// get marked compromised because apparently just sending a json obj
// with just the secret is acceptable to the ping endpoint, which messes
// up non-trivial client configs.
assert!(!CONTROL_CENTER_STOP_URL.contains("ping"))
}
}

View file

@ -5,13 +5,17 @@ use std::str::FromStr;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Wrapper type for a port number. /// Wrapper type for a port number.
#[derive(Serialize, Deserialize, Debug, Clone, Copy)] #[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub struct Port(NonZeroU16); pub struct Port(NonZeroU16);
impl Port { impl Port {
pub const fn get(self) -> u16 { pub const fn get(self) -> u16 {
self.0.get() self.0.get()
} }
pub fn new(amt: u16) -> Option<Self> {
NonZeroU16::new(amt).map(Self)
}
} }
impl Default for Port { impl Default for Port {
@ -34,9 +38,16 @@ impl Display for Port {
} }
} }
#[derive(Copy, Clone, Serialize, Deserialize, Default, Debug, Hash, Eq, PartialEq)] #[derive(Copy, Clone, Deserialize, Default, Debug, Hash, Eq, PartialEq)]
pub struct Mebibytes(usize); pub struct Mebibytes(usize);
impl Mebibytes {
#[cfg(test)]
pub fn new(size: usize) -> Self {
Self(size)
}
}
impl FromStr for Mebibytes { impl FromStr for Mebibytes {
type Err = ParseIntError; type Err = ParseIntError;
@ -45,7 +56,8 @@ impl FromStr for Mebibytes {
} }
} }
pub struct Bytes(usize); #[derive(Serialize, Debug)]
pub struct Bytes(pub usize);
impl Bytes { impl Bytes {
pub const fn get(&self) -> usize { pub const fn get(&self) -> usize {
@ -62,6 +74,13 @@ impl From<Mebibytes> for Bytes {
#[derive(Copy, Clone, Deserialize, Debug, Hash, Eq, PartialEq)] #[derive(Copy, Clone, Deserialize, Debug, Hash, Eq, PartialEq)]
pub struct KilobitsPerSecond(NonZeroU64); pub struct KilobitsPerSecond(NonZeroU64);
impl KilobitsPerSecond {
#[cfg(test)]
pub fn new(size: u64) -> Option<Self> {
NonZeroU64::new(size).map(Self)
}
}
impl FromStr for KilobitsPerSecond { impl FromStr for KilobitsPerSecond {
type Err = ParseIntError; type Err = ParseIntError;