Compare commits

...

129 Commits

Author SHA1 Message Date
Edward Shen 1152f775b9
lint tests 2022-03-26 16:30:50 -07:00
Edward Shen e404d144a2
clippy 2022-03-26 16:28:58 -07:00
Edward Shen bc6ed7d07e
Update actix 2022-03-26 16:21:27 -07:00
Edward Shen ff0944f58c
Update dependencies 2022-03-26 16:20:00 -07:00
Edward Shen 81604a7e94
Clippy 2022-01-02 13:25:00 -08:00
Edward Shen f6a9caf653
Update Cargo.lock 2022-01-02 13:18:08 -08:00
Edward Shen 5c6f02b9a5
Update minor version 2022-01-02 12:48:03 -08:00
Edward Shen 557f141ed2
Update lru 2022-01-02 12:42:22 -08:00
Edward Shen 55f6279dce
Update to actix beta 18 2022-01-02 12:34:00 -08:00
Edward Shen 0bf76eab6b
Update deps 2022-01-02 12:12:35 -08:00
Edward Shen a838a94ce9
Update to modern code 2022-01-02 12:12:23 -08:00
Edward Shen 63eba4dc37
add db listener test 2021-07-23 20:27:52 -04:00
Edward Shen 4544061845
clarify iso code impl 2021-07-23 17:41:17 -04:00
Edward Shen 42cbd81375
convert some unwraps to expects 2021-07-22 13:46:40 -04:00
Edward Shen 07bd39e69d
Add unit tests for token validation 2021-07-22 13:37:43 -04:00
Edward Shen 8f5799211c
Try reserve port during startup 2021-07-22 13:37:32 -04:00
Edward Shen 0300135a6a
Add exclusions for code coverage 2021-07-20 16:47:04 -04:00
Edward Shen 7bbbf44328
Hololive-ify sample config file 2021-07-18 23:57:38 -04:00
Edward Shen 2878ddf0dc
Update dependencies 2021-07-18 23:45:24 -04:00
Edward Shen e4af231829
Remove tracing-futures 2021-07-18 22:05:36 -04:00
Edward Shen b04fac9b01
Added docker-compose file 2021-07-18 21:36:54 -04:00
Edward Shen 5aa72e9821
Add parser directive to Dockerfile 2021-07-18 21:33:56 -04:00
Edward Shen 6d6bf7371b
Add dockerfile 2021-07-18 21:27:29 -04:00
Edward Shen 6b1c913b5d
Remove unneeded imports 2021-07-18 18:35:22 -04:00
Edward Shen acd37297fd
Remove legacy token validation field 2021-07-18 18:32:19 -04:00
Edward Shen bd306455bc
Use ReaderStream 2021-07-18 11:37:39 -04:00
Edward Shen afa2cf55fa
Fix some future not send lints 2021-07-17 13:32:43 -04:00
Edward Shen e95afd3e32
Bump to 0.5.3 2021-07-17 13:05:13 -04:00
Edward Shen fbcf9566c1
Update readme 2021-07-17 13:04:39 -04:00
Edward Shen d42b80d7e1
Remove encryption option warning 2021-07-17 12:52:41 -04:00
Edward Shen 931da0c3ff
Add redis support 2021-07-17 12:52:02 -04:00
Edward Shen 5fdcfa5071
Add newline to Cargo.toml 2021-07-16 21:15:47 -04:00
Edward Shen 93ff76aa89
create folder if not found 2021-07-16 20:03:59 -04:00
Edward Shen f8f4098fae
Finish mem tests 2021-07-16 16:40:07 -04:00
Edward Shen bfcf131b33
Add stop url test 2021-07-16 15:26:18 -04:00
Edward Shen 5da486d43d
Add put test for mem cache 2021-07-16 15:18:10 -04:00
Edward Shen 51546eb387
Add memory cache get tests 2021-07-16 14:06:28 -04:00
Edward Shen 712257429a
Remove Compression wrapper 2021-07-16 12:32:24 -04:00
Edward Shen 5e7a82a610
Add partial test for mem cache 2021-07-16 01:13:51 -04:00
Edward Shen afb02db7b7
Remove unnecessary async keyword 2021-07-16 01:13:31 -04:00
Edward Shen b41ae8cb79
Add sync restriction on CacheStream 2021-07-16 01:13:01 -04:00
Edward Shen 54c8fe1cb3
Turn DB messages into struct from tuple 2021-07-15 21:49:19 -04:00
Edward Shen 8556f37904
Extract internal cache listener to function 2021-07-15 21:49:19 -04:00
Edward Shen fc930285f0
Bump to 0.5.2 2021-07-15 19:14:54 -04:00
Edward Shen 041760f9e9
clippy 2021-07-15 19:13:31 -04:00
Edward Shen 87271c85a7
Fix deleting legacy names 2021-07-15 19:03:39 -04:00
Edward Shen 3e4260f6e1
rename /metric endpoint to /prometheus 2021-07-15 15:54:26 -04:00
Edward Shen dc99437aec
Fix legacy path lookup 2021-07-15 13:47:55 -04:00
Edward Shen 3dbf2f8bb0
Read legacy path on disk 2021-07-15 13:25:46 -04:00
Edward Shen f7b037e8e1
Update readme 2021-07-15 12:58:40 -04:00
Edward Shen a552523a3a
increment version to 0.5.1 2021-07-15 12:39:19 -04:00
Edward Shen 5935e4220b
Fix stop url 2021-07-15 12:37:55 -04:00
Edward Shen fa9ab93c77
Add proxy support 2021-07-15 12:29:55 -04:00
Edward Shen 833a0c0468
Add automatic migration to new db location 2021-07-15 11:17:54 -04:00
Edward Shen b71253d8dc
Remove debug statement 2021-07-15 10:48:25 -04:00
Edward Shen 84941e2cb4
Fix sending bytes instead of mebibytes 2021-07-15 03:01:15 -04:00
Edward Shen 940af6508c
Default metadata path is metadata.db now 2021-07-15 02:53:00 -04:00
Edward Shen d3434e8408
Fix includes for publishing 2021-07-15 02:45:05 -04:00
Edward Shen 3786827f20
Add sqlx json 2021-07-15 02:16:33 -04:00
Edward Shen 261427a735
Bump version to 0.5.0 2021-07-15 02:14:41 -04:00
Edward Shen c4fa53fa40
Add geo ip logging support 2021-07-15 02:14:04 -04:00
Edward Shen 1c00c993bf
Rename metrics to conventions 2021-07-15 02:13:31 -04:00
Edward Shen b2650da556
Documented more CLI options 2021-07-15 02:12:20 -04:00
Edward Shen 6b3c6ce03a
Added geo ip dependencies 2021-07-15 02:10:30 -04:00
Edward Shen 032db4e1dd
Add special thanks to readme 2021-07-15 02:09:57 -04:00
Edward Shen 355fd936ab
Add debug message to outgoing ping request 2021-07-15 01:19:21 -04:00
Edward Shen 6415c3dee6
Respect external ip config 2021-07-15 00:54:31 -04:00
Edward Shen acf6dc1cb1
Add geoip config reading 2021-07-14 22:32:05 -04:00
Edward Shen d4d22ec674
Reduce tokio features 2021-07-14 21:56:46 -04:00
Edward Shen 353ee72713
Add unit tests 2021-07-14 21:56:29 -04:00
Edward Shen b1797dafd2
Fix double write bug 2021-07-14 19:11:46 -04:00
Edward Shen 9209b822a9
Simplify DiskWriter poll_flush 2021-07-14 14:20:31 -04:00
Edward Shen 5338ff81a5
Move impl block to correct location 2021-07-14 14:00:02 -04:00
Edward Shen 53015e116f
Renamed EncryptedDiskReader to EncryptedReader 2021-07-14 13:32:26 -04:00
Edward Shen 7ce974b4f9
Make EncryptedDiskReader generic over R 2021-07-14 13:32:00 -04:00
Edward Shen 973ece3604
MetadataFuture tests, fix UB 2021-07-14 13:28:09 -04:00
Edward Shen 0c78b379f1
Remove comments w.r.t. potential db optimizations
DB queries already are performed on their own thread, so spawning a
thread to do db work is unncessary.
2021-07-14 11:21:28 -04:00
Edward Shen 94375b185f
More debugging 2021-07-13 23:12:29 -04:00
Edward Shen 6ac8582183
Encrypted files work in debug mode 2021-07-13 20:38:01 -04:00
Edward Shen 656543b539
more partial work into encryption 2021-07-13 16:39:32 -04:00
Edward Shen 2ace8d3d66
Partial rewrite of encrypted writer 2021-07-13 13:16:44 -04:00
Edward Shen 160f369a72
Migrate to tracing crate 2021-07-12 23:23:51 -04:00
Edward Shen f8ee49ffd7
Add extended options for sample config 2021-07-12 22:35:06 -04:00
Edward Shen 9f76a7a1b3
Add compression middleware 2021-07-12 16:39:06 -04:00
Edward Shen 2f271a220a
remove tarpaulin-report 2021-07-12 16:35:17 -04:00
Edward Shen 868278991a
Add disk tests 2021-07-12 15:59:52 -04:00
Edward Shen e8bea39100
Remove serialize impl from legacy structs 2021-07-12 13:43:47 -04:00
Edward Shen 20e349fd79
Add sqlx envs 2021-07-12 01:44:01 -04:00
Edward Shen 80eeacd884
try fix ci 2021-07-12 01:35:32 -04:00
Edward Shen 580d05d00c
add sqlx check 2021-07-12 01:34:37 -04:00
Edward Shen acd8f234ab
ignore cfg tarpaulin for now 2021-07-12 01:08:19 -04:00
Edward Shen 4c135cd72d
Add coverage action 2021-07-12 01:05:18 -04:00
Edward Shen acc3ab2186
use hex instead of u8s in test 2021-07-12 00:49:06 -04:00
Edward Shen 8daa6bdc27
add tests for Md5Hash conversions 2021-07-12 00:48:00 -04:00
Edward Shen 69587b9ade
Clippy lints 2021-07-12 00:12:15 -04:00
Edward Shen 8f3430fb77
Add support for reading old db image ids 2021-07-11 23:33:22 -04:00
Edward Shen ec9473fa78
clippy lint 2021-07-11 23:25:17 -04:00
Edward Shen 3764af0ed5
Optimize header creation 2021-07-11 14:23:15 -04:00
Edward Shen 099c795cca
Add potential perf gains in the future 2021-07-11 14:22:59 -04:00
Edward Shen 92a66e60dc
Seek file from beginning on encrypted header 2021-07-11 13:25:02 -04:00
Edward Shen 5143dff888
Add logging 2021-07-11 13:21:57 -04:00
Edward Shen 7546948196
nightly clippy lints 2021-07-11 13:19:37 -04:00
Edward Shen 5f4be9809a
Add support for legacy files 2021-07-11 02:33:51 -04:00
Edward Shen 8040c49a1e
Add warning for encryption 2021-07-11 00:15:43 -04:00
Edward Shen 9871fc3774
bump to 0.4 2021-07-10 19:07:55 -04:00
Edward Shen f64d03493e
use static default headers 2021-07-10 19:04:27 -04:00
Edward Shen 93249397f1
Simply codebase 2021-07-10 18:53:28 -04:00
Edward Shen 154679967b
writing optimizations 2021-07-10 14:22:29 -04:00
Edward Shen b90edd72a6
Add clippy to CI 2021-07-09 21:25:08 -04:00
Edward Shen 3ec4d1c125
remove anchor from github workflows 2021-07-09 20:59:23 -04:00
Edward Shen 52ca595029
fix workflow ignores 2021-07-09 20:58:07 -04:00
Edward Shen e0bd29751a
Update paths to ignore 2021-07-09 20:56:35 -04:00
Edward Shen 7bd9189ebd
Update readme 2021-07-09 20:51:45 -04:00
Edward Shen c98a6d59f4
Have build script create .env if not found 2021-07-09 20:43:33 -04:00
Edward Shen 60bec5592a
Have build script modify .env file 2021-07-09 20:37:21 -04:00
Edward Shen de5816a44a
add env value to CI 2021-07-09 20:22:58 -04:00
Edward Shen 6fa301a4a5
rename workflow 2021-07-09 20:12:58 -04:00
Edward Shen 7b5738da8d
add standard CI 2021-07-09 20:11:43 -04:00
Edward Shen 5ab04a9e9c
add security CI 2021-07-09 20:02:22 -04:00
Edward Shen e65a7ba9ef
Update dependencies 2021-07-09 19:53:58 -04:00
Edward Shen 3a855a7e4a
add debug for client config 2021-07-09 19:51:48 -04:00
Edward Shen 5300afa205
Clean up ClientSecret usage 2021-07-09 19:48:25 -04:00
Edward Shen c5383639f5
ignore default config value 2021-07-09 19:18:09 -04:00
Edward Shen 88561f7c2c
Fix short arg conflict 2021-07-09 19:17:56 -04:00
Edward Shen ce03ce0baf
add support for yaml files 2021-07-09 19:14:53 -04:00
Edward Shen e78315025d
Use build script 2021-07-09 17:32:00 -04:00
Edward Shen a8e5d09ff0
clippy lints 2021-07-09 17:20:15 -04:00
Edward Shen b4f27c5f8c
migrate config to Config struct 2021-07-09 17:18:43 -04:00
Edward Shen 5b67431778
initial work 2021-07-09 14:36:04 -04:00
31 changed files with 4681 additions and 1677 deletions

10
.dockerignore Normal file
View File

@ -0,0 +1,10 @@
# Ignore everything
*
# Only include necessary paths (This should be synchronized with `Cargo.toml`)
!db_queries/
!src/
!settings.sample.yaml
!sqlx-data.json
!Cargo.toml
!Cargo.lock

53
.github/workflows/build_and_test.yml vendored Normal file
View File

@ -0,0 +1,53 @@
name: Build and test
on:
push:
branches: [ master ]
paths-ignore:
- "docs/**"
- "settings.sample.yaml"
- "README.md"
- "LICENSE"
pull_request:
branches: [ master ]
paths-ignore:
- "docs/**"
- "settings.sample.yaml"
- "README.md"
- "LICENSE"
env:
CARGO_TERM_COLOR: always
DATABASE_URL: sqlite:./cache/metadata.sqlite
SQLX_OFFLINE: true
jobs:
clippy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- run: rustup component add clippy
- uses: actions-rs/clippy-check@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
args: --all-features
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose
sqlx-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Install sqlx-cli
run: cargo install sqlx-cli
- name: Initialize database
run: mkdir -p cache && sqlite3 cache/metadata.sqlite < db_queries/init.sql
- name: Check sqlx statements
run: cargo sqlx prepare --check

22
.github/workflows/coverage.yml vendored Normal file
View File

@ -0,0 +1,22 @@
name: coverage
on: [push]
jobs:
test:
name: coverage
runs-on: ubuntu-latest
container:
image: xd009642/tarpaulin:develop-nightly
options: --security-opt seccomp=unconfined
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Generate code coverage
run: |
cargo +nightly tarpaulin --verbose --all-features --workspace --timeout 120 --avoid-cfg-tarpaulin --out Xml
- name: Upload to codecov.io
uses: codecov/codecov-action@v1
with:
fail_ci_if_error: true

14
.github/workflows/security_audit.yml vendored Normal file
View File

@ -0,0 +1,14 @@
name: Security audit
on:
push:
paths:
- '**/Cargo.toml'
- '**/Cargo.lock'
jobs:
security_audit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions-rs/audit-check@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}

5
.gitignore vendored
View File

@ -3,4 +3,7 @@
/cache
flamegraph*.svg
perf.data*
dhat.out.*
dhat.out.*
settings.yaml
tarpaulin-report.html
GeoLite2-Country.mmdb

1583
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,49 +1,69 @@
[package]
name = "mangadex-home"
version = "0.3.0"
version = "0.5.4"
license = "GPL-3.0-or-later"
authors = ["Edward Shen <code@eddie.sh>"]
edition = "2018"
include = ["src/**/*", "db_queries", "LICENSE", "README.md"]
include = [
"src/**/*",
"db_queries",
"LICENSE",
"README.md",
"sqlx-data.json",
"settings.sample.yaml"
]
description = "A MangaDex@Home implementation in Rust."
repository = "https://github.com/edward-shen/mangadex-home-rs"
[profile.release]
lto = true
codegen-units = 1
# debug = 1
debug = 1
[dependencies]
actix-web = { version = "4.0.0-beta.4", features = [ "rustls" ] }
# Pin because we're using unstable versions
actix-web = { version = "4", features = [ "rustls" ] }
arc-swap = "1"
async-trait = "0.1"
base64 = "0.13"
bincode = "1"
bytes = "1"
bytes = { version = "1", features = [ "serde" ] }
chacha20 = "0.7"
chrono = { version = "0.4", features = [ "serde" ] }
clap = { version = "3.0.0-beta.2", features = [ "wrap_help" ] }
clap = { version = "3", features = [ "wrap_help", "derive", "cargo" ] }
ctrlc = "3"
dotenv = "0.15"
flate2 = { version = "1", features = [ "tokio" ] }
futures = "0.3"
once_cell = "1"
log = "0.4"
log = { version = "0.4", features = [ "serde" ] }
lfu_cache = "1"
lru = "0.6"
lru = "0.7"
maxminddb = "0.20"
md-5 = "0.9"
parking_lot = "0.11"
prometheus = { version = "0.12", features = [ "process" ] }
redis = "0.21"
reqwest = { version = "0.11", default_features = false, features = [ "json", "stream", "rustls-tls" ] }
rustls = "0.19"
rustls = "0.20"
rustls-pemfile = "0.2"
serde = "1"
serde_json = "1"
serde_repr = "0.1"
simple_logger = "1"
serde_yaml = "0.8"
sodiumoxide = "0.2"
sqlx = { version = "0.5", features = [ "runtime-actix-rustls", "sqlite", "time", "chrono", "macros" ] }
sqlx = { version = "0.5", features = [ "runtime-actix-rustls", "sqlite", "time", "chrono", "macros", "offline" ] }
tar = "0.4"
thiserror = "1"
tokio = { version = "1", features = [ "full", "parking_lot" ] }
tokio = { version = "1", features = [ "rt-multi-thread", "macros", "fs", "time", "sync", "parking_lot" ] }
tokio-stream = { version = "0.1", features = [ "sync" ] }
tokio-util = { version = "0.6", features = [ "codec" ] }
tracing = "0.1"
tracing-subscriber = { version = "0.2", features = [ "parking_lot" ] }
url = { version = "2", features = [ "serde" ] }
[build-dependencies]
vergen = "5"
vergen = "5"
[dev-dependencies]
tempfile = "3"

11
Dockerfile Normal file
View File

@ -0,0 +1,11 @@
# syntax=docker/dockerfile:1
FROM rust:alpine as builder
COPY . .
RUN apk add --no-cache file make musl-dev \
&& cargo install --path . \
&& strip /usr/local/cargo/bin/mangadex-home
FROM alpine:latest
COPY --from=builder /usr/local/cargo/bin/mangadex-home /usr/local/bin/mangadex-home
CMD ["mangadex-home"]

107
README.md
View File

@ -2,75 +2,86 @@ A Rust implementation of a MangaDex@Home client.
This client contains the following features:
- Multi-threaded
- HTTP/2 support
- No support for TLS 1.1 or 1.0
- Easy migration from the official client
- Fully compliant with MangaDex@Home specifications
- Multi-threaded, high performance, and low overhead client
- HTTP/2 support for API users, HTTP/2 only for upstream connections
- Secure and privacy oriented features:
- Only supports TLS 1.2 or newer; HTTP is not enabled by default
- Options for no logging and no metrics
- Support for on-disk XChaCha20 encryption with ephemeral key generation
- Supports an internal LFU, LRU, or a redis instance for in-memory caching
## Building
Since we use SQLx there are a few things you'll need to do. First, you'll need
to run the init cache script, which initializes the db cache at
`./cache/metadata.sqlite`. Then you'll need to add the location of that to a
`.env` file:
```sh
# In the project root
./init_cache.sh
echo "DATABASE_URL=sqlite:./cache/metadata.sqlite" >> .env
cargo build
cargo test
```
## Cache implementation
You may need to set a client secret, see Configuration for more information.
This client implements a multi-tier in-memory and on-disk LRU cache constrained
by quotas. In essence, it acts as an unified LRU, where in-memory items are
evicted and pushed into the on-disk LRU and fetching a item from the on-disk LRU
promotes it to the in-memory LRU.
# Migration
Note that the capacity of each LRU is dynamic, depending on the maximum byte
capacity that you permit each cache to be. A large item may evict multiple
smaller items to fit within this constraint, for example.
Migration from the official client was made to be as painless as possible. There
are caveats though:
- If you ever want to return to using the official client, you will need to
clear your cache.
- As this is an unofficial client implementation, the only support you can
probably get is from me.
Note that these quotas are closer to a rough estimate, and is not guaranteed to
be strictly below these values, so it's recommended to under set your config
values to make sure you don't exceed the actual quota.
Otherwise, the steps to migration is easy:
1. Place the binary in the same folder as your `images` folder and
`settings.yaml`.
2. Rename `images` to `cache`.
# Client implementation
This client follows a secure-first approach. As such, your statistics may report
a _ever-so-slightly_ higher-than-average failure rate. Specifically, this client
choses to:
- Not support TLS 1.1 or 1.0, which would be a primary source of
incompatibility.
- Not provide a server identification string in the header of served requests.
- HTTPS by enabled by default, HTTP is provided (and unsupported).
That being said, this client should be backwards compatibility with the official
client data and config. That means you should be able to replace the binary and
preserve all your settings and cache.
## Installation
Either build it from source or run `cargo install mangadex-home`.
## Running
Run `mangadex-home`, and make sure the advertised port is open on your firewall.
Do note that some configuration fields are required. See the next section for
details.
## Configuration
Most configuration options can be either provided on the command line, sourced
from a `.env` file, or sourced directly from the environment. Do not that the
client secret is an exception. You must provide the client secret from the
environment or from the `.env` file, as providing client secrets in a shell is a
operation security risk.
Most configuration options can be either provided on the command line or sourced
from a file named `settings.yaml` from the directory you ran the command from,
which will be created on first run.
The following options are required:
Note that the client secret (`CLIENT_SECRET`) is the only configuration option
that can only can be provided from the environment, an `.env` file, or the
`settings.yaml` file. In other words, you _cannot_ provide this value from the
command line.
- Client Secret
- Memory cache quota
- Disk cache quota
- Advertised network speed
## Special thanks
The following are optional as a default value will be set for you:
This project could not have been completed without the assistance of the
following:
- Port
- Disk cache path
#### Development Assistance (Alphabetical Order)
### Advanced configuration
- carbotaniuman#6974
- LFlair#1337
- Plykiya#1738
- Tristan 9#6752
- The Rust Discord community
This implementation prefers to act more secure by default. As a result, some
features that the official specification requires are not enabled by default.
If you don't know why these features are disabled by default, then don't enable
these, as they may generally weaken the security stance of the client for more
compatibility.
#### Beta testers
- Sending Server version string
- NigelVH#7162
---
If using the geo IP logging feature, then this product includes GeoLite2 data
created by MaxMind, available from https://www.maxmind.com.

View File

@ -3,8 +3,10 @@ use std::error::Error;
use vergen::{vergen, Config, ShaKind};
fn main() -> Result<(), Box<dyn Error>> {
// Initialize vergen stuff
let mut config = Config::default();
*config.git_mut().sha_kind_mut() = ShaKind::Short;
vergen(config)?;
Ok(())
}

View File

@ -0,0 +1 @@
insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing

9
docker-compose.yml Normal file
View File

@ -0,0 +1,9 @@
version: "3.9"
services:
mangadex-home:
build: .
ports:
- "443:443"
volumes:
- ./cache:/cache
- ./settings.yaml:/settings.yaml

14
docs/ciphers.md Normal file
View File

@ -0,0 +1,14 @@
# Ciphers
This client relies on rustls, which only supports a subset of TLS ciphers.
Specifically, only TLS 1.2 ECDSA GCM ciphers as well as all TLS 1.3 ciphers are
supported. This means that clients that only support older, more insecure
ciphers may not be able to connect to this client.
In practice, this means this client's failure rate may be higher than expected.
This is okay, and within specifications.
## Why even bother?
Well, Australia has officially banned hentai... so I gotta make sure my mates
over there won't get in trouble if I'm connecting to them.

14
docs/unstable_options.md Normal file
View File

@ -0,0 +1,14 @@
# Unstable Options
Unstable options are options that are either experimental, dangerous, for
development only, or a mix of the three. The following table describes each
option. Generally speaking, you should never need to enable these unless you
know what you're doing.
| Option | Experimental? | Dangerous? | For development? |
| -------------------------- | ------------- | ---------- | ---------------- |
| `override-upstream` | | | Yes |
| `use-lfu` | Yes | | |
| `disable-token-validation` | | Yes | Yes |
| `offline-mode` | | | Yes |
| `disable-tls` | | Yes | Yes |

View File

@ -1,10 +0,0 @@
#!/usr/bin/env bash
# This script needs to be run once in order for compile time macros to not
# complain about a missing DB
# We can trust that our program will initialize the db at runtime the same way
# as it pulls from the same file for initialization
mkdir cache
sqlite3 cache/metadata.sqlite < db_queries/init.sql

114
settings.sample.yaml Normal file
View File

@ -0,0 +1,114 @@
---
# ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢦⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢰⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⢼⣈⡆⠀⠀⠀⠀⠀⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠀⠀⠀⠀⠀⠀⣀⣀⣀⣀⡀⠀⠀⠀⢸⡧⣀⠀⣄⡠⠒⠉⠀⠀⠀⢀⡈⢑⢦⠀⠀⠀⠀⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⣴⡗⠒⠉⠉⠉⠉⠉⠀⠀⠀⠀⠈⠉⠉⠉⠑⠣⡀⠉⠚⡷⡆⠀⣀⣀⣀⠤⡺⠈⠫⠗⠢⡄⠀⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠔⠊⠁⢰⠃⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢰⠢⣧⠘⣎⠐⠠⠟⠋⠀⠀⠀⠀⢄⢸⠀⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠴⠋⠀⠀⠀⠀⡜⠀⢱⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⡆⡏⠀⠸⡄⢀⠀⠀⠪⠴⠒⠒⠚⠘⢤⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠔⠁⠀⠀⠀⠀⠀⠀⡇⠀⠀⢣⡀⢀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠀⢱⠃⠀⠀⡇⡇⠉⠖⡤⠤⠤⠤⠴⢒⠺⠀⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠋⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠱⡼⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠘⣆⡎⠀⠀⢀⡇⠙⠦⣌⣹⣏⠉⠉⠁⣀⠠⣂⠀⠀⠀
#⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡴⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢧⠀⠀⠀⣇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢰⠀⢻⣿⣿⣿⡟⠓⠤⣀⡀⠉⠓⠭⠭⠔⠒⠉⠈⡆⠀⠀
#⠀⠀⠀⠀⠀⠀⢀⣀⣀⣀⣀⣀⡸⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢣⣠⠴⢻⠀⠀⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣧⢸⡈⠉⠉⢣⠀⠀⠀⠉⠑⠢⢄⣀⣀⡤⠖⠋⠳⡀⠀
#⠀⠀⠀⠀⠀⡔⠁⠀⢇⠀⠀⠀⠈⠳⡀⠀⠀⢀⠂⠀⠀⠀⠀⢀⠃⢠⠏⠀⠀⠈⡆⢠⢷⠀⠀⠀⠀⠀⠀⠀⠀⢰⠀⣿⡎⡇⠀⡠⠈⡇⠀⠀⠀⠀⠀⠀⠈⣦⠃⠀⠀⠈⢦⠀
#⠀⠀⠀⠀⡰⠃⠀⠀⠘⡄⠀⠀⠀⠀⢱⠀⠀⡜⠀⠀⠀⠀⠀⢸⠐⡮⢀⠀⠀⠀⢱⡸⡄⢧⠀⠀⠀⡀⠀⠀⠀⢸⣇⡿⢳⡧⠊⠀⠀⡇⡇⠀⠆⠀⠀⠀⠀⢱⡀⡠⠂⠀⠈⡇
#⠀⠀⠀⢰⠁⠀⠀⡀⠀⠘⡄⠀⠀⠀⢸⠀⢠⠃⠀⠀⢀⠀⠀⣼⢠⠃⠀⠁⠢⢄⠀⠳⣇⠀⠣⡀⠀⢣⡀⠀⠀⢸⢹⡧⢺⠀⠀⠀⠀⡷⢹⠀⢠⠀⠀⠀⠀⠈⡏⠳⡄⠀⠀⢳
#⠀⠀⠀⢸⠀⠀⠀⠈⠢⢄⠈⣢⠔⠒⠙⠒⢼⠀⠀⠀⢸⠀⢀⠿⣸⠀⠀⠀⠀⠀⠉⠢⢌⠀⠀⠈⠉⠒⠯⠉⠒⠚⠚⣠⠾⠶⠿⠷⢶⣧⡈⢆⢸⠀⠀⠀⠀⠀⢣⠀⢱⠀⠀⡎
#⠀⠀⠀⢸⠀⠀⠸⠀⠀⠀⢹⠁⠀⠀⠀⡀⡞⠀⠀⠀⢸⠀⢸⠀⢿⠀⣠⣴⠾⠛⠛⠓⠢⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠁⡀⠘⣿⣶⣄⠈⢻⣆⢻⠀⡆⠀⠀⠀⢸⠠⣸⠂⢠⠃
#⠀⠀⠀⠘⡄⠀⠀⢡⠀⠀⡼⠀⡠⠐⠁⠀⡇⠀⠀⠀⠈⡆⢸⠀⢨⡾⠋⠀⠀⢻⣿⣿⣷⣦⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⡿⢻⣿⣿⠻⣇⠀⢻⣾⠀⡇⠀⠀⠀⠈⡞⠁⣠⠋⠀
#⠀⠀⠀⠀⠱⡀⠀⠀⠑⢄⡑⣅⠀⠀⠀⠀⡇⠀⠀⠀⠀⠘⣼⠀⣿⠁⠀⢠⡷⢾⣿⣿⡟⠛⡇⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠈⠛⠁⠀⢸⠀⠈⢹⡸⠀⠀⠀⠀⠀⡧⠚⡇⠀⠀
#⠀⠀⠀⠀⠀⠈⠢⢄⣀⠀⠈⠉⢑⣶⠴⠒⡇⠀⠀⠀⠀⠀⡟⠧⡇⠀⠀⠸⡁⠀⠙⠋⠀⠀⡞⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⠦⣤⣤⡴⠃⠀⠀⠼⠣⡀⠀⡇⠀⠀⡷⢄⣇⠀⠀
#⠀⠀⠀⠀⠀⠀⢀⠞⠀⡏⠉⢉⣵⣳⠀⠀⡇⠀⠀⠀⠀⠀⢱⠀⠁⠀⠀⠀⠑⠤⠤⡠⠤⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠠⠡⢁⠀⠀⢱⠀⡇⠀⢠⡇⠀⢻⡀⠀
#⠀⠀⠀⠀⠀⢠⠎⠀⢸⣇⡔⠉⠀⢹⡀⠀⡇⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠐⡀⢀⠀⠄⡀⠀⠀⠀⠀⠀⠀⢀⣀⣀⣀⣀⣀⣀⣤⠈⠠⠡⠁⠂⠌⠀⢸⠀⡗⠀⢸⠇⠀⢀⡇⠀
#⠀⠀⠀⠀⢠⠃⠀⠀⡎⡏⠀⠀⠀⠀⡇⠀⡇⠀⡆⠀⠀⠀⠘⡄⠀⠀⠈⠌⠠⠂⠌⠐⠀⢀⠎⠉⠒⠉⠉⠉⠉⠙⠛⠧⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⡼⢀⠇⠀⢸⣀⠴⠋⢱⠀
#⠀⠀⠀⢠⠃⠀⠀⢰⠙⢣⡀⠀⠀⣇⡇⠀⢧⠀⡇⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⡿⠀⠀⠀⠀⠀⢀⡼⠃⡜⠀⠀⡏⢱⠀⠐⠈⡇
#⠀⠀⢠⢃⠀⠀⢠⠇⠀⢸⡉⠓⠶⢿⠃⠀⢸⠀⡇⠀⠀⠀⡄⢹⠀⠀⠀⠀⠀⠀⠀⠀⠀⠸⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡴⠁⠀⠀⠀⢀⣴⡟⠁⡰⠁⠀⢰⢧⠈⡆⠀⠇⢇
#⠀⠀⡜⡄⠀⢀⡎⠀⠀⠀⡇⠀⠀⢸⠀⠀⠈⡇⣿⠀⠀⠀⢧⠈⡗⠢⢤⣀⠀⠀⠀⠀⠀⠀⠙⢄⡀⠀⠀⠀⠀⠀⣀⡤⠚⠁⢀⣠⡤⠞⠋⢀⠇⡴⠁⠀⠀⠾⣼⠀⢱⠀⢸⢸
#⠀⠀⡇⡇⠀⡜⠀⠀⠀⠀⡇⠀⠀⣾⠀⠀⠀⢹⡏⡆⠀⠀⢸⢆⠸⡄⠀⠀⠉⢑⣦⣤⡀⠀⠀⠀⠉⠑⠒⣒⣋⣉⣡⣤⠒⠊⠉⡇⠀⠀⠀⣾⣊⠀⠀⠀⠈⢠⢻⠀⢸⣀⣿⡜
#⠀⠀⣷⢇⢸⠁⠀⠀⠀⠀⡇⠀⢰⢹⠀⠀⠀⠀⢿⠹⡀⠀⠸⡀⠳⣵⡀⡠⠚⠉⠙⢿⣿⣷⣦⣀⠀⠀⠀⣱⣿⣿⠀⠈⠉⠲⣄⢧⣠⠒⢌⡇⡠⣃⣀⡠⠔⠁⠀⡇⢸⡟⢸⠇
#⠀⠀⢻⠘⣼⠀⠀⠀⠀⢰⠁⣠⠃⢸⠀⠀⠀⠀⠘⠀⠳⡀⠀⡇⠀⢀⠟⠦⠀⡀⠀⢸⣛⣻⣿⣿⣿⣶⣭⣿⣿⣻⡆⠀⠀⠀⠈⢦⠸⣽⢝⠿⡫⡁⢸⡇⠀⠀⠀⢣⠘⠁⠘⠀
#⠀⠀⠘⠆⠸⠄⠀⠀⢠⠏⡰⠁⠀⡞⠀⠀⠀⠀⠀⠀⠀⠙⢄⣸⣶⣷⣶⣶⣶⣤⣤⣼⣿⣽⣯⣿⣿⣿⣷⣾⣿⣿⣿⣾⣤⣴⣶⣾⣷⣇⠺⠤⠕⠈⢉⠇⠀⠀⠀⠘⡄
#
# MangaDex@Home configuration file
#
# Thanks for contributing to MangaDex@Home, friend!
# Beat up a pineapple, and don't forget your AsaCoco!
#
# Default values are commented out.
# The size in mebibytes of the cache You can use megabytes instead in a pinch,
# but just know the two are **NOT** the same.
max_cache_size_in_mebibytes: 0
server_settings:
# The client secret. Keep this secret at all costs :P
secret: suichan wa kyou mo kawaii!
# The port for the webserver to listen on. 443 is recommended for max appeal.
# port: 443
# This controls the value the server receives for your upload speed.
external_max_kilobits_per_second: 1
#
# Advanced settings
#
# The external hostname to listen on. Keep this at 0.0.0.0 unless you know
# what you're doing.
# hostname: 0.0.0.0
# The external port to broadcast to the backend. Keep this at 0 unless you
# know what you're doing. 0 means broadcast the same value as `port`.
# external_port: 0
# How long to wait at most for the graceful shutdown (Ctrl-C or SIGINT).
# graceful_shutdown_wait_seconds: 60
# The external ip to broadcast to the webserver. The default of null (~) means
# the backend will infer it from where it was sent from, which may fail in the
# presence of multiple IPs.
# external_ip: ~
# Settings for geo IP analytics
metric_settings:
# Whether to enable geo IP analytics
# enable_geoip: false
# geoip_license_key: none
# These settings are unique to the Rust client, and may be ignored or behave
# differently from the official client.
extended_options:
# Which cache type to use. By default, this is `on_disk`, but one can select
# `lfu`, `lru`, or `redis` to use a LFU, LRU, or redis instance in addition
# to the on-disk cache to improve lookup times. Generally speaking, using one
# is almost always better, but by how much depends on how much memory you let
# the node use, how large is your node, and which caching implementation you
# use.
# cache_type: on_disk
# The redis url to connect with. Does nothing if the cache type isn't`redis`.
# redis_url: "redis://127.0.0.1/"
# The amount of memory the client should use when using an in-memory cache.
# This does nothing if only the on-disk cache is used.
# memory_quota: 0
# Whether or not to expose a prometheus endpoint at /metrics. This is a
# completely open endpoint, so best practice is to make sure this is only
# readable from the internal network.
# enable_metrics: false
# If you'd like to specify a different path location for the cache, you can do
# so here.
# cache_path: "./cache"
# What logging level to use. Valid options are "error", "warn", "info",
# "debug", "trace", and "off", which disables logging.
# logging_level: info
# Enables disk encryption where the key is stored in memory. In other words,
# when the MD@H program is stopped, all cached files are irrecoverable.
# Practically speaking, this isn't all too useful (and definitely hurts
# performance), but for peace of mind, this may be useful.
# ephemeral_disk_encryption: false

75
sqlx-data.json Normal file
View File

@ -0,0 +1,75 @@
{
"db": "SQLite",
"24b536161a0ed44d0595052ad069c023631ffcdeadb15a01ee294717f87cdd42": {
"query": "update Images set accessed = ? where id = ?",
"describe": {
"columns": [],
"parameters": {
"Right": 2
},
"nullable": []
}
},
"2a8aa6dd2c59241a451cd73f23547d0e94930e35654692839b5d11bb8b87703e": {
"query": "insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing",
"describe": {
"columns": [],
"parameters": {
"Right": 3
},
"nullable": []
}
},
"311721fec7824c2fc3ecf53f714949a49245c11a6b622efdb04fdac24be41ba3": {
"query": "SELECT IFNULL(SUM(size), 0) AS size FROM Images",
"describe": {
"columns": [
{
"name": "size",
"ordinal": 0,
"type_info": "Int"
}
],
"parameters": {
"Right": 0
},
"nullable": [
true
]
}
},
"44234188e873a467ecf2c60dfb4731011e0b7afc4472339ed2ae33aee8b0c9dd": {
"query": "select id, size from Images order by accessed asc limit 1000",
"describe": {
"columns": [
{
"name": "id",
"ordinal": 0,
"type_info": "Text"
},
{
"name": "size",
"ordinal": 1,
"type_info": "Int64"
}
],
"parameters": {
"Right": 0
},
"nullable": [
false,
false
]
}
},
"a60501a30fd75b2a2a59f089e850343af075436a5c543a267ecb4fa841593ce9": {
"query": "create table if not exists Images(\n id varchar primary key not null,\n size integer not null,\n accessed timestamp not null default CURRENT_TIMESTAMP\n);\ncreate index if not exists Images_accessed on Images(accessed);",
"describe": {
"columns": [],
"parameters": {
"Right": 0
},
"nullable": []
}
}
}

152
src/cache/compat.rs vendored Normal file
View File

@ -0,0 +1,152 @@
//! These structs have alternative deserialize and serializations
//! implementations to assist reading from the official client file format.
use std::str::FromStr;
use chrono::{DateTime, FixedOffset};
use serde::de::{Unexpected, Visitor};
use serde::{Deserialize, Serialize};
use super::ImageContentType;
#[derive(Copy, Clone, Deserialize)]
pub struct LegacyImageMetadata {
pub(crate) content_type: Option<LegacyImageContentType>,
pub(crate) size: Option<u32>,
pub(crate) last_modified: Option<LegacyDateTime>,
}
#[derive(Copy, Clone, Serialize)]
pub struct LegacyDateTime(pub DateTime<FixedOffset>);
impl<'de> Deserialize<'de> for LegacyDateTime {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct LegacyDateTimeVisitor;
impl<'de> Visitor<'de> for LegacyDateTimeVisitor {
type Value = LegacyDateTime;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "a valid image type")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
DateTime::parse_from_rfc2822(v)
.map(LegacyDateTime)
.map_err(|_| E::invalid_value(Unexpected::Str(v), &"a valid image type"))
}
}
deserializer.deserialize_str(LegacyDateTimeVisitor)
}
}
#[derive(Copy, Clone)]
pub struct LegacyImageContentType(pub ImageContentType);
impl<'de> Deserialize<'de> for LegacyImageContentType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct LegacyImageContentTypeVisitor;
impl<'de> Visitor<'de> for LegacyImageContentTypeVisitor {
type Value = LegacyImageContentType;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "a valid image type")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
ImageContentType::from_str(v)
.map(LegacyImageContentType)
.map_err(|_| E::invalid_value(Unexpected::Str(v), &"a valid image type"))
}
}
deserializer.deserialize_str(LegacyImageContentTypeVisitor)
}
}
#[cfg(test)]
mod parse {
use std::error::Error;
use chrono::DateTime;
use crate::cache::ImageContentType;
use super::LegacyImageMetadata;
#[test]
fn from_valid_legacy_format() -> Result<(), Box<dyn Error>> {
let legacy_header = r#"{"content_type":"image/jpeg","last_modified":"Sat, 10 Apr 2021 10:55:22 GMT","size":117888}"#;
let metadata: LegacyImageMetadata = serde_json::from_str(legacy_header)?;
assert_eq!(
metadata.content_type.map(|v| v.0),
Some(ImageContentType::Jpeg)
);
assert_eq!(metadata.size, Some(117_888));
assert_eq!(
metadata.last_modified.map(|v| v.0),
Some(DateTime::parse_from_rfc2822(
"Sat, 10 Apr 2021 10:55:22 GMT"
)?)
);
Ok(())
}
#[test]
fn empty_metadata() -> Result<(), Box<dyn Error>> {
let legacy_header = "{}";
let metadata: LegacyImageMetadata = serde_json::from_str(legacy_header)?;
assert!(metadata.content_type.is_none());
assert!(metadata.size.is_none());
assert!(metadata.last_modified.is_none());
Ok(())
}
#[test]
fn invalid_image_mime_value() {
let legacy_header = r#"{"content_type":"image/not-a-real-image"}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn invalid_date_time() {
let legacy_header = r#"{"last_modified":"idk last tuesday?"}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn invalid_size() {
let legacy_header = r#"{"size":-1}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn wrong_image_type() {
let legacy_header = r#"{"content_type":25}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
#[test]
fn wrong_date_time_type() {
let legacy_header = r#"{"last_modified":false}"#;
assert!(serde_json::from_str::<LegacyImageMetadata>(legacy_header).is_err());
}
}

597
src/cache/disk.rs vendored
View File

@ -1,30 +1,39 @@
//! Low memory caching stuff
use std::path::PathBuf;
use std::convert::TryFrom;
use std::hint::unreachable_unchecked;
use std::os::unix::prelude::OsStrExt;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use futures::StreamExt;
use log::{error, warn, LevelFilter};
use log::LevelFilter;
use md5::digest::generic_array::GenericArray;
use md5::{Digest, Md5};
use sodiumoxide::hex;
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::{ConnectOptions, SqlitePool};
use tokio::fs::remove_file;
use sqlx::{ConnectOptions, Sqlite, SqlitePool, Transaction};
use tokio::fs::{create_dir_all, remove_file, rename, File};
use tokio::join;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, info, instrument, warn};
use super::{
BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata,
};
use crate::units::Bytes;
use super::{Cache, CacheEntry, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata};
#[derive(Debug)]
pub struct DiskCache {
disk_path: PathBuf,
disk_cur_size: AtomicU64,
db_update_channel_sender: Sender<DbMessage>,
}
#[derive(Debug)]
enum DbMessage {
Get(Arc<PathBuf>),
Put(Arc<PathBuf>, u64),
@ -34,29 +43,47 @@ impl DiskCache {
/// Constructs a new low memory cache at the provided path and capacity.
/// This internally spawns a task that will wait for filesystem
/// notifications when a file has been written.
pub async fn new(disk_max_size: u64, disk_path: PathBuf) -> Arc<Self> {
let (db_tx, db_rx) = channel(128);
pub async fn new(disk_max_size: Bytes, disk_path: PathBuf) -> Arc<Self> {
if let Err(e) = create_dir_all(&disk_path).await {
error!("Failed to create cache folder: {}", e);
}
let cache_path = disk_path.to_string_lossy();
// Migrate old to new path
if rename(
format!("{}/metadata.sqlite", cache_path),
format!("{}/metadata.db", cache_path),
)
.await
.is_ok()
{
info!("Found old metadata file, migrating to new location.");
}
let db_pool = {
let db_url = format!("sqlite:{}/metadata.sqlite", disk_path.to_string_lossy());
let db_url = format!("sqlite:{}/metadata.db", cache_path);
let mut options = SqliteConnectOptions::from_str(&db_url)
.unwrap()
.create_if_missing(true);
options.log_statements(LevelFilter::Trace);
let db = SqlitePool::connect_with(options).await.unwrap();
// Run db init
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut db.acquire().await.unwrap())
.await
.unwrap();
db
SqlitePool::connect_with(options).await.unwrap()
};
Self::from_db_pool(db_pool, disk_max_size, disk_path).await
}
async fn from_db_pool(pool: SqlitePool, disk_max_size: Bytes, disk_path: PathBuf) -> Arc<Self> {
let (db_tx, db_rx) = channel(128);
// Run db init
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut pool.acquire().await.unwrap())
.await
.unwrap();
// This is intentional.
#[allow(clippy::cast_sign_loss)]
let disk_cur_size = {
let mut conn = db_pool.acquire().await.unwrap();
let mut conn = pool.acquire().await.unwrap();
sqlx::query!("SELECT IFNULL(SUM(size), 0) AS size FROM Images")
.fetch_one(&mut conn)
.await
@ -74,12 +101,25 @@ impl DiskCache {
tokio::spawn(db_listener(
Arc::clone(&new_self),
db_rx,
db_pool,
disk_max_size / 20 * 19,
pool,
disk_max_size.get() as u64 / 20 * 19,
));
new_self
}
#[cfg(test)]
fn in_memory() -> (Self, Receiver<DbMessage>) {
let (db_tx, db_rx) = channel(128);
(
Self {
disk_path: PathBuf::new(),
disk_cur_size: AtomicU64::new(0),
db_update_channel_sender: db_tx,
},
db_rx,
)
}
}
/// Spawn a new task that will listen for updates to the db, pruning if the size
@ -90,9 +130,10 @@ async fn db_listener(
db_pool: SqlitePool,
max_on_disk_size: u64,
) {
// This is in a receiver stream to process up to 128 simultaneous db updates
// in one transaction
let mut recv_stream = ReceiverStream::new(db_rx).ready_chunks(128);
while let Some(messages) = recv_stream.next().await {
let now = chrono::Utc::now();
let mut transaction = match db_pool.begin().await {
Ok(transaction) => transaction,
Err(e) => {
@ -100,38 +141,12 @@ async fn db_listener(
continue;
}
};
for message in messages {
match message {
DbMessage::Get(entry) => {
let key = entry.as_os_str().to_str();
let query =
sqlx::query!("update Images set accessed = ? where id = ?", now, key)
.execute(&mut transaction)
.await;
if let Err(e) = query {
warn!("Failed to update timestamp in db for {:?}: {}", key, e);
}
}
DbMessage::Get(entry) => handle_db_get(&entry, &mut transaction).await,
DbMessage::Put(entry, size) => {
let key = entry.as_os_str().to_str();
{
// This is intentional.
#[allow(clippy::cast_possible_wrap)]
let size = size as i64;
let query = sqlx::query!(
"insert into Images (id, size, accessed) values (?, ?, ?) on conflict do nothing",
key,
size,
now,
)
.execute(&mut transaction)
.await;
if let Err(e) = query {
warn!("Failed to add {:?} to db: {}", key, e);
}
}
cache.disk_cur_size.fetch_add(size, Ordering::Release);
handle_db_put(&entry, size, &cache, &mut transaction).await;
}
}
}
@ -145,21 +160,10 @@ async fn db_listener(
let on_disk_size = (cache.disk_cur_size.load(Ordering::Acquire) + 4095) / 4096 * 4096;
if on_disk_size >= max_on_disk_size {
let mut conn = match db_pool.acquire().await {
Ok(conn) => conn,
Err(e) => {
error!(
"Failed to get a DB connection and cannot prune disk cache: {}",
e
);
continue;
}
};
let items = {
let request =
sqlx::query!("select id, size from Images order by accessed asc limit 1000")
.fetch_all(&mut conn)
.fetch_all(&db_pool)
.await;
match request {
Ok(items) => items,
@ -176,8 +180,9 @@ async fn db_listener(
let mut size_freed = 0;
#[allow(clippy::cast_sign_loss)]
for item in items {
debug!("deleting file due to exceeding cache size");
size_freed += item.size as u64;
tokio::spawn(remove_file(item.id));
tokio::spawn(remove_file_handler(item.id));
}
cache.disk_cur_size.fetch_sub(size_freed, Ordering::Release);
@ -185,6 +190,126 @@ async fn db_listener(
}
}
/// Returns if a file was successfully deleted.
async fn remove_file_handler(key: String) -> bool {
let error = if let Err(e) = remove_file(&key).await {
e
} else {
return true;
};
if error.kind() != std::io::ErrorKind::NotFound {
warn!("Failed to delete file `{}` from cache: {}", &key, error);
return false;
}
if let Ok(bytes) = hex::decode(&key) {
if bytes.len() != 16 {
warn!("Failed to delete file `{}`; invalid hash size.", &key);
return false;
}
let hash = Md5Hash(*GenericArray::from_slice(&bytes));
let path: PathBuf = hash.into();
if let Err(e) = remove_file(&path).await {
warn!(
"Failed to delete file `{}` from cache: {}",
path.to_string_lossy(),
e
);
false
} else {
true
}
} else {
warn!("Failed to delete file `{}`; not a md5hash.", &key);
false
}
}
#[instrument(level = "debug", skip(transaction))]
async fn handle_db_get(entry: &Path, transaction: &mut Transaction<'_, Sqlite>) {
let key = entry.as_os_str().to_str();
let now = chrono::Utc::now();
let query = sqlx::query!("update Images set accessed = ? where id = ?", now, key)
.execute(transaction)
.await;
if let Err(e) = query {
warn!("Failed to update timestamp in db for {:?}: {}", key, e);
}
}
#[instrument(level = "debug", skip(transaction, cache))]
async fn handle_db_put(
entry: &Path,
size: u64,
cache: &DiskCache,
transaction: &mut Transaction<'_, Sqlite>,
) {
let key = entry.as_os_str().to_str();
let now = chrono::Utc::now();
// This is intentional.
#[allow(clippy::cast_possible_wrap)]
let casted_size = size as i64;
let query = sqlx::query_file!("./db_queries/insert_image.sql", key, casted_size, now)
.execute(transaction)
.await;
if let Err(e) = query {
warn!("Failed to add to db: {}", e);
}
cache.disk_cur_size.fetch_add(size, Ordering::Release);
}
/// Represents a Md5 hash that can be converted to and from a path. This is used
/// for compatibility with the official client, where the image id and on-disk
/// path is determined by file path.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
struct Md5Hash(GenericArray<u8, <Md5 as md5::Digest>::OutputSize>);
impl Md5Hash {
fn to_hex_string(self) -> String {
format!("{:x}", self.0)
}
}
impl TryFrom<&Path> for Md5Hash {
type Error = ();
fn try_from(path: &Path) -> Result<Self, Self::Error> {
let mut iter = path.iter();
let file_name = iter.next_back().ok_or(())?;
let chapter_hash = iter.next_back().ok_or(())?;
let is_data_saver = iter.next_back().ok_or(())? == "saver";
let mut hasher = Md5::new();
if is_data_saver {
hasher.update("saver");
}
hasher.update(chapter_hash.as_bytes());
hasher.update(".");
hasher.update(file_name.as_bytes());
Ok(Self(hasher.finalize()))
}
}
impl From<Md5Hash> for PathBuf {
fn from(hash: Md5Hash) -> Self {
let hex_value = hash.to_hex_string();
let path = hex_value[0..3]
.chars()
.rev()
.map(|char| Self::from(char.to_string()))
.reduce(|first, second| first.join(second));
match path {
Some(p) => p.join(hex_value),
None => unsafe { unreachable_unchecked() }, // literally not possible
}
}
}
#[async_trait]
impl Cache for DiskCache {
async fn get(
@ -196,12 +321,33 @@ impl Cache for DiskCache {
let path = Arc::new(self.disk_path.clone().join(PathBuf::from(key)));
let path_0 = Arc::clone(&path);
tokio::spawn(async move { channel.send(DbMessage::Get(path_0)).await });
let legacy_path = Md5Hash::try_from(path_0.as_path())
.map(PathBuf::from)
.map(|path| self.disk_path.clone().join(path))
.map(Arc::new);
super::fs::read_file(&path).await.map(|res| {
let (inner, maybe_header, metadata) = res?;
CacheStream::new(inner, maybe_header)
.map(|stream| (stream, metadata))
// Get file and path of first existing location path
let (file, path) = if let Ok(legacy_path) = legacy_path {
let maybe_files = join!(
File::open(legacy_path.as_path()),
File::open(path.as_path()),
);
match maybe_files {
(Ok(f), _) => Some((f, legacy_path)),
(_, Ok(f)) => Some((f, path)),
_ => return None,
}
} else {
File::open(path.as_path())
.await
.ok()
.map(|file| (file, path))
}?;
tokio::spawn(async move { channel.send(DbMessage::Get(path)).await });
super::fs::read_file(file).await.map(|res| {
res.map(|(stream, _, metadata)| (stream, metadata))
.map_err(|_| CacheError::DecryptionFailure)
})
}
@ -209,9 +355,9 @@ impl Cache for DiskCache {
async fn put(
&self,
key: CacheKey,
image: BoxedImageStream,
image: bytes::Bytes,
metadata: ImageMetadata,
) -> Result<CacheStream, CacheError> {
) -> Result<(), CacheError> {
let channel = self.db_update_channel_sender.clone();
let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key)));
@ -224,9 +370,6 @@ impl Cache for DiskCache {
super::fs::write_file(&path, key, image, metadata, db_callback, None)
.await
.map_err(CacheError::from)
.and_then(|(inner, maybe_header)| {
CacheStream::new(inner, maybe_header).map_err(|_| CacheError::DecryptionFailure)
})
}
}
@ -235,10 +378,10 @@ impl CallbackCache for DiskCache {
async fn put_with_on_completed_callback(
&self,
key: CacheKey,
image: BoxedImageStream,
image: bytes::Bytes,
metadata: ImageMetadata,
on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>,
) -> Result<CacheStream, CacheError> {
on_complete: Sender<CacheEntry>,
) -> Result<(), CacheError> {
let channel = self.db_update_channel_sender.clone();
let path = Arc::new(self.disk_path.clone().join(PathBuf::from(&key)));
@ -252,8 +395,298 @@ impl CallbackCache for DiskCache {
super::fs::write_file(&path, key, image, metadata, db_callback, Some(on_complete))
.await
.map_err(CacheError::from)
.and_then(|(inner, maybe_header)| {
CacheStream::new(inner, maybe_header).map_err(|_| CacheError::DecryptionFailure)
})
}
}
#[cfg(test)]
mod db_listener {
use super::{db_listener, DbMessage};
use crate::DiskCache;
use futures::TryStreamExt;
use sqlx::{Row, SqlitePool};
use std::error::Error;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc::channel;
#[tokio::test]
async fn can_handle_multiple_events() -> Result<(), Box<dyn Error>> {
let (mut cache, rx) = DiskCache::in_memory();
let (mut tx, _) = channel(1);
// Swap the tx with the new one, else the receiver will never end
std::mem::swap(&mut cache.db_update_channel_sender, &mut tx);
assert_eq!(tx.capacity(), 128);
let cache = Arc::new(cache);
let db = SqlitePool::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&db)
.await?;
// Populate the queue with messages
for c in 'a'..='z' {
tx.send(DbMessage::Put(Arc::new(PathBuf::from(c.to_string())), 10))
.await?;
tx.send(DbMessage::Get(Arc::new(PathBuf::from(c.to_string()))))
.await?;
}
// Explicitly close the channel so that the listener terminates
std::mem::drop(tx);
db_listener(cache, rx, db.clone(), u64::MAX).await;
let count = Arc::new(AtomicUsize::new(0));
sqlx::query("select * from Images")
.fetch(&db)
.try_for_each_concurrent(None, |row| {
let count = Arc::clone(&count);
async move {
assert_eq!(row.get::<i32, _>("size"), 10);
count.fetch_add(1, Ordering::Release);
Ok(())
}
})
.await?;
assert_eq!(count.load(Ordering::Acquire), 26);
Ok(())
}
}
#[cfg(test)]
mod remove_file_handler {
use std::error::Error;
use tempfile::tempdir;
use tokio::fs::{create_dir_all, remove_dir_all};
use super::{remove_file_handler, File};
#[tokio::test]
async fn should_not_panic_on_invalid_path() {
assert!(!remove_file_handler("/this/is/a/non-existent/path/".to_string()).await);
}
#[tokio::test]
async fn should_not_panic_on_invalid_hash() {
assert!(!remove_file_handler("68b329da9893e34099c7d8ad5cb9c940".to_string()).await);
}
#[tokio::test]
async fn should_not_panic_on_malicious_hashes() {
assert!(!remove_file_handler("68b329da9893e34".to_string()).await);
assert!(
!remove_file_handler("68b329da9893e34099c7d8ad5cb9c940aaaaaaaaaaaaaaaaaa".to_string())
.await
);
}
#[tokio::test]
async fn should_delete_existing_file() -> Result<(), Box<dyn Error>> {
let temp_dir = tempdir()?;
let mut dir_path = temp_dir.path().to_path_buf();
dir_path.push("abc123.png");
// create a file, it can be empty
File::create(&dir_path).await?;
assert!(remove_file_handler(dir_path.to_string_lossy().into_owned()).await);
Ok(())
}
#[tokio::test]
async fn should_delete_existing_hash() -> Result<(), Box<dyn Error>> {
create_dir_all("b/8/6").await?;
File::create("b/8/6/68b329da9893e34099c7d8ad5cb9c900").await?;
assert!(remove_file_handler("68b329da9893e34099c7d8ad5cb9c900".to_string()).await);
remove_dir_all("b").await?;
Ok(())
}
}
#[cfg(test)]
mod disk_cache {
use std::error::Error;
use std::path::PathBuf;
use std::sync::atomic::Ordering;
use chrono::Utc;
use sqlx::SqlitePool;
use crate::units::Bytes;
use super::DiskCache;
#[tokio::test]
async fn db_is_initialized() -> Result<(), Box<dyn Error>> {
let conn = SqlitePool::connect("sqlite::memory:").await?;
let _cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await;
let res = sqlx::query("select * from Images").execute(&conn).await;
assert!(res.is_ok());
Ok(())
}
#[tokio::test]
async fn db_initializes_empty() -> Result<(), Box<dyn Error>> {
let conn = SqlitePool::connect("sqlite::memory:").await?;
let cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await;
assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 0);
Ok(())
}
#[tokio::test]
async fn db_can_load_from_existing() -> Result<(), Box<dyn Error>> {
let conn = SqlitePool::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&conn)
.await?;
let now = Utc::now();
sqlx::query_file!("./db_queries/insert_image.sql", "a", 4, now)
.execute(&conn)
.await?;
let now = Utc::now();
sqlx::query_file!("./db_queries/insert_image.sql", "b", 15, now)
.execute(&conn)
.await?;
let cache = DiskCache::from_db_pool(conn.clone(), Bytes(1000), PathBuf::new()).await;
assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 19);
Ok(())
}
}
#[cfg(test)]
mod db {
use chrono::{DateTime, Utc};
use sqlx::{Connection, Row, SqliteConnection};
use std::error::Error;
use super::{handle_db_get, handle_db_put, DiskCache, FromStr, Ordering, PathBuf, StreamExt};
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn get() -> Result<(), Box<dyn Error>> {
let (cache, _) = DiskCache::in_memory();
let path = PathBuf::from_str("a/b/c")?;
let mut conn = SqliteConnection::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut conn)
.await?;
// Add an entry
let mut transaction = conn.begin().await?;
handle_db_put(&path, 10, &cache, &mut transaction).await;
transaction.commit().await?;
let time_fence = Utc::now();
let mut transaction = conn.begin().await?;
handle_db_get(&path, &mut transaction).await;
transaction.commit().await?;
let mut rows: Vec<_> = sqlx::query("select * from Images")
.fetch(&mut conn)
.collect()
.await;
assert_eq!(rows.len(), 1);
let entry = rows.pop().unwrap()?;
assert!(time_fence < entry.get::<'_, DateTime<Utc>, _>("accessed"));
Ok(())
}
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn put() -> Result<(), Box<dyn Error>> {
let (cache, _) = DiskCache::in_memory();
let path = PathBuf::from_str("a/b/c")?;
let mut conn = SqliteConnection::connect("sqlite::memory:").await?;
sqlx::query_file!("./db_queries/init.sql")
.execute(&mut conn)
.await?;
let mut transaction = conn.begin().await?;
let transaction_time = Utc::now();
handle_db_put(&path, 10, &cache, &mut transaction).await;
transaction.commit().await?;
let mut rows: Vec<_> = sqlx::query("select * from Images")
.fetch(&mut conn)
.collect()
.await;
assert_eq!(rows.len(), 1);
let entry = rows.pop().unwrap()?;
assert_eq!(entry.get::<'_, &str, _>("id"), "a/b/c");
assert_eq!(entry.get::<'_, i64, _>("size"), 10);
let accessed: DateTime<Utc> = entry.get("accessed");
assert!(transaction_time < accessed);
assert!(accessed < Utc::now());
assert_eq!(cache.disk_cur_size.load(Ordering::SeqCst), 10);
Ok(())
}
}
#[cfg(test)]
mod md5_hash {
use super::{Digest, GenericArray, Md5, Md5Hash, Path, PathBuf, TryFrom};
#[test]
fn to_cache_path() {
let hash = Md5Hash(
*GenericArray::<_, <Md5 as md5::Digest>::OutputSize>::from_slice(&[
0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd,
0xef, 0xab,
]),
);
assert_eq!(
PathBuf::from(hash).to_str(),
Some("c/b/a/abcdefabcdefabcdefabcdefabcdefab")
)
}
#[test]
fn from_data_path() {
let mut expected_hasher = Md5::new();
expected_hasher.update("foo.bar.png");
assert_eq!(
Md5Hash::try_from(Path::new("data/foo/bar.png")),
Ok(Md5Hash(expected_hasher.finalize()))
);
}
#[test]
fn from_data_saver_path() {
let mut expected_hasher = Md5::new();
expected_hasher.update("saverfoo.bar.png");
assert_eq!(
Md5Hash::try_from(Path::new("saver/foo/bar.png")),
Ok(Md5Hash(expected_hasher.finalize()))
);
}
#[test]
fn can_handle_long_paths() {
assert_eq!(
Md5Hash::try_from(Path::new("a/b/c/d/e/f/g/saver/foo/bar.png")),
Md5Hash::try_from(Path::new("saver/foo/bar.png")),
);
}
#[test]
fn from_invalid_paths() {
assert!(Md5Hash::try_from(Path::new("foo/bar.png")).is_err());
assert!(Md5Hash::try_from(Path::new("bar.png")).is_err());
assert!(Md5Hash::try_from(Path::new("")).is_err());
}
}

886
src/cache/fs.rs vendored

File diff suppressed because it is too large Load Diff

692
src/cache/mem.rs vendored
View File

@ -1,19 +1,74 @@
use std::borrow::Cow;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use super::{
BoxedImageStream, Cache, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata,
InnerStream, MemStream,
};
use super::{Cache, CacheEntry, CacheKey, CacheStream, CallbackCache, ImageMetadata, MemStream};
use async_trait::async_trait;
use bytes::Bytes;
use futures::FutureExt;
use lfu_cache::LfuCache;
use lru::LruCache;
use tokio::sync::mpsc::{channel, Sender};
use redis::{
Client as RedisClient, Commands, FromRedisValue, RedisError, RedisResult, ToRedisArgs,
};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::Mutex;
use tracing::warn;
type CacheValue = (Bytes, ImageMetadata, u64);
#[derive(Clone, Serialize, Deserialize)]
pub struct CacheValue {
data: Bytes,
metadata: ImageMetadata,
on_disk_size: u64,
}
impl CacheValue {
#[inline]
fn new(data: Bytes, metadata: ImageMetadata, on_disk_size: u64) -> Self {
Self {
data,
metadata,
on_disk_size,
}
}
}
impl FromRedisValue for CacheValue {
fn from_redis_value(v: &redis::Value) -> RedisResult<Self> {
use bincode::ErrorKind;
if let redis::Value::Data(data) = v {
bincode::deserialize(data).map_err(|err| match *err {
ErrorKind::Io(e) => RedisError::from(e),
ErrorKind::Custom(e) => RedisError::from((
redis::ErrorKind::ResponseError,
"bincode deserialize failed",
e,
)),
e => RedisError::from((
redis::ErrorKind::ResponseError,
"bincode deserialized failed",
e.to_string(),
)),
})
} else {
Err(RedisError::from((
redis::ErrorKind::TypeError,
"Got non data type from redis db",
)))
}
}
}
impl ToRedisArgs for CacheValue {
fn write_redis_args<W>(&self, out: &mut W)
where
W: ?Sized + redis::RedisWrite,
{
out.write_arg(&bincode::serialize(self).expect("serialization to work"));
}
}
/// Use LRU as the eviction strategy
pub type Lru = LruCache<CacheKey, CacheValue>;
@ -21,22 +76,29 @@ pub type Lru = LruCache<CacheKey, CacheValue>;
pub type Lfu = LfuCache<CacheKey, CacheValue>;
/// Adapter trait for memory cache backends
pub trait InternalMemoryCacheInitializer: InternalMemoryCache {
fn new() -> Self;
}
pub trait InternalMemoryCache: Sync + Send {
fn unbounded() -> Self;
fn get(&mut self, key: &CacheKey) -> Option<&CacheValue>;
fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>>;
fn push(&mut self, key: CacheKey, data: CacheValue);
fn pop(&mut self) -> Option<(CacheKey, CacheValue)>;
}
impl InternalMemoryCache for Lfu {
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCacheInitializer for Lfu {
#[inline]
fn unbounded() -> Self {
fn new() -> Self {
Self::unbounded()
}
}
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCache for Lfu {
#[inline]
fn get(&mut self, key: &CacheKey) -> Option<&CacheValue> {
self.get(key)
fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>> {
self.get(key).map(Cow::Borrowed)
}
#[inline]
@ -50,15 +112,19 @@ impl InternalMemoryCache for Lfu {
}
}
impl InternalMemoryCache for Lru {
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCacheInitializer for Lru {
#[inline]
fn unbounded() -> Self {
fn new() -> Self {
Self::unbounded()
}
}
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCache for Lru {
#[inline]
fn get(&mut self, key: &CacheKey) -> Option<&CacheValue> {
self.get(key)
fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>> {
self.get(key).map(Cow::Borrowed)
}
#[inline]
@ -72,13 +138,73 @@ impl InternalMemoryCache for Lru {
}
}
#[cfg(not(tarpaulin_include))]
impl InternalMemoryCache for RedisClient {
fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>> {
Commands::get(self, key).ok().map(Cow::Owned)
}
fn push(&mut self, key: CacheKey, data: CacheValue) {
if let Err(e) = Commands::set::<_, _, ()>(self, key, data) {
warn!("Failed to push to redis: {}", e);
}
}
fn pop(&mut self) -> Option<(CacheKey, CacheValue)> {
unimplemented!("redis should handle its own memory")
}
}
/// Memory accelerated disk cache. Uses the internal cache implementation in
/// memory to speed up reads.
pub struct MemoryCache<MemoryCacheImpl, ColdCache> {
inner: ColdCache,
cur_mem_size: AtomicU64,
mem_cache: Mutex<MemoryCacheImpl>,
master_sender: Sender<(CacheKey, Bytes, ImageMetadata, u64)>,
master_sender: Sender<CacheEntry>,
}
impl<MemoryCacheImpl, ColdCache> MemoryCache<MemoryCacheImpl, ColdCache>
where
MemoryCacheImpl: 'static + InternalMemoryCacheInitializer,
ColdCache: 'static + Cache,
{
pub fn new(inner: ColdCache, max_mem_size: crate::units::Bytes) -> Arc<Self> {
let (tx, rx) = channel(100);
let new_self = Arc::new(Self {
inner,
cur_mem_size: AtomicU64::new(0),
mem_cache: Mutex::new(MemoryCacheImpl::new()),
master_sender: tx,
});
tokio::spawn(internal_cache_listener(
Arc::clone(&new_self),
max_mem_size,
rx,
));
new_self
}
/// Returns an instance of the cache with the receiver for callback events
/// Really only useful for inspecting the receiver, e.g. for testing
#[cfg(test)]
pub fn new_with_receiver(
inner: ColdCache,
_: crate::units::Bytes,
) -> (Self, Receiver<CacheEntry>) {
let (tx, rx) = channel(100);
(
Self {
inner,
cur_mem_size: AtomicU64::new(0),
mem_cache: Mutex::new(MemoryCacheImpl::new()),
master_sender: tx,
},
rx,
)
}
}
impl<MemoryCacheImpl, ColdCache> MemoryCache<MemoryCacheImpl, ColdCache>
@ -86,54 +212,68 @@ where
MemoryCacheImpl: 'static + InternalMemoryCache,
ColdCache: 'static + Cache,
{
pub async fn new(inner: ColdCache, max_mem_size: u64) -> Arc<Self> {
let (tx, mut rx) = channel(100);
let new_self = Arc::new(Self {
pub fn new_with_cache(inner: ColdCache, init_mem_cache: MemoryCacheImpl) -> Self {
Self {
inner,
cur_mem_size: AtomicU64::new(0),
mem_cache: Mutex::new(MemoryCacheImpl::unbounded()),
master_sender: tx,
});
let new_self_0 = Arc::clone(&new_self);
tokio::spawn(async move {
let new_self = new_self_0;
let max_mem_size = max_mem_size / 20 * 19;
while let Some((key, bytes, metadata, size)) = rx.recv().await {
// Add to memory cache
// We can add first because we constrain our memory usage to 95%
new_self
.cur_mem_size
.fetch_add(size as u64, Ordering::Release);
new_self
.mem_cache
.lock()
.await
.push(key, (bytes, metadata, size));
// Pop if too large
while new_self.cur_mem_size.load(Ordering::Acquire) >= max_mem_size {
let popped = new_self
.mem_cache
.lock()
.await
.pop()
.map(|(key, (bytes, metadata, size))| (key, bytes, metadata, size));
if let Some((_, _, _, size)) = popped {
new_self
.cur_mem_size
.fetch_sub(size as u64, Ordering::Release);
} else {
break;
}
}
}
});
new_self
mem_cache: Mutex::new(init_mem_cache),
master_sender: channel(1).0,
}
}
}
async fn internal_cache_listener<MemoryCacheImpl, ColdCache>(
cache: Arc<MemoryCache<MemoryCacheImpl, ColdCache>>,
max_mem_size: crate::units::Bytes,
mut rx: Receiver<CacheEntry>,
) where
MemoryCacheImpl: InternalMemoryCache,
ColdCache: Cache,
{
let max_mem_size = mem_threshold(&max_mem_size);
while let Some(CacheEntry {
key,
data,
metadata,
on_disk_size,
}) = rx.recv().await
{
// Add to memory cache
// We can add first because we constrain our memory usage to 95%
cache
.cur_mem_size
.fetch_add(on_disk_size as u64, Ordering::Release);
cache
.mem_cache
.lock()
.await
.push(key, CacheValue::new(data, metadata, on_disk_size));
// Pop if too large
while cache.cur_mem_size.load(Ordering::Acquire) >= max_mem_size as u64 {
let popped = cache.mem_cache.lock().await.pop().map(
|(
key,
CacheValue {
data,
metadata,
on_disk_size,
},
)| (key, data, metadata, on_disk_size),
);
if let Some((_, _, _, size)) = popped {
cache.cur_mem_size.fetch_sub(size as u64, Ordering::Release);
} else {
break;
}
}
}
}
const fn mem_threshold(bytes: &crate::units::Bytes) -> usize {
bytes.get() / 20 * 19
}
#[async_trait]
impl<MemoryCacheImpl, ColdCache> Cache for MemoryCache<MemoryCacheImpl, ColdCache>
where
@ -146,16 +286,16 @@ where
key: &CacheKey,
) -> Option<Result<(CacheStream, ImageMetadata), super::CacheError>> {
match self.mem_cache.lock().now_or_never() {
Some(mut mem_cache) => match mem_cache.get(key).map(|(bytes, metadata, _)| {
Ok((InnerStream::Memory(MemStream(bytes.clone())), *metadata))
}) {
Some(v) => Some(v.and_then(|(inner, metadata)| {
CacheStream::new(inner, None)
.map(|v| (v, metadata))
.map_err(|_| CacheError::DecryptionFailure)
})),
None => self.inner.get(key).await,
},
Some(mut mem_cache) => {
match mem_cache.get(key).map(Cow::into_owned).map(
|CacheValue { data, metadata, .. }| {
Ok((CacheStream::Memory(MemStream(data)), metadata))
},
) {
Some(v) => Some(v),
None => self.inner.get(key).await,
}
}
None => self.inner.get(key).await,
}
}
@ -164,11 +304,419 @@ where
async fn put(
&self,
key: CacheKey,
image: BoxedImageStream,
image: Bytes,
metadata: ImageMetadata,
) -> Result<CacheStream, super::CacheError> {
) -> Result<(), super::CacheError> {
self.inner
.put_with_on_completed_callback(key, image, metadata, self.master_sender.clone())
.await
}
}
#[cfg(test)]
mod test_util {
use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap};
use super::{CacheValue, InternalMemoryCache, InternalMemoryCacheInitializer};
use crate::cache::{
Cache, CacheEntry, CacheError, CacheKey, CacheStream, CallbackCache, ImageMetadata,
};
use async_trait::async_trait;
use parking_lot::Mutex;
use tokio::io::BufReader;
use tokio::sync::mpsc::Sender;
use tokio_util::io::ReaderStream;
#[derive(Default)]
pub struct TestDiskCache(
pub Mutex<RefCell<HashMap<CacheKey, Result<(CacheStream, ImageMetadata), CacheError>>>>,
);
#[async_trait]
impl Cache for TestDiskCache {
async fn get(
&self,
key: &CacheKey,
) -> Option<Result<(CacheStream, ImageMetadata), CacheError>> {
self.0.lock().get_mut().remove(key)
}
async fn put(
&self,
key: CacheKey,
image: bytes::Bytes,
metadata: ImageMetadata,
) -> Result<(), CacheError> {
let reader = Box::pin(BufReader::new(tokio_util::io::StreamReader::new(
tokio_stream::once(Ok::<_, std::io::Error>(image)),
)));
let stream = CacheStream::Completed(ReaderStream::new(reader));
self.0.lock().get_mut().insert(key, Ok((stream, metadata)));
Ok(())
}
}
#[async_trait]
impl CallbackCache for TestDiskCache {
async fn put_with_on_completed_callback(
&self,
key: CacheKey,
data: bytes::Bytes,
metadata: ImageMetadata,
on_complete: Sender<CacheEntry>,
) -> Result<(), CacheError> {
self.put(key.clone(), data.clone(), metadata)
.await?;
let on_disk_size = data.len() as u64;
let _ = on_complete
.send(CacheEntry {
key,
data,
metadata,
on_disk_size,
})
.await;
Ok(())
}
}
#[derive(Default)]
pub struct TestMemoryCache(pub BTreeMap<CacheKey, CacheValue>);
impl InternalMemoryCacheInitializer for TestMemoryCache {
fn new() -> Self {
Self::default()
}
}
impl InternalMemoryCache for TestMemoryCache {
fn get(&mut self, key: &CacheKey) -> Option<Cow<CacheValue>> {
self.0.get(key).map(Cow::Borrowed)
}
fn push(&mut self, key: CacheKey, data: CacheValue) {
self.0.insert(key, data);
}
fn pop(&mut self) -> Option<(CacheKey, CacheValue)> {
let mut cache = BTreeMap::new();
std::mem::swap(&mut cache, &mut self.0);
let mut iter = cache.into_iter();
let ret = iter.next();
self.0 = iter.collect();
ret
}
}
}
#[cfg(test)]
mod cache_ops {
use std::error::Error;
use bytes::Bytes;
use futures::{FutureExt, StreamExt};
use crate::cache::mem::{CacheValue, InternalMemoryCache};
use crate::cache::{Cache, CacheEntry, CacheKey, CacheStream, ImageMetadata, MemStream};
use super::test_util::{TestDiskCache, TestMemoryCache};
use super::MemoryCache;
#[tokio::test]
async fn get_mem_cached() -> Result<(), Box<dyn Error>> {
let (cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
let value = CacheValue::new(bytes.clone(), metadata, 34);
// Populate the cache, need to drop the lock else it's considered locked
// when we actually call the cache
{
let mem_cache = &mut cache.mem_cache.lock().await;
mem_cache.push(key.clone(), value.clone());
}
let (stream, ret_metadata) = cache.get(&key).await.unwrap()?;
assert_eq!(metadata, ret_metadata);
if let CacheStream::Memory(MemStream(ret_stream)) = stream {
assert_eq!(bytes, ret_stream);
} else {
panic!("wrong stream type");
}
assert!(rx.recv().now_or_never().is_none());
Ok(())
}
#[tokio::test]
async fn get_disk_cached() -> Result<(), Box<dyn Error>> {
let (mut cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
{
let cache = &mut cache.inner;
cache
.put(key.clone(), bytes.clone(), metadata)
.await?;
}
let (mut stream, ret_metadata) = cache.get(&key).await.unwrap()?;
assert_eq!(metadata, ret_metadata);
assert!(matches!(stream, CacheStream::Completed(_)));
assert_eq!(stream.next().await, Some(Ok(bytes.clone())));
assert!(rx.recv().now_or_never().is_none());
Ok(())
}
// Identical to the get_disk_cached test but we hold a lock on the mem_cache
#[tokio::test]
async fn get_mem_locked() -> Result<(), Box<dyn Error>> {
let (mut cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
{
let cache = &mut cache.inner;
cache
.put(key.clone(), bytes.clone(), metadata)
.await?;
}
// intentionally not dropped
let _mem_cache = &mut cache.mem_cache.lock().await;
let (mut stream, ret_metadata) = cache.get(&key).await.unwrap()?;
assert_eq!(metadata, ret_metadata);
assert!(matches!(stream, CacheStream::Completed(_)));
assert_eq!(stream.next().await, Some(Ok(bytes.clone())));
assert!(rx.recv().now_or_never().is_none());
Ok(())
}
#[tokio::test]
async fn get_miss() {
let (cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
assert!(cache.get(&key).await.is_none());
assert!(rx.recv().now_or_never().is_none());
}
#[tokio::test]
async fn put_puts_into_disk_and_hears_from_rx() -> Result<(), Box<dyn Error>> {
let (cache, mut rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(10),
);
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
let bytes_len = bytes.len() as u64;
cache
.put(key.clone(), bytes.clone(), metadata)
.await?;
// Because the callback is supposed to let the memory cache insert the
// entry into its cache, we can check that it properly stored it on the
// disk layer by checking if we can successfully fetch it.
let (mut stream, ret_metadata) = cache.get(&key).await.unwrap()?;
assert_eq!(metadata, ret_metadata);
assert!(matches!(stream, CacheStream::Completed(_)));
assert_eq!(stream.next().await, Some(Ok(bytes.clone())));
// Check that we heard back
let cache_entry = rx
.recv()
.now_or_never()
.flatten()
.ok_or("failed to hear back from cache")?;
assert_eq!(
cache_entry,
CacheEntry {
key,
data: bytes,
metadata,
on_disk_size: bytes_len,
}
);
Ok(())
}
}
#[cfg(test)]
mod db_listener {
use std::error::Error;
use std::iter::FromIterator;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use bytes::Bytes;
use tokio::task;
use crate::cache::{Cache, CacheKey, ImageMetadata};
use super::test_util::{TestDiskCache, TestMemoryCache};
use super::{internal_cache_listener, MemoryCache};
#[tokio::test]
async fn put_into_memory() -> Result<(), Box<dyn Error>> {
let (cache, rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(0),
);
let cache = Arc::new(cache);
tokio::spawn(internal_cache_listener(
Arc::clone(&cache),
crate::units::Bytes(20),
rx,
));
// put small image into memory
let key = CacheKey("a".to_string(), "b".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcd");
cache.put(key.clone(), bytes.clone(), metadata).await?;
// let the listener run first
for _ in 0..10 {
task::yield_now().await;
}
assert_eq!(
cache.cur_mem_size.load(Ordering::SeqCst),
bytes.len() as u64
);
// Since we didn't populate the cache, fetching must be from memory, so
// this should succeed since the cache listener should push the item
// into cache
assert!(cache.get(&key).await.is_some());
Ok(())
}
#[tokio::test]
async fn pops_items() -> Result<(), Box<dyn Error>> {
let (cache, rx) = MemoryCache::<TestMemoryCache, _>::new_with_receiver(
TestDiskCache::default(),
crate::units::Bytes(0),
);
let cache = Arc::new(cache);
tokio::spawn(internal_cache_listener(
Arc::clone(&cache),
crate::units::Bytes(20),
rx,
));
// put small image into memory
let key_0 = CacheKey("a".to_string(), "b".to_string(), false);
let key_1 = CacheKey("c".to_string(), "d".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_static(b"abcde");
cache.put(key_0, bytes.clone(), metadata).await?;
cache.put(key_1, bytes.clone(), metadata).await?;
// let the listener run first
task::yield_now().await;
for _ in 0..10 {
task::yield_now().await;
}
// Items should be in cache now
assert_eq!(
cache.cur_mem_size.load(Ordering::SeqCst),
(bytes.len() * 2) as u64
);
let key_3 = CacheKey("e".to_string(), "f".to_string(), false);
let metadata = ImageMetadata {
content_type: None,
content_length: Some(1),
last_modified: None,
};
let bytes = Bytes::from_iter(b"0".repeat(16).into_iter());
let bytes_len = bytes.len();
cache.put(key_3, bytes, metadata).await?;
// let the listener run first
task::yield_now().await;
for _ in 0..10 {
task::yield_now().await;
}
// Items should have been evicted, only 16 bytes should be there now
assert_eq!(cache.cur_mem_size.load(Ordering::SeqCst), bytes_len as u64);
Ok(())
}
}
#[cfg(test)]
mod mem_threshold {
use crate::units::Bytes;
use super::mem_threshold;
#[test]
fn small_amount_works() {
assert_eq!(mem_threshold(&Bytes(100)), 95);
}
#[test]
fn large_amount_cannot_overflow() {
assert_eq!(mem_threshold(&Bytes(usize::MAX)), 17_524_406_870_024_074_020);
}
}

138
src/cache/mod.rs vendored
View File

@ -5,34 +5,46 @@ use std::str::FromStr;
use std::sync::Arc;
use std::task::{Context, Poll};
use actix_web::http::HeaderValue;
use actix_web::http::header::HeaderValue;
use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use bytes::Bytes;
use chacha20::Key;
use chrono::{DateTime, FixedOffset};
use fs::ConcurrentFsStream;
use futures::{Stream, StreamExt};
use once_cell::sync::OnceCell;
use redis::ToRedisArgs;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use sodiumoxide::crypto::secretstream::{Header, Key, Pull, Stream as SecretStream};
use thiserror::Error;
use tokio::io::AsyncRead;
use tokio::sync::mpsc::Sender;
use tokio_util::codec::{BytesCodec, FramedRead};
use tokio_util::io::ReaderStream;
pub use disk::DiskCache;
pub use fs::UpstreamError;
pub use mem::MemoryCache;
use self::compat::LegacyImageMetadata;
use self::fs::MetadataFetch;
pub static ENCRYPTION_KEY: OnceCell<Key> = OnceCell::new();
mod compat;
mod disk;
mod fs;
pub mod mem;
#[derive(PartialEq, Eq, Hash, Clone)]
#[derive(PartialEq, Eq, Hash, Clone, Debug, PartialOrd, Ord)]
pub struct CacheKey(pub String, pub String, pub bool);
impl ToRedisArgs for CacheKey {
fn write_redis_args<W>(&self, out: &mut W)
where
W: ?Sized + redis::RedisWrite,
{
out.write_arg_fmt(self);
}
}
impl Display for CacheKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.2 {
@ -60,7 +72,7 @@ impl From<&CacheKey> for PathBuf {
#[derive(Clone)]
pub struct CachedImage(pub Bytes);
#[derive(Copy, Clone, Serialize, Deserialize)]
#[derive(Copy, Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct ImageMetadata {
pub content_type: Option<ImageContentType>,
pub content_length: Option<u32>,
@ -68,7 +80,7 @@ pub struct ImageMetadata {
}
// Confirmed by Ply to be these types: https://link.eddie.sh/ZXfk0
#[derive(Copy, Clone, Serialize_repr, Deserialize_repr)]
#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum ImageContentType {
Png = 0,
@ -103,12 +115,21 @@ impl AsRef<str> for ImageContentType {
}
}
#[allow(clippy::pub_enum_variant_names)]
impl From<LegacyImageMetadata> for ImageMetadata {
fn from(legacy: LegacyImageMetadata) -> Self {
Self {
content_type: legacy.content_type.map(|v| v.0),
content_length: legacy.size,
last_modified: legacy.last_modified.map(|v| v.0),
}
}
}
#[derive(Debug)]
pub enum ImageRequestError {
InvalidContentType,
InvalidContentLength,
InvalidLastModified,
ContentType,
ContentLength,
LastModified,
}
impl ImageMetadata {
@ -124,14 +145,14 @@ impl ImageMetadata {
Err(_) => Err(InvalidContentType),
})
.transpose()
.map_err(|_| ImageRequestError::InvalidContentType)?,
.map_err(|_| ImageRequestError::ContentType)?,
content_length: content_length
.map(|header_val| {
header_val
.to_str()
.map_err(|_| ImageRequestError::InvalidContentLength)?
.map_err(|_| ImageRequestError::ContentLength)?
.parse()
.map_err(|_| ImageRequestError::InvalidContentLength)
.map_err(|_| ImageRequestError::ContentLength)
})
.transpose()?,
last_modified: last_modified
@ -139,17 +160,15 @@ impl ImageMetadata {
DateTime::parse_from_rfc2822(
header_val
.to_str()
.map_err(|_| ImageRequestError::InvalidLastModified)?,
.map_err(|_| ImageRequestError::LastModified)?,
)
.map_err(|_| ImageRequestError::InvalidLastModified)
.map_err(|_| ImageRequestError::LastModified)
})
.transpose()?,
})
}
}
type BoxedImageStream = Box<dyn Stream<Item = Result<Bytes, CacheError>> + Unpin + Send>;
#[derive(Error, Debug)]
pub enum CacheError {
#[error(transparent)]
@ -170,9 +189,9 @@ pub trait Cache: Send + Sync {
async fn put(
&self,
key: CacheKey,
image: BoxedImageStream,
image: Bytes,
metadata: ImageMetadata,
) -> Result<CacheStream, CacheError>;
) -> Result<(), CacheError>;
}
#[async_trait]
@ -189,9 +208,9 @@ impl<T: Cache> Cache for Arc<T> {
async fn put(
&self,
key: CacheKey,
image: BoxedImageStream,
image: Bytes,
metadata: ImageMetadata,
) -> Result<CacheStream, CacheError> {
) -> Result<(), CacheError> {
self.as_ref().put(key, image, metadata).await
}
}
@ -201,74 +220,43 @@ pub trait CallbackCache: Cache {
async fn put_with_on_completed_callback(
&self,
key: CacheKey,
image: BoxedImageStream,
image: Bytes,
metadata: ImageMetadata,
on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>,
) -> Result<CacheStream, CacheError>;
on_complete: Sender<CacheEntry>,
) -> Result<(), CacheError>;
}
#[async_trait]
impl<T: CallbackCache> CallbackCache for Arc<T> {
#[inline]
#[cfg(not(tarpaulin_include))]
async fn put_with_on_completed_callback(
&self,
key: CacheKey,
image: BoxedImageStream,
image: Bytes,
metadata: ImageMetadata,
on_complete: Sender<(CacheKey, Bytes, ImageMetadata, u64)>,
) -> Result<CacheStream, CacheError> {
on_complete: Sender<CacheEntry>,
) -> Result<(), CacheError> {
self.as_ref()
.put_with_on_completed_callback(key, image, metadata, on_complete)
.await
}
}
pub struct CacheStream {
inner: InnerStream,
decrypt: Option<SecretStream<Pull>>,
#[derive(PartialEq, Eq, Debug)]
pub struct CacheEntry {
key: CacheKey,
data: Bytes,
metadata: ImageMetadata,
on_disk_size: u64,
}
impl CacheStream {
pub(self) fn new(inner: InnerStream, header: Option<Header>) -> Result<Self, ()> {
Ok(Self {
inner,
decrypt: header
.and_then(|header| ENCRYPTION_KEY.get().map(|key| SecretStream::init_pull(&header, key)))
.transpose()?,
})
}
}
impl Stream for CacheStream {
type Item = CacheStreamItem;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx).map(|data| {
// False positive (`data`): https://link.eddie.sh/r1fXX
#[allow(clippy::option_if_let_else)]
if let Some(keystream) = self.decrypt.as_mut() {
data.map(|bytes_res| {
bytes_res.and_then(|bytes| {
keystream
.pull(&bytes, None)
.map(|(data, _tag)| Bytes::from(data))
.map_err(|_| UpstreamError)
})
})
} else {
data
}
})
}
}
pub(self) enum InnerStream {
Concurrent(ConcurrentFsStream),
pub enum CacheStream {
Memory(MemStream),
Completed(FramedRead<Pin<Box<dyn AsyncRead + Send>>, BytesCodec>),
Completed(ReaderStream<Pin<Box<dyn MetadataFetch + Send + Sync>>>),
}
impl From<CachedImage> for InnerStream {
impl From<CachedImage> for CacheStream {
fn from(image: CachedImage) -> Self {
Self::Memory(MemStream(image.0))
}
@ -276,17 +264,13 @@ impl From<CachedImage> for InnerStream {
type CacheStreamItem = Result<Bytes, UpstreamError>;
impl Stream for InnerStream {
impl Stream for CacheStream {
type Item = CacheStreamItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.get_mut() {
Self::Concurrent(stream) => stream.poll_next_unpin(cx),
Self::Memory(stream) => stream.poll_next_unpin(cx),
Self::Completed(stream) => stream
.poll_next_unpin(cx)
.map_ok(BytesMut::freeze)
.map_err(|_| UpstreamError),
Self::Completed(stream) => stream.poll_next_unpin(cx).map_err(|_| UpstreamError),
}
}
}

220
src/client.rs Normal file
View File

@ -0,0 +1,220 @@
use std::collections::HashMap;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
use actix_web::http::header::{HeaderMap, HeaderName, HeaderValue};
use actix_web::web::Data;
use bytes::Bytes;
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use reqwest::header::{
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, CACHE_CONTROL, CONTENT_LENGTH,
CONTENT_TYPE, LAST_MODIFIED, X_CONTENT_TYPE_OPTIONS,
};
use reqwest::{Client, Proxy, StatusCode};
use tokio::sync::watch::{channel, Receiver};
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
use crate::cache::{Cache, CacheKey, ImageMetadata};
use crate::config::{DISABLE_CERT_VALIDATION, USE_PROXY};
pub static HTTP_CLIENT: Lazy<CachingClient> = Lazy::new(|| {
let mut inner = Client::builder()
.pool_idle_timeout(Duration::from_secs(180))
.https_only(true)
.http2_prior_knowledge();
if let Some(socket_addr) = USE_PROXY.get() {
info!(
"Using {} as a proxy for upstream requests.",
socket_addr.as_str()
);
inner = inner.proxy(Proxy::all(socket_addr.as_str()).unwrap());
}
if DISABLE_CERT_VALIDATION.load(Ordering::Acquire) {
inner = inner.danger_accept_invalid_certs(true);
}
let inner = inner.build().expect("Client initialization to work");
CachingClient {
inner,
locks: RwLock::new(HashMap::new()),
}
});
#[cfg(not(tarpaulin_include))]
pub static DEFAULT_HEADERS: Lazy<HeaderMap> = Lazy::new(|| {
let mut headers = HeaderMap::with_capacity(8);
headers.insert(X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff"));
headers.insert(
ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_static("https://mangadex.org"),
);
headers.insert(ACCESS_CONTROL_EXPOSE_HEADERS, HeaderValue::from_static("*"));
headers.insert(
CACHE_CONTROL,
HeaderValue::from_static("public, max-age=1209600"),
);
headers.insert(
HeaderName::from_static("timing-allow-origin"),
HeaderValue::from_static("https://mangadex.org"),
);
headers
});
pub struct CachingClient {
inner: Client,
locks: RwLock<HashMap<String, Receiver<FetchResult>>>,
}
#[derive(Clone, Debug)]
pub enum FetchResult {
ServiceUnavailable,
InternalServerError,
Data(StatusCode, HeaderMap, Bytes),
Processing,
}
impl CachingClient {
pub async fn fetch_and_cache(
&'static self,
url: String,
key: CacheKey,
cache: Data<dyn Cache>,
) -> FetchResult {
let maybe_receiver = {
let lock = self.locks.read();
lock.get(&url).map(Clone::clone)
};
if let Some(mut recv) = maybe_receiver {
loop {
if !matches!(*recv.borrow(), FetchResult::Processing) {
break;
}
if recv.changed().await.is_err() {
break;
}
}
return recv.borrow().clone();
}
let notify = Arc::new(Notify::new());
tokio::spawn(self.fetch_and_cache_impl(cache, url.clone(), key, Arc::clone(&notify)));
notify.notified().await;
let mut recv = self
.locks
.read()
.get(&url)
.expect("receiver to exist since we just made one")
.clone();
loop {
if !matches!(*recv.borrow(), FetchResult::Processing) {
break;
}
if recv.changed().await.is_err() {
break;
}
}
let resp = recv.borrow().clone();
resp
}
async fn fetch_and_cache_impl(
&self,
cache: Data<dyn Cache>,
url: String,
key: CacheKey,
notify: Arc<Notify>,
) {
let (tx, rx) = channel(FetchResult::Processing);
self.locks.write().insert(url.clone(), rx);
notify.notify_one();
let resp = self.inner.get(&url).send().await;
let resp = match resp {
Ok(mut resp) => {
let content_type = resp.headers().get(CONTENT_TYPE);
let is_image = content_type
.map(|v| String::from_utf8_lossy(v.as_ref()).contains("image/"))
.unwrap_or_default();
if resp.status() != StatusCode::OK || !is_image {
warn!("Got non-OK or non-image response code from upstream, proxying and not caching result.");
let mut headers = DEFAULT_HEADERS.clone();
if let Some(content_type) = content_type {
headers.insert(CONTENT_TYPE, content_type.clone());
}
FetchResult::Data(
resp.status(),
headers,
resp.bytes().await.unwrap_or_default(),
)
} else {
let (content_type, length, last_mod) = {
let headers = resp.headers_mut();
(
headers.remove(CONTENT_TYPE),
headers.remove(CONTENT_LENGTH),
headers.remove(LAST_MODIFIED),
)
};
let body = resp.bytes().await.unwrap();
debug!("Inserting into cache");
let metadata =
ImageMetadata::new(content_type.clone(), length.clone(), last_mod.clone())
.unwrap();
match cache.put(key, body.clone(), metadata).await {
Ok(()) => {
debug!("Done putting into cache");
let mut headers = DEFAULT_HEADERS.clone();
if let Some(content_type) = content_type {
headers.insert(CONTENT_TYPE, content_type);
}
if let Some(content_length) = length {
headers.insert(CONTENT_LENGTH, content_length);
}
if let Some(last_modified) = last_mod {
headers.insert(LAST_MODIFIED, last_modified);
}
FetchResult::Data(StatusCode::OK, headers, body)
}
Err(e) => {
warn!("Failed to insert into cache: {}", e);
FetchResult::InternalServerError
}
}
}
}
Err(e) => {
error!("Failed to fetch image from server: {}", e);
FetchResult::ServiceUnavailable
}
};
// This shouldn't happen
tx.send(resp).unwrap();
self.locks.write().remove(&url);
}
#[inline]
pub const fn inner(&self) -> &Client {
&self.inner
}
}

View File

@ -1,67 +1,340 @@
use std::fmt::{Display, Formatter};
use std::num::{NonZeroU16, NonZeroU64};
use std::path::PathBuf;
use std::fs::{File, OpenOptions};
use std::hint::unreachable_unchecked;
use std::io::{ErrorKind, Write};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::num::NonZeroU16;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::{AtomicBool, Ordering};
use clap::{crate_authors, crate_description, crate_version, Clap};
use clap::{crate_authors, crate_description, crate_version, Parser};
use log::LevelFilter;
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::level_filters::LevelFilter as TracingLevelFilter;
use url::Url;
use crate::units::{KilobitsPerSecond, Mebibytes, Port};
// Validate tokens is an atomic because it's faster than locking on rwlock.
pub static VALIDATE_TOKENS: AtomicBool = AtomicBool::new(false);
// We use an atomic here because it's better for us to not pass the config
// everywhere.
pub static SEND_SERVER_VERSION: AtomicBool = AtomicBool::new(false);
pub static OFFLINE_MODE: AtomicBool = AtomicBool::new(false);
pub static USE_PROXY: OnceCell<Url> = OnceCell::new();
pub static DISABLE_CERT_VALIDATION: AtomicBool = AtomicBool::new(false);
#[derive(Clap, Clone)]
#[clap(version = crate_version!(), author = crate_authors!(), about = crate_description!())]
pub struct CliArgs {
/// The port to listen on.
#[clap(short, long, default_value = "42069", env = "PORT")]
pub port: NonZeroU16,
/// How large, in bytes, the in-memory cache should be. Note that this does
/// not include runtime memory usage.
#[clap(long, env = "MEM_CACHE_QUOTA_BYTES", conflicts_with = "low-memory")]
pub memory_quota: Option<NonZeroU64>,
/// How large, in bytes, the on-disk cache should be. Note that actual
/// values may be larger for metadata information.
#[clap(long, env = "DISK_CACHE_QUOTA_BYTES")]
pub disk_quota: u64,
/// Sets the location of the disk cache.
#[clap(long, default_value = "./cache", env = "DISK_CACHE_PATH")]
#[derive(Error, Debug)]
pub enum ConfigError {
#[error("No config found. One has been created for you to modify.")]
NotInitialized,
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_yaml::Error),
}
pub fn load_config() -> Result<Config, ConfigError> {
// Load cli args first
let cli_args: CliArgs = CliArgs::parse();
// Load yaml file next
let config_file: Result<YamlArgs, _> = {
let config_path = cli_args
.config_path
.as_deref()
.unwrap_or_else(|| Path::new("./settings.yaml"));
match File::open(config_path) {
Ok(file) => serde_yaml::from_reader(file),
Err(e) if e.kind() == ErrorKind::NotFound => {
let mut file = OpenOptions::new()
.write(true)
.create_new(true)
.open(config_path)
.unwrap();
let default_config = include_str!("../settings.sample.yaml");
file.write_all(default_config.as_bytes()).unwrap();
return Err(ConfigError::NotInitialized);
}
Err(e) => return Err(e.into()),
}
};
// generate config
let config = Config::from_cli_and_file(cli_args, config_file?);
// initialize globals
OFFLINE_MODE.store(
config
.unstable_options
.contains(&UnstableOptions::OfflineMode),
Ordering::Release,
);
if let Some(socket) = config.proxy.clone() {
USE_PROXY
.set(socket)
.expect("USE_PROXY to be set only by this function");
}
DISABLE_CERT_VALIDATION.store(
config
.unstable_options
.contains(&UnstableOptions::DisableCertValidation),
Ordering::Release,
);
Ok(config)
}
#[derive(Debug)]
/// Represents a fully parsed config, from a variety of sources.
pub struct Config {
pub cache_type: CacheType,
pub cache_path: PathBuf,
pub shutdown_timeout: NonZeroU16,
pub log_level: TracingLevelFilter,
pub client_secret: ClientSecret,
pub port: Port,
pub bind_address: SocketAddr,
pub external_address: Option<SocketAddr>,
pub ephemeral_disk_encryption: bool,
pub network_speed: KilobitsPerSecond,
pub disk_quota: Mebibytes,
pub memory_quota: Mebibytes,
pub unstable_options: Vec<UnstableOptions>,
pub override_upstream: Option<Url>,
pub enable_metrics: bool,
pub geoip_license_key: Option<ClientSecret>,
pub proxy: Option<Url>,
pub redis_url: Option<Url>,
}
impl Config {
fn from_cli_and_file(cli_args: CliArgs, file_args: YamlArgs) -> Self {
let file_extended_options = file_args.extended_options.unwrap_or_default();
let log_level = match (cli_args.quiet, cli_args.verbose) {
(n, _) if n > 2 => TracingLevelFilter::OFF,
(2, _) => TracingLevelFilter::ERROR,
(1, _) => TracingLevelFilter::WARN,
// Use log level from file if no flags were provided to CLI
(0, 0) => {
file_extended_options
.logging_level
.map_or(TracingLevelFilter::INFO, |filter| match filter {
LevelFilter::Off => TracingLevelFilter::OFF,
LevelFilter::Error => TracingLevelFilter::ERROR,
LevelFilter::Warn => TracingLevelFilter::WARN,
LevelFilter::Info => TracingLevelFilter::INFO,
LevelFilter::Debug => TracingLevelFilter::DEBUG,
LevelFilter::Trace => TracingLevelFilter::TRACE,
})
}
(_, 1) => TracingLevelFilter::DEBUG,
(_, n) if n > 1 => TracingLevelFilter::TRACE,
// compiler can't figure it out
_ => unsafe { unreachable_unchecked() },
};
let bind_port = cli_args
.port
.unwrap_or(file_args.server_settings.port)
.get();
// This needs to be outside because rust isn't smart enough yet to
// realize a disjointed borrow of a moved value is ok. This will be
// fixed in Rust 2021.
let external_port = file_args
.server_settings
.external_port
.map_or(bind_port, Port::get);
Self {
cache_type: cli_args
.cache_type
.or(file_extended_options.cache_type)
.unwrap_or_default(),
cache_path: cli_args
.cache_path
.or(file_extended_options.cache_path)
.unwrap_or_else(|| PathBuf::from_str("./cache").unwrap()),
shutdown_timeout: file_args
.server_settings
.graceful_shutdown_wait_seconds
.unwrap_or(unsafe { NonZeroU16::new_unchecked(60) }),
log_level,
// secret should never be in CLI
client_secret: if let Ok(v) = std::env::var("CLIENT_SECRET") {
ClientSecret(v)
} else {
file_args.server_settings.secret
},
port: cli_args.port.unwrap_or(file_args.server_settings.port),
bind_address: SocketAddr::new(
file_args
.server_settings
.hostname
.unwrap_or_else(|| IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))),
bind_port,
),
external_address: file_args
.server_settings
.external_ip
.map(|ip_addr| SocketAddr::new(ip_addr, external_port)),
ephemeral_disk_encryption: cli_args.ephemeral_disk_encryption
|| file_extended_options
.ephemeral_disk_encryption
.unwrap_or_default(),
network_speed: cli_args
.network_speed
.unwrap_or(file_args.server_settings.external_max_kilobits_per_second),
disk_quota: cli_args
.disk_quota
.unwrap_or(file_args.max_cache_size_in_mebibytes),
memory_quota: cli_args
.memory_quota
.or(file_extended_options.memory_quota)
.unwrap_or_default(),
enable_metrics: file_extended_options.enable_metrics.unwrap_or_default(),
// Unstable options (and related) should never be in yaml config
unstable_options: cli_args.unstable_options,
override_upstream: cli_args.override_upstream,
geoip_license_key: file_args.metric_settings.and_then(|args| {
if args.enable_geoip.unwrap_or_default() {
args.geoip_license_key
} else {
None
}
}),
proxy: cli_args.proxy,
redis_url: file_extended_options.redis_url,
}
}
}
// this intentionally does not implement display
#[derive(Deserialize, Serialize, Clone)]
pub struct ClientSecret(String);
impl ClientSecret {
pub fn as_str(&self) -> &str {
self.0.as_ref()
}
}
impl std::fmt::Debug for ClientSecret {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "[client secret]")
}
}
#[derive(Deserialize, Copy, Clone, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum CacheType {
OnDisk,
Lru,
Lfu,
Redis,
}
impl FromStr for CacheType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"on_disk" => Ok(Self::OnDisk),
"lru" => Ok(Self::Lru),
"lfu" => Ok(Self::Lfu),
"redis" => Ok(Self::Redis),
_ => Err(format!("Unknown option: {}", s)),
}
}
}
impl Default for CacheType {
fn default() -> Self {
Self::OnDisk
}
}
#[derive(Deserialize)]
struct YamlArgs {
// Naming is legacy
max_cache_size_in_mebibytes: Mebibytes,
server_settings: YamlServerSettings,
metric_settings: Option<YamlMetricSettings>,
// This implementation's custom options
extended_options: Option<YamlExtendedOptions>,
}
// Naming is legacy
#[derive(Deserialize)]
struct YamlServerSettings {
secret: ClientSecret,
#[serde(default)]
port: Port,
external_max_kilobits_per_second: KilobitsPerSecond,
external_port: Option<Port>,
graceful_shutdown_wait_seconds: Option<NonZeroU16>,
hostname: Option<IpAddr>,
external_ip: Option<IpAddr>,
}
#[derive(Deserialize)]
struct YamlMetricSettings {
enable_geoip: Option<bool>,
geoip_license_key: Option<ClientSecret>,
}
#[derive(Deserialize, Default)]
struct YamlExtendedOptions {
memory_quota: Option<Mebibytes>,
cache_type: Option<CacheType>,
ephemeral_disk_encryption: Option<bool>,
enable_metrics: Option<bool>,
logging_level: Option<LevelFilter>,
cache_path: Option<PathBuf>,
redis_url: Option<Url>,
}
#[derive(Parser, Clone)]
#[clap(version = crate_version!(), author = crate_authors!(), about = crate_description!())]
struct CliArgs {
/// The port to listen on.
#[clap(short, long)]
pub port: Option<Port>,
/// How large, in mebibytes, the in-memory cache should be. Note that this
/// does not include runtime memory usage.
#[clap(long)]
pub memory_quota: Option<Mebibytes>,
/// How large, in mebibytes, the on-disk cache should be. Note that actual
/// values may be larger for metadata information.
#[clap(long)]
pub disk_quota: Option<Mebibytes>,
/// Sets the location of the disk cache.
#[clap(long)]
pub cache_path: Option<PathBuf>,
/// The network speed to advertise to Mangadex@Home control server.
#[clap(long, env = "MAX_NETWORK_SPEED")]
pub network_speed: NonZeroU64,
/// Whether or not to provide the Server HTTP header to clients. This is
/// useful for debugging, but is generally not recommended for security
/// reasons.
#[clap(long, env = "ENABLE_SERVER_STRING", takes_value = false)]
pub enable_server_string: bool,
/// Changes the caching behavior to avoid buffering images in memory, and
/// instead use the filesystem as the buffer backing. This is useful for
/// clients in low (< 1GB) RAM environments.
#[clap(
short,
long,
conflicts_with("memory-quota"),
env = "LOW_MEMORY_MODE",
takes_value = false
)]
pub low_memory: bool,
#[clap(long)]
pub network_speed: Option<KilobitsPerSecond>,
/// Changes verbosity. Default verbosity is INFO, while increasing counts of
/// verbose flags increases the verbosity to DEBUG and TRACE, respectively.
#[clap(short, long, parse(from_occurrences))]
#[clap(short, long, parse(from_occurrences), conflicts_with = "quiet")]
pub verbose: usize,
/// Changes verbosity. Default verbosity is INFO, while increasing counts of
/// quiet flags decreases the verbosity to WARN, ERROR, and no logs,
/// respectively.
#[clap(short, long, parse(from_occurrences), conflicts_with = "verbose")]
pub quiet: usize,
/// Unstable options. Intentionally not documented.
#[clap(short = 'Z', long)]
pub unstable_options: Vec<UnstableOptions>,
/// Override the image server with the one provided. Do not set this unless
/// you know what you're doing.
#[clap(long)]
pub override_upstream: Option<Url>,
/// Enables ephemeral disk encryption. Items written to disk are first
@ -69,6 +342,17 @@ pub struct CliArgs {
/// performance, privacy, and usability with this flag enabled.
#[clap(short, long)]
pub ephemeral_disk_encryption: bool,
/// The path to the config file. Default value is `./settings.yaml`.
#[clap(short, long)]
pub config_path: Option<PathBuf>,
/// Whether to use an in-memory cache in addition to the disk cache. Default
/// value is "on_disk", other options are "lfu", "lru", and "redis".
#[clap(short = 't', long)]
pub cache_type: Option<CacheType>,
/// Whether or not to use a proxy for upstream requests. This affects all
/// requests except for the shutdown request.
#[clap(short = 'P', long)]
pub proxy: Option<Url>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@ -77,10 +361,6 @@ pub enum UnstableOptions {
/// you know what you're dealing with.
OverrideUpstream,
/// Use an LFU implementation for the in-memory cache instead of the default
/// LRU implementation.
UseLfu,
/// Disables token validation. Don't use this unless you know the
/// ramifications of this command.
DisableTokenValidation,
@ -90,6 +370,10 @@ pub enum UnstableOptions {
/// Serves HTTP in plaintext
DisableTls,
/// Disable certificate validation. Only useful for debugging with a MITM
/// proxy
DisableCertValidation,
}
impl FromStr for UnstableOptions {
@ -98,10 +382,10 @@ impl FromStr for UnstableOptions {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"override-upstream" => Ok(Self::OverrideUpstream),
"use-lfu" => Ok(Self::UseLfu),
"disable-token-validation" => Ok(Self::DisableTokenValidation),
"offline-mode" => Ok(Self::OfflineMode),
"disable-tls" => Ok(Self::DisableTls),
"disable-cert-validation" => Ok(Self::DisableCertValidation),
_ => Err(format!("Unknown unstable option '{}'", s)),
}
}
@ -111,10 +395,85 @@ impl Display for UnstableOptions {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::OverrideUpstream => write!(f, "override-upstream"),
Self::UseLfu => write!(f, "use-lfu"),
Self::DisableTokenValidation => write!(f, "disable-token-validation"),
Self::OfflineMode => write!(f, "offline-mode"),
Self::DisableTls => write!(f, "disable-tls"),
Self::DisableCertValidation => write!(f, "disable-cert-validation"),
}
}
}
#[cfg(test)]
mod sample_yaml {
use crate::config::YamlArgs;
#[test]
fn parses() {
assert!(serde_yaml::from_str::<YamlArgs>(include_str!("../settings.sample.yaml")).is_ok());
}
}
#[cfg(test)]
mod config {
use std::path::PathBuf;
use log::LevelFilter;
use tracing::level_filters::LevelFilter as TracingLevelFilter;
use crate::config::{CacheType, ClientSecret, Config, YamlExtendedOptions, YamlServerSettings};
use crate::units::{KilobitsPerSecond, Mebibytes, Port};
use super::{CliArgs, YamlArgs};
#[test]
fn cli_has_priority() {
let cli_config = CliArgs {
port: Port::new(1234),
memory_quota: Some(Mebibytes::new(10)),
disk_quota: Some(Mebibytes::new(10)),
cache_path: Some(PathBuf::from("a")),
network_speed: KilobitsPerSecond::new(10),
verbose: 1,
quiet: 0,
unstable_options: vec![],
override_upstream: None,
ephemeral_disk_encryption: true,
config_path: None,
cache_type: Some(CacheType::Lfu),
proxy: None,
};
let yaml_args = YamlArgs {
max_cache_size_in_mebibytes: Mebibytes::new(50),
server_settings: YamlServerSettings {
secret: ClientSecret(String::new()),
port: Port::new(4321).expect("to work?"),
external_max_kilobits_per_second: KilobitsPerSecond::new(50).expect("to work?"),
external_port: None,
graceful_shutdown_wait_seconds: None,
hostname: None,
external_ip: None,
},
metric_settings: None,
extended_options: Some(YamlExtendedOptions {
memory_quota: Some(Mebibytes::new(50)),
cache_type: Some(CacheType::Lru),
ephemeral_disk_encryption: Some(false),
enable_metrics: None,
logging_level: Some(LevelFilter::Error),
cache_path: Some(PathBuf::from("b")),
redis_url: None,
}),
};
let config = Config::from_cli_and_file(cli_config, yaml_args);
assert_eq!(Some(config.port), Port::new(1234));
assert_eq!(config.memory_quota, Mebibytes::new(10));
assert_eq!(config.disk_quota, Mebibytes::new(10));
assert_eq!(config.cache_path, PathBuf::from("a"));
assert_eq!(Some(config.network_speed), KilobitsPerSecond::new(10));
assert_eq!(config.log_level, TracingLevelFilter::DEBUG);
assert_eq!(config.ephemeral_disk_encryption, true);
assert_eq!(config.cache_type, CacheType::Lfu);
}
}

View File

@ -2,50 +2,52 @@
// We're end users, so these is ok
#![allow(clippy::module_name_repetitions)]
use std::env::{self, VarError};
use std::env::VarError;
use std::error::Error;
use std::fmt::Display;
use std::hint::unreachable_unchecked;
use std::num::{NonZeroU64, ParseIntError};
use std::process;
use std::net::SocketAddr;
use std::num::ParseIntError;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use actix_web::dev::Service;
use actix_web::rt::{spawn, time, System};
use actix_web::web::{self, Data};
use actix_web::{App, HttpResponse, HttpServer};
use cache::{Cache, DiskCache};
use clap::Clap;
use config::CliArgs;
use log::{debug, error, info, warn, LevelFilter};
use chacha20::Key;
use config::Config;
use maxminddb::geoip2;
use parking_lot::RwLock;
use rustls::{NoClientAuth, ServerConfig};
use simple_logger::SimpleLogger;
use sodiumoxide::crypto::secretstream::gen_key;
use redis::Client as RedisClient;
use rustls::server::NoClientAuth;
use rustls::ServerConfig;
use sodiumoxide::crypto::stream::xchacha20::gen_key;
use state::{RwLockServerState, ServerState};
use stop::send_stop;
use thiserror::Error;
use tracing::{debug, error, info, warn};
use crate::cache::mem::{Lfu, Lru};
use crate::cache::{MemoryCache, ENCRYPTION_KEY};
use crate::config::{UnstableOptions, OFFLINE_MODE};
use crate::config::{CacheType, UnstableOptions, OFFLINE_MODE};
use crate::metrics::{record_country_visit, GEOIP_DATABASE};
use crate::state::DynamicServerCert;
mod cache;
mod client;
mod config;
mod metrics;
mod ping;
mod routes;
mod state;
mod stop;
mod units;
#[macro_export]
macro_rules! client_api_version {
() => {
"31"
};
}
const CLIENT_API_VERSION: usize = 31;
#[derive(Error, Debug)]
enum ServerError {
@ -65,77 +67,76 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Config loading
//
let cli_args = CliArgs::parse();
let port = cli_args.port;
let memory_max_size = cli_args
.memory_quota
.map(NonZeroU64::get)
.unwrap_or_default();
let disk_quota = cli_args.disk_quota;
let cache_path = cli_args.cache_path.clone();
let low_mem_mode = cli_args.low_memory;
let use_lfu = cli_args.unstable_options.contains(&UnstableOptions::UseLfu);
let disable_tls = cli_args
let config = match config::load_config() {
Ok(c) => c,
Err(e) => {
eprintln!("{}", e);
return Err(Box::new(e) as Box<_>);
}
};
let memory_quota = config.memory_quota;
let disk_quota = config.disk_quota;
let cache_type = config.cache_type;
let cache_path = config.cache_path.clone();
let disable_tls = config
.unstable_options
.contains(&UnstableOptions::DisableTls);
OFFLINE_MODE.store(
cli_args
.unstable_options
.contains(&UnstableOptions::OfflineMode),
Ordering::Release,
);
let bind_address = config.bind_address;
let redis_url = config.redis_url.clone();
//
// Logging and warnings
//
let log_level = match (cli_args.quiet, cli_args.verbose) {
(n, _) if n > 2 => LevelFilter::Off,
(2, _) => LevelFilter::Error,
(1, _) => LevelFilter::Warn,
(0, 0) => LevelFilter::Info,
(_, 1) => LevelFilter::Debug,
(_, n) if n > 1 => LevelFilter::Trace,
// compiler can't figure it out
_ => unsafe { unreachable_unchecked() },
};
tracing_subscriber::fmt()
.with_max_level(config.log_level)
.init();
SimpleLogger::new().with_level(log_level).init()?;
if let Err(e) = print_preamble_and_warnings(&cli_args) {
if let Err(e) = print_preamble_and_warnings(&config) {
error!("{}", e);
return Err(e);
}
let client_secret = if let Ok(v) = env::var("CLIENT_SECRET") {
v
} else {
error!("Client secret not found in ENV. Please set CLIENT_SECRET.");
process::exit(1);
};
let client_secret_1 = client_secret.clone();
debug!("{:?}", &config);
if cli_args.ephemeral_disk_encryption {
let client_secret = config.client_secret.clone();
let client_secret_1 = config.client_secret.clone();
if config.ephemeral_disk_encryption {
info!("Running with at-rest encryption!");
ENCRYPTION_KEY.set(gen_key()).unwrap();
ENCRYPTION_KEY
.set(*Key::from_slice(gen_key().as_ref()))
.unwrap();
}
metrics::init();
if config.enable_metrics {
metrics::init();
}
if let Some(key) = config.geoip_license_key.clone() {
if let Err(e) = metrics::load_geo_ip_data(key).await {
error!("Failed to initialize geo ip db: {}", e);
}
}
// HTTP Server init
// Try bind to provided port first
let port_reservation = std::net::TcpListener::bind(bind_address);
if let Err(e) = port_reservation {
error!("Failed to bind to port!");
return Err(e.into());
};
let server = if OFFLINE_MODE.load(Ordering::Acquire) {
ServerState::init_offline()
} else {
ServerState::init(&client_secret, &cli_args).await?
ServerState::init(&client_secret, &config).await?
};
let data_0 = Arc::new(RwLockServerState(RwLock::new(server)));
let data_1 = Arc::clone(&data_0);
// What's nice is that Rustls only supports TLS 1.2 and 1.3.
let mut tls_config = ServerConfig::new(NoClientAuth::new());
tls_config.cert_resolver = Arc::new(DynamicServerCert);
//
// At this point, the server is ready to start, and starts the necessary
// threads.
@ -155,7 +156,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
send_stop(&client_secret).await;
} else {
warn!("Got second Ctrl-C, forcefully exiting");
system.stop()
system.stop();
}
});
}
@ -171,18 +172,25 @@ async fn main() -> Result<(), Box<dyn Error>> {
loop {
interval.tick().await;
debug!("Sending ping!");
ping::update_server_state(&client_secret_1, &cli_args, &mut data).await;
ping::update_server_state(&client_secret_1, &config, &mut data).await;
}
});
}
let cache = DiskCache::new(disk_quota, cache_path.clone()).await;
let cache: Arc<dyn Cache> = if low_mem_mode {
cache
} else if use_lfu {
MemoryCache::<Lfu, _>::new(cache, memory_max_size).await
} else {
MemoryCache::<Lru, _>::new(cache, memory_max_size).await
let memory_max_size = memory_quota.into();
let cache = DiskCache::new(disk_quota.into(), cache_path.clone()).await;
let cache: Arc<dyn Cache> = match cache_type {
CacheType::OnDisk => cache,
CacheType::Lru => MemoryCache::<Lfu, _>::new(cache, memory_max_size),
CacheType::Lfu => MemoryCache::<Lru, _>::new(cache, memory_max_size),
CacheType::Redis => {
let url = redis_url.unwrap_or_else(|| {
url::Url::parse("redis://127.0.0.1/").expect("default redis url to be parsable")
});
info!("Trying to connect to redis instance at {}", url);
let mem_cache = RedisClient::open(url)?;
Arc::new(MemoryCache::new_with_cache(cache, mem_cache))
}
};
let cache_0 = Arc::clone(&cache);
@ -190,6 +198,23 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Start HTTPS server
let server = HttpServer::new(move || {
App::new()
.wrap_fn(|req, srv| {
if let Some(reader) = GEOIP_DATABASE.get() {
let maybe_country = req
.connection_info()
.realip_remote_addr()
.map(SocketAddr::from_str)
.and_then(Result::ok)
.as_ref()
.map(SocketAddr::ip)
.map(|ip| reader.lookup::<geoip2::Country>(ip))
.and_then(Result::ok);
record_country_visit(maybe_country);
}
srv.call(req)
})
.service(routes::index)
.service(routes::token_data)
.service(routes::token_data_saver)
@ -208,18 +233,25 @@ async fn main() -> Result<(), Box<dyn Error>> {
})
.shutdown_timeout(60);
// drop port reservation, might have a TOCTOU but it's not a big deal; this
// is just a best effort.
std::mem::drop(port_reservation);
if disable_tls {
server.bind(format!("0.0.0.0:{}", port))?.run().await?;
server.bind(bind_address)?.run().await?;
} else {
server
.bind_rustls(format!("0.0.0.0:{}", port), tls_config)?
.run()
.await?;
// Rustls only supports TLS 1.2 and 1.3.
let tls_config = ServerConfig::builder()
.with_safe_defaults()
.with_client_cert_verifier(NoClientAuth::new())
.with_cert_resolver(Arc::new(DynamicServerCert));
server.bind_rustls(bind_address, tls_config)?.run().await?;
}
// Waiting for us to finish sending stop message
while running.load(Ordering::SeqCst) {
std::thread::sleep(Duration::from_millis(250));
tokio::time::sleep(Duration::from_millis(250)).await;
}
Ok(())
@ -230,6 +262,7 @@ enum InvalidCombination {
MissingUnstableOption(&'static str, UnstableOptions),
}
#[cfg(not(tarpaulin_include))]
impl Display for InvalidCombination {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@ -246,32 +279,38 @@ impl Display for InvalidCombination {
impl Error for InvalidCombination {}
fn print_preamble_and_warnings(args: &CliArgs) -> Result<(), Box<dyn Error>> {
println!(concat!(
env!("CARGO_PKG_NAME"),
" ",
env!("CARGO_PKG_VERSION"),
" (",
env!("VERGEN_GIT_SHA_SHORT"),
")",
" Copyright (C) 2021 ",
env!("CARGO_PKG_AUTHORS"),
"\n\n",
env!("CARGO_PKG_NAME"),
" is free software: you can redistribute it and/or modify\n\
it under the terms of the GNU General Public License as published by\n\
the Free Software Foundation, either version 3 of the License, or\n\
(at your option) any later version.\n\n",
env!("CARGO_PKG_NAME"),
" is distributed in the hope that it will be useful,\n\
but WITHOUT ANY WARRANTY; without even the implied warranty of\n\
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n\
GNU General Public License for more details.\n\n\
You should have received a copy of the GNU General Public License\n\
along with ",
env!("CARGO_PKG_NAME"),
". If not, see <https://www.gnu.org/licenses/>.\n"
));
#[cfg(not(tarpaulin_include))]
#[allow(clippy::cognitive_complexity)]
fn print_preamble_and_warnings(args: &Config) -> Result<(), Box<dyn Error>> {
let build_string = option_env!("VERGEN_GIT_SHA_SHORT")
.map(|git_sha| format!(" ({})", git_sha))
.unwrap_or_default();
println!(
concat!(
env!("CARGO_PKG_NAME"),
" ",
env!("CARGO_PKG_VERSION"),
"{} Copyright (C) 2021 ",
env!("CARGO_PKG_AUTHORS"),
"\n\n",
env!("CARGO_PKG_NAME"),
" is free software: you can redistribute it and/or modify\n\
it under the terms of the GNU General Public License as published by\n\
the Free Software Foundation, either version 3 of the License, or\n\
(at your option) any later version.\n\n",
env!("CARGO_PKG_NAME"),
" is distributed in the hope that it will be useful,\n\
but WITHOUT ANY WARRANTY; without even the implied warranty of\n\
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n\
GNU General Public License for more details.\n\n\
You should have received a copy of the GNU General Public License\n\
along with ",
env!("CARGO_PKG_NAME"),
". If not, see <https://www.gnu.org/licenses/>.\n"
),
build_string
);
if !args.unstable_options.is_empty() {
warn!("Unstable options are enabled. These options should not be used in production!");
@ -288,6 +327,13 @@ fn print_preamble_and_warnings(args: &CliArgs) -> Result<(), Box<dyn Error>> {
warn!("Serving insecure traffic! You better be running this for development only.");
}
if args
.unstable_options
.contains(&UnstableOptions::DisableCertValidation)
{
error!("Cert validation disabled! You REALLY only better be debugging.");
}
if args.override_upstream.is_some()
&& !args
.unstable_options

View File

@ -1,5 +1,31 @@
use once_cell::sync::Lazy;
use prometheus::{register_int_counter, IntCounter};
#![cfg(not(tarpaulin_include))]
use std::fs::metadata;
use std::hint::unreachable_unchecked;
use std::time::SystemTime;
use chrono::Duration;
use flate2::read::GzDecoder;
use maxminddb::geoip2::Country;
use once_cell::sync::{Lazy, OnceCell};
use prometheus::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec};
use tar::Archive;
use thiserror::Error;
use tracing::{debug, field::debug, info, warn};
use crate::client::HTTP_CLIENT;
use crate::config::ClientSecret;
pub static GEOIP_DATABASE: OnceCell<maxminddb::Reader<Vec<u8>>> = OnceCell::new();
static COUNTRY_VISIT_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"country_visits_total",
"The number of visits from a country",
&["country"]
)
.unwrap()
});
macro_rules! init_counters {
($(($counter:ident, $ty:ty, $name:literal, $desc:literal),)*) => {
@ -11,7 +37,11 @@ macro_rules! init_counters {
#[allow(clippy::shadow_unrelated)]
pub fn init() {
// These need to be called at least once, otherwise the macro never
// called and thus the metrics don't get logged
$(let _a = $counter.get();)*
init_other();
}
};
}
@ -20,13 +50,13 @@ init_counters!(
(
CACHE_HIT_COUNTER,
IntCounter,
"cache_hit",
"cache_hit_total",
"The number of cache hits."
),
(
CACHE_MISS_COUNTER,
IntCounter,
"cache_miss",
"cache_miss_total",
"The number of cache misses."
),
(
@ -38,19 +68,118 @@ init_counters!(
(
REQUESTS_DATA_COUNTER,
IntCounter,
"requests_data",
"requests_data_total",
"The number of requests served from the /data endpoint."
),
(
REQUESTS_DATA_SAVER_COUNTER,
IntCounter,
"requests_data_saver",
"requests_data_saver_total",
"The number of requests served from the /data-saver endpoint."
),
(
REQUESTS_OTHER_COUNTER,
IntCounter,
"requests_other",
"requests_other_total",
"The total number of request not served by primary endpoints."
),
);
// initialization for any other counters that aren't simple int counters
fn init_other() {
let _a = COUNTRY_VISIT_COUNTER.local();
}
#[derive(Error, Debug)]
pub enum DbLoadError {
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
MaxMindDb(#[from] maxminddb::MaxMindDBError),
}
pub async fn load_geo_ip_data(license_key: ClientSecret) -> Result<(), DbLoadError> {
const DB_PATH: &str = "./GeoLite2-Country.mmdb";
// Check date of db
let db_date_created = metadata(DB_PATH)
.ok()
.and_then(|metadata| {
if let Ok(time) = metadata.created() {
Some(time)
} else {
debug("fs didn't report birth time, fall back to last modified instead");
metadata.modified().ok()
}
})
.unwrap_or(SystemTime::UNIX_EPOCH);
let duration = if let Ok(time) = SystemTime::now().duration_since(db_date_created) {
Duration::from_std(time).expect("duration to fit")
} else {
warn!("Clock may have gone backwards?");
Duration::max_value()
};
// DB expired, fetch a new one
if duration > Duration::weeks(1) {
fetch_db(license_key).await?;
} else {
info!("Geo IP database isn't old enough, not updating.");
}
// Result literally cannot panic here, buuuuuut if it does we'll panic
GEOIP_DATABASE
.set(maxminddb::Reader::open_readfile(DB_PATH)?)
.map_err(|_| ()) // Need to map err here or can't expect
.expect("to set the geo ip db singleton");
Ok(())
}
async fn fetch_db(license_key: ClientSecret) -> Result<(), DbLoadError> {
let resp = HTTP_CLIENT
.inner()
.get(format!("https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-Country&license_key={}&suffix=tar.gz", license_key.as_str()))
.send()
.await?
.bytes()
.await?;
let mut decoder = Archive::new(GzDecoder::new(resp.as_ref()));
let mut decoded_paths: Vec<_> = decoder
.entries()?
.filter_map(Result::ok)
.filter_map(|mut entry| {
let path = entry.path().ok()?.to_path_buf();
let file_name = path.file_name()?;
if file_name != "GeoLite2-Country.mmdb" {
return None;
}
entry.unpack(file_name).ok()?;
Some(path)
})
.collect();
assert_eq!(decoded_paths.len(), 1);
let path = match decoded_paths.pop() {
Some(path) => path,
None => unsafe { unreachable_unchecked() },
};
debug!("Extracted {}", path.as_path().to_string_lossy());
Ok(())
}
pub fn record_country_visit(country: Option<Country>) {
let iso_code = country
.and_then(|country| country.country.and_then(|c| c.iso_code))
.unwrap_or("unknown");
COUNTRY_VISIT_COUNTER
.get_metric_with_label_values(&[iso_code])
.unwrap()
.inc();
}

View File

@ -1,65 +1,71 @@
use std::num::{NonZeroU16, NonZeroU64};
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::Ordering;
use std::{io::BufReader, sync::Arc};
use log::{debug, error, info, warn};
use rustls::internal::pemfile::{certs, rsa_private_keys};
use rustls::sign::{RSASigningKey, SigningKey};
use rustls::Certificate;
use rustls::sign::{CertifiedKey, RsaSigningKey, SigningKey};
use rustls::{Certificate, PrivateKey};
use rustls_pemfile::{certs, rsa_private_keys};
use serde::de::{MapAccess, Visitor};
use serde::{Deserialize, Serialize};
use serde_repr::Deserialize_repr;
use sodiumoxide::crypto::box_::PrecomputedKey;
use tracing::{debug, error, info, warn};
use url::Url;
use crate::config::{CliArgs, VALIDATE_TOKENS};
use crate::client::HTTP_CLIENT;
use crate::config::{ClientSecret, Config};
use crate::state::{
RwLockServerState, PREVIOUSLY_COMPROMISED, PREVIOUSLY_PAUSED, TLS_CERTS,
TLS_PREVIOUSLY_CREATED, TLS_SIGNING_KEY,
RwLockServerState, CERTIFIED_KEY, PREVIOUSLY_COMPROMISED, PREVIOUSLY_PAUSED,
TLS_PREVIOUSLY_CREATED,
};
use crate::{client_api_version, config::UnstableOptions};
use crate::units::{Bytes, BytesPerSecond, Port};
use crate::CLIENT_API_VERSION;
pub const CONTROL_CENTER_PING_URL: &str = "https://api.mangadex.network/ping";
#[derive(Serialize, Debug)]
pub struct Request<'a> {
secret: &'a str,
port: NonZeroU16,
disk_space: u64,
network_speed: NonZeroU64,
build_version: u64,
secret: &'a ClientSecret,
port: Port,
disk_space: Bytes,
network_speed: BytesPerSecond,
build_version: usize,
tls_created_at: Option<String>,
ip_address: Option<IpAddr>,
}
impl<'a> Request<'a> {
fn from_config_and_state(secret: &'a str, config: &CliArgs) -> Self {
fn from_config_and_state(secret: &'a ClientSecret, config: &Config) -> Self {
Self {
secret,
port: config.port,
disk_space: config.disk_quota,
network_speed: config.network_speed,
build_version: client_api_version!()
.parse()
.expect("to parse the build version"),
port: config
.external_address
.and_then(|v| Port::new(v.port()))
.unwrap_or(config.port),
disk_space: config.disk_quota.into(),
network_speed: config.network_speed.into(),
build_version: CLIENT_API_VERSION,
tls_created_at: TLS_PREVIOUSLY_CREATED
.get()
.map(|v| v.load().as_ref().clone()),
ip_address: config.external_address.as_ref().map(SocketAddr::ip),
}
}
}
#[allow(clippy::fallible_impl_from)]
impl<'a> From<(&'a str, &CliArgs)> for Request<'a> {
fn from((secret, config): (&'a str, &CliArgs)) -> Self {
impl<'a> From<(&'a ClientSecret, &Config)> for Request<'a> {
fn from((secret, config): (&'a ClientSecret, &Config)) -> Self {
Self {
secret,
port: config.port,
disk_space: config.disk_quota,
network_speed: config.network_speed,
build_version: client_api_version!()
.parse()
.expect("to parse the build version"),
port: config
.external_address
.and_then(|v| Port::new(v.port()))
.unwrap_or(config.port),
disk_space: config.disk_quota.into(),
network_speed: config.network_speed.into(),
build_version: CLIENT_API_VERSION,
tls_created_at: None,
ip_address: config.external_address.as_ref().map(SocketAddr::ip),
}
}
}
@ -67,7 +73,7 @@ impl<'a> From<(&'a str, &CliArgs)> for Request<'a> {
#[derive(Deserialize, Debug)]
#[serde(untagged)]
pub enum Response {
Ok(OkResponse),
Ok(Box<OkResponse>),
Error(ErrorResponse),
}
@ -79,8 +85,6 @@ pub struct OkResponse {
pub token_key: Option<String>,
pub compromised: bool,
pub paused: bool,
#[serde(default)]
pub force_tokens: bool,
pub tls: Option<Tls>,
}
@ -100,7 +104,7 @@ pub enum ErrorCode {
pub struct Tls {
pub created_at: String,
pub priv_key: Arc<Box<dyn SigningKey>>,
pub priv_key: Arc<RsaSigningKey>,
pub certs: Vec<Certificate>,
}
@ -133,11 +137,12 @@ impl<'de> Deserialize<'de> for Tls {
priv_key = rsa_private_keys(&mut BufReader::new(value.as_bytes()))
.ok()
.and_then(|mut v| {
v.pop().and_then(|key| RSASigningKey::new(&key).ok())
})
v.pop()
.and_then(|key| RsaSigningKey::new(&PrivateKey(key)).ok())
});
}
"certificate" => {
certificates = certs(&mut BufReader::new(value.as_bytes())).ok()
certificates = certs(&mut BufReader::new(value.as_bytes())).ok();
}
_ => (), // Ignore extra fields
}
@ -146,8 +151,8 @@ impl<'de> Deserialize<'de> for Tls {
match (created_at, priv_key, certificates) {
(Some(created_at), Some(priv_key), Some(certificates)) => Ok(Tls {
created_at,
priv_key: Arc::new(Box::new(priv_key)),
certs: certificates,
priv_key: Arc::new(priv_key),
certs: certificates.into_iter().map(Certificate).collect(),
}),
_ => Err(serde::de::Error::custom("Could not deserialize tls info")),
}
@ -158,6 +163,7 @@ impl<'de> Deserialize<'de> for Tls {
}
}
#[cfg(not(tarpaulin_include))]
impl std::fmt::Debug for Tls {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tls")
@ -166,10 +172,19 @@ impl std::fmt::Debug for Tls {
}
}
pub async fn update_server_state(secret: &str, cli: &CliArgs, data: &mut Arc<RwLockServerState>) {
pub async fn update_server_state(
secret: &ClientSecret,
cli: &Config,
data: &mut Arc<RwLockServerState>,
) {
let req = Request::from_config_and_state(secret, cli);
let client = reqwest::Client::new();
let resp = client.post(CONTROL_CENTER_PING_URL).json(&req).send().await;
debug!("Sending ping request: {:?}", req);
let resp = HTTP_CLIENT
.inner()
.post(CONTROL_CENTER_PING_URL)
.json(&req)
.send()
.await;
match resp {
Ok(resp) => match resp.json::<Response>().await {
Ok(Response::Ok(resp)) => {
@ -184,27 +199,13 @@ pub async fn update_server_state(secret: &str, cli: &CliArgs, data: &mut Arc<RwL
}
if let Some(key) = resp.token_key {
if let Some(key) = base64::decode(&key)
base64::decode(&key)
.ok()
.and_then(|k| PrecomputedKey::from_slice(&k))
{
write_guard.precomputed_key = key;
} else {
error!("Failed to parse token key: got {}", key);
}
}
if !cli
.unstable_options
.contains(&UnstableOptions::DisableTokenValidation)
&& VALIDATE_TOKENS.load(Ordering::Acquire) != resp.force_tokens
{
if resp.force_tokens {
info!("Client received command to enforce token validity.");
} else {
info!("Client received command to no longer enforce token validity");
}
VALIDATE_TOKENS.store(resp.force_tokens, Ordering::Release);
.map_or_else(
|| error!("Failed to parse token key: got {}", key),
|key| write_guard.precomputed_key = key,
);
}
if let Some(tls) = resp.tls {
@ -212,8 +213,12 @@ pub async fn update_server_state(secret: &str, cli: &CliArgs, data: &mut Arc<RwL
.get()
.unwrap()
.swap(Arc::new(tls.created_at));
TLS_SIGNING_KEY.get().unwrap().swap(tls.priv_key);
TLS_CERTS.get().unwrap().swap(Arc::new(tls.certs));
CERTIFIED_KEY.store(Some(Arc::new(CertifiedKey {
cert: tls.certs.clone(),
key: Arc::clone(&tls.priv_key) as Arc<dyn SigningKey>,
ocsp: None,
sct_list: None,
})));
}
let previously_compromised = PREVIOUSLY_COMPROMISED.load(Ordering::Acquire);
@ -252,7 +257,7 @@ pub async fn update_server_state(secret: &str, cli: &CliArgs, data: &mut Arc<RwL
},
Err(e) => match e {
e if e.is_timeout() => {
error!("Response timed out to control server. Is MangaDex down?")
error!("Response timed out to control server. Is MangaDex down?");
}
e => warn!("Failed to send request: {}", e),
},

View File

@ -1,61 +1,41 @@
use std::hint::unreachable_unchecked;
use std::sync::atomic::Ordering;
use std::time::Duration;
use actix_web::body::BoxBody;
use actix_web::error::ErrorNotFound;
use actix_web::http::header::{
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, CACHE_CONTROL, CONTENT_LENGTH,
CONTENT_TYPE, LAST_MODIFIED, X_CONTENT_TYPE_OPTIONS,
};
use actix_web::web::Path;
use actix_web::http::header::{HeaderValue, CONTENT_LENGTH, CONTENT_TYPE, LAST_MODIFIED};
use actix_web::web::{Data, Path};
use actix_web::HttpResponseBuilder;
use actix_web::{get, web::Data, HttpRequest, HttpResponse, Responder};
use actix_web::{get, HttpRequest, HttpResponse, Responder};
use base64::DecodeError;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use futures::{Stream, TryStreamExt};
use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use futures::Stream;
use prometheus::{Encoder, TextEncoder};
use reqwest::{Client, StatusCode};
use serde::Deserialize;
use sodiumoxide::crypto::box_::{open_precomputed, Nonce, PrecomputedKey, NONCEBYTES};
use thiserror::Error;
use tracing::{debug, error, info, trace};
use crate::cache::{Cache, CacheKey, ImageMetadata, UpstreamError};
use crate::client_api_version;
use crate::config::{OFFLINE_MODE, SEND_SERVER_VERSION, VALIDATE_TOKENS};
use crate::client::{FetchResult, DEFAULT_HEADERS, HTTP_CLIENT};
use crate::config::{OFFLINE_MODE, VALIDATE_TOKENS};
use crate::metrics::{
CACHE_HIT_COUNTER, CACHE_MISS_COUNTER, REQUESTS_DATA_COUNTER, REQUESTS_DATA_SAVER_COUNTER,
REQUESTS_OTHER_COUNTER, REQUESTS_TOTAL_COUNTER,
};
use crate::state::RwLockServerState;
pub const BASE64_CONFIG: base64::Config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
const BASE64_CONFIG: base64::Config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
static HTTP_CLIENT: Lazy<Client> = Lazy::new(|| {
Client::builder()
.pool_idle_timeout(Duration::from_secs(180))
.https_only(true)
.http2_prior_knowledge()
.build()
.expect("Client initialization to work")
});
const SERVER_ID_STRING: &str = concat!(
env!("CARGO_CRATE_NAME"),
" ",
env!("CARGO_PKG_VERSION"),
" (",
client_api_version!(),
") - Conforming to spec revision b82043289",
);
enum ServerResponse {
pub enum ServerResponse {
TokenValidationError(TokenValidationError),
HttpResponse(HttpResponse),
}
impl Responder for ServerResponse {
type Body = BoxBody;
#[inline]
fn respond_to(self, req: &HttpRequest) -> HttpResponse {
match self {
@ -68,12 +48,12 @@ impl Responder for ServerResponse {
}
}
#[allow(clippy::unused_async)]
#[get("/")]
async fn index() -> impl Responder {
HttpResponse::Ok().body(include_str!("index.html"))
}
#[allow(clippy::future_not_send)]
#[get("/{token}/data/{chapter_hash}/{file_name}")]
async fn token_data(
state: Data<RwLockServerState>,
@ -90,7 +70,6 @@ async fn token_data(
fetch_image(state, cache, chapter_hash, file_name, false).await
}
#[allow(clippy::future_not_send)]
#[get("/{token}/data-saver/{chapter_hash}/{file_name}")]
async fn token_data_saver(
state: Data<RwLockServerState>,
@ -126,42 +105,43 @@ pub async fn default(state: Data<RwLockServerState>, req: HttpRequest) -> impl R
info!("Got unknown path, just proxying: {}", path);
let resp = match HTTP_CLIENT.get(path).send().await {
let mut resp = match HTTP_CLIENT.inner().get(path).send().await {
Ok(resp) => resp,
Err(e) => {
error!("{}", e);
return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish());
}
};
let content_type = resp.headers().get(CONTENT_TYPE);
let content_type = resp.headers_mut().remove(CONTENT_TYPE);
let mut resp_builder = HttpResponseBuilder::new(resp.status());
let mut headers = DEFAULT_HEADERS.clone();
if let Some(content_type) = content_type {
resp_builder.insert_header((CONTENT_TYPE, content_type));
headers.insert(CONTENT_TYPE, content_type);
}
push_headers(&mut resp_builder);
// push_headers(&mut resp_builder);
ServerResponse::HttpResponse(resp_builder.body(resp.bytes().await.unwrap_or_default()))
let mut resp = resp_builder.body(resp.bytes().await.unwrap_or_default());
*resp.headers_mut() = headers;
ServerResponse::HttpResponse(resp)
}
#[allow(clippy::future_not_send)]
#[get("/metrics")]
#[allow(clippy::unused_async)]
#[get("/prometheus")]
pub async fn metrics() -> impl Responder {
let metric_families = prometheus::gather();
let mut buffer = Vec::new();
TextEncoder::new()
.encode(&metric_families, &mut buffer)
.unwrap();
String::from_utf8(buffer).unwrap()
.expect("Should never have an io error writing to a vec");
String::from_utf8(buffer).expect("Text encoder should render valid utf-8")
}
#[derive(Error, Debug)]
enum TokenValidationError {
#[derive(Error, Debug, PartialEq, Eq)]
pub enum TokenValidationError {
#[error("Failed to decode base64 token.")]
DecodeError(#[from] DecodeError),
#[error("Nonce was too short.")]
IncompleteNonce,
#[error("Invalid nonce.")]
InvalidNonce,
#[error("Decryption failed")]
DecryptionFailure,
#[error("The token format was invalid.")]
@ -173,9 +153,13 @@ enum TokenValidationError {
}
impl Responder for TokenValidationError {
type Body = BoxBody;
#[inline]
fn respond_to(self, _: &HttpRequest) -> HttpResponse {
push_headers(&mut HttpResponse::Forbidden()).finish()
let mut resp = HttpResponse::Forbidden().finish();
*resp.headers_mut() = DEFAULT_HEADERS.clone();
resp
}
}
@ -197,7 +181,11 @@ fn validate_token(
let (nonce, encrypted) = data.split_at(NONCEBYTES);
let nonce = Nonce::from_slice(nonce).ok_or(TokenValidationError::InvalidNonce)?;
let nonce = match Nonce::from_slice(nonce) {
Some(nonce) => nonce,
// We split at NONCEBYTES, so this should never happen.
None => unsafe { unreachable_unchecked() },
};
let decrypted = open_precomputed(encrypted, &nonce, precomputed_key)
.map_err(|_| TokenValidationError::DecryptionFailure)?;
@ -217,22 +205,6 @@ fn validate_token(
Ok(())
}
#[inline]
fn push_headers(builder: &mut HttpResponseBuilder) -> &mut HttpResponseBuilder {
builder
.insert_header((X_CONTENT_TYPE_OPTIONS, "nosniff"))
.insert_header((ACCESS_CONTROL_ALLOW_ORIGIN, "https://mangadex.org"))
.insert_header((ACCESS_CONTROL_EXPOSE_HEADERS, "*"))
.insert_header((CACHE_CONTROL, "public, max-age=1209600"))
.insert_header(("Timing-Allow-Origin", "https://mangadex.org"));
if SEND_SERVER_VERSION.load(Ordering::Acquire) {
builder.insert_header(("Server", SERVER_ID_STRING));
}
builder
}
#[allow(clippy::future_not_send)]
async fn fetch_image(
state: Data<RwLockServerState>,
@ -251,7 +223,7 @@ async fn fetch_image(
Some(Err(_)) => {
return ServerResponse::HttpResponse(HttpResponse::BadGateway().finish());
}
_ => (),
None => (),
}
CACHE_MISS_COUNTER.inc();
@ -263,112 +235,217 @@ async fn fetch_image(
);
}
// It's important to not get a write lock before this request, else we're
// holding the read lock until the await resolves.
let resp = if is_data_saver {
HTTP_CLIENT
.get(format!(
"{}/data-saver/{}/{}",
state.0.read().image_server,
&key.0,
&key.1
))
.send()
let url = if is_data_saver {
format!(
"{}/data-saver/{}/{}",
state.0.read().image_server,
&key.0,
&key.1,
)
} else {
HTTP_CLIENT
.get(format!(
"{}/data/{}/{}",
state.0.read().image_server,
&key.0,
&key.1
))
.send()
}
.await;
format!("{}/data/{}/{}", state.0.read().image_server, &key.0, &key.1)
};
match resp {
Ok(mut resp) => {
let content_type = resp.headers().get(CONTENT_TYPE);
let is_image = content_type
.map(|v| String::from_utf8_lossy(v.as_ref()).contains("image/"))
.unwrap_or_default();
if resp.status() != StatusCode::OK || !is_image {
warn!(
"Got non-OK or non-image response code from upstream, proxying and not caching result.",
);
let mut resp_builder = HttpResponseBuilder::new(resp.status());
if let Some(content_type) = content_type {
resp_builder.insert_header((CONTENT_TYPE, content_type));
}
push_headers(&mut resp_builder);
return ServerResponse::HttpResponse(
resp_builder.body(resp.bytes().await.unwrap_or_default()),
);
}
let (content_type, length, last_mod) = {
let headers = resp.headers_mut();
(
headers.remove(CONTENT_TYPE),
headers.remove(CONTENT_LENGTH),
headers.remove(LAST_MODIFIED),
)
};
let body = resp.bytes_stream().map_err(|e| e.into());
debug!("Inserting into cache");
let metadata = ImageMetadata::new(content_type, length, last_mod).unwrap();
let stream = {
match cache.put(key, Box::new(body), metadata).await {
Ok(stream) => stream,
Err(e) => {
warn!("Failed to insert into cache: {}", e);
return ServerResponse::HttpResponse(
HttpResponse::InternalServerError().finish(),
);
}
}
};
debug!("Done putting into cache");
construct_response(stream, &metadata)
match HTTP_CLIENT.fetch_and_cache(url, key, cache).await {
FetchResult::ServiceUnavailable => {
ServerResponse::HttpResponse(HttpResponse::ServiceUnavailable().finish())
}
Err(e) => {
error!("Failed to fetch image from server: {}", e);
ServerResponse::HttpResponse(
push_headers(&mut HttpResponse::ServiceUnavailable()).finish(),
)
FetchResult::InternalServerError => {
ServerResponse::HttpResponse(HttpResponse::InternalServerError().finish())
}
FetchResult::Data(status, headers, data) => {
let mut resp = HttpResponseBuilder::new(status);
let mut resp = resp.body(data);
*resp.headers_mut() = headers;
ServerResponse::HttpResponse(resp)
}
FetchResult::Processing => panic!("Race condition found with fetch result"),
}
}
fn construct_response(
#[inline]
pub fn construct_response(
data: impl Stream<Item = Result<Bytes, UpstreamError>> + Unpin + 'static,
metadata: &ImageMetadata,
) -> ServerResponse {
trace!("Constructing response");
let mut resp = HttpResponse::Ok();
let mut headers = DEFAULT_HEADERS.clone();
if let Some(content_type) = metadata.content_type {
resp.append_header((CONTENT_TYPE, content_type.as_ref()));
headers.insert(
CONTENT_TYPE,
HeaderValue::from_str(content_type.as_ref()).unwrap(),
);
}
if let Some(content_length) = metadata.content_length {
resp.append_header((CONTENT_LENGTH, content_length));
headers.insert(CONTENT_LENGTH, HeaderValue::from(content_length));
}
if let Some(last_modified) = metadata.last_modified {
resp.append_header((LAST_MODIFIED, last_modified.to_rfc2822()));
headers.insert(
LAST_MODIFIED,
HeaderValue::from_str(&last_modified.to_rfc2822()).unwrap(),
);
}
ServerResponse::HttpResponse(push_headers(&mut resp).streaming(data))
let mut ret = resp.streaming(data);
*ret.headers_mut() = headers;
ServerResponse::HttpResponse(ret)
}
#[cfg(test)]
mod token_validation {
use super::{BASE64_CONFIG, DecodeError, PrecomputedKey, TokenValidationError, Utc, validate_token};
use sodiumoxide::crypto::box_::precompute;
use sodiumoxide::crypto::box_::seal_precomputed;
use sodiumoxide::crypto::box_::{gen_keypair, gen_nonce, PRECOMPUTEDKEYBYTES};
#[test]
fn invalid_base64() {
let res = validate_token(
&PrecomputedKey::from_slice(&b"1".repeat(PRECOMPUTEDKEYBYTES))
.expect("valid test token"),
"a".to_string(),
"b",
);
assert_eq!(
res,
Err(TokenValidationError::DecodeError(
DecodeError::InvalidLength
))
);
}
#[test]
fn not_long_enough_for_nonce() {
let res = validate_token(
&PrecomputedKey::from_slice(&b"1".repeat(PRECOMPUTEDKEYBYTES))
.expect("valid test token"),
"aGVsbG8gaW50ZXJuZXR-Cg==".to_string(),
"b",
);
assert_eq!(res, Err(TokenValidationError::IncompleteNonce));
}
#[test]
fn invalid_precomputed_key() {
let precomputed_1 = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let precomputed_2 = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
// Seal with precomputed_2, open with precomputed_1
let data = seal_precomputed(b"hello world", &nonce, &precomputed_2);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed_1, data, "b");
assert_eq!(res, Err(TokenValidationError::DecryptionFailure));
}
#[test]
fn invalid_token_data() {
let precomputed = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
let data = seal_precomputed(b"hello world", &nonce, &precomputed);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed, data, "b");
assert_eq!(res, Err(TokenValidationError::InvalidToken));
}
#[test]
fn token_must_have_valid_expiration() {
let precomputed = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
let time = Utc::now() - chrono::Duration::weeks(1);
let data = seal_precomputed(
serde_json::json!({
"expires": time.to_rfc3339(),
"hash": "b",
})
.to_string()
.as_bytes(),
&nonce,
&precomputed,
);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed, data, "b");
assert_eq!(res, Err(TokenValidationError::TokenExpired));
}
#[test]
fn token_must_have_valid_chapter_hash() {
let precomputed = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
let time = Utc::now() + chrono::Duration::weeks(1);
let data = seal_precomputed(
serde_json::json!({
"expires": time.to_rfc3339(),
"hash": "b",
})
.to_string()
.as_bytes(),
&nonce,
&precomputed,
);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed, data, "");
assert_eq!(res, Err(TokenValidationError::InvalidChapterHash));
}
#[test]
fn valid_token_returns_ok() {
let precomputed = {
let (pk, sk) = gen_keypair();
precompute(&pk, &sk)
};
let nonce = gen_nonce();
let time = Utc::now() + chrono::Duration::weeks(1);
let data = seal_precomputed(
serde_json::json!({
"expires": time.to_rfc3339(),
"hash": "b",
})
.to_string()
.as_bytes(),
&nonce,
&precomputed,
);
let data: Vec<u8> = nonce.as_ref().iter().copied().chain(data).collect();
let data = base64::encode_config(data, BASE64_CONFIG);
let res = validate_token(&precomputed, data, "b");
assert!(res.is_ok());
}
}

View File

@ -1,17 +1,19 @@
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use crate::config::{CliArgs, UnstableOptions, OFFLINE_MODE, SEND_SERVER_VERSION, VALIDATE_TOKENS};
use crate::client::HTTP_CLIENT;
use crate::config::{ClientSecret, Config, OFFLINE_MODE};
use crate::ping::{Request, Response, CONTROL_CENTER_PING_URL};
use arc_swap::ArcSwap;
use log::{error, info, warn};
use arc_swap::{ArcSwap, ArcSwapOption};
use once_cell::sync::OnceCell;
use parking_lot::RwLock;
use rustls::sign::{CertifiedKey, SigningKey};
use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::{CertifiedKey, RsaSigningKey, SigningKey};
use rustls::Certificate;
use rustls::{ClientHello, ResolvesServerCert};
use sodiumoxide::crypto::box_::{PrecomputedKey, PRECOMPUTEDKEYBYTES};
use thiserror::Error;
use tracing::{error, info, warn};
use url::Url;
pub struct ServerState {
@ -25,8 +27,10 @@ pub static PREVIOUSLY_PAUSED: AtomicBool = AtomicBool::new(false);
pub static PREVIOUSLY_COMPROMISED: AtomicBool = AtomicBool::new(false);
pub static TLS_PREVIOUSLY_CREATED: OnceCell<ArcSwap<String>> = OnceCell::new();
pub static TLS_SIGNING_KEY: OnceCell<ArcSwap<Box<dyn SigningKey>>> = OnceCell::new();
pub static TLS_CERTS: OnceCell<ArcSwap<Vec<Certificate>>> = OnceCell::new();
static TLS_SIGNING_KEY: OnceCell<ArcSwap<RsaSigningKey>> = OnceCell::new();
static TLS_CERTS: OnceCell<ArcSwap<Vec<Certificate>>> = OnceCell::new();
pub static CERTIFIED_KEY: ArcSwapOption<CertifiedKey> = ArcSwapOption::const_empty();
#[derive(Error, Debug)]
pub enum ServerInitError {
@ -45,18 +49,14 @@ pub enum ServerInitError {
}
impl ServerState {
pub async fn init(secret: &str, config: &CliArgs) -> Result<Self, ServerInitError> {
let resp = reqwest::Client::new()
pub async fn init(secret: &ClientSecret, config: &Config) -> Result<Self, ServerInitError> {
let resp = HTTP_CLIENT
.inner()
.post(CONTROL_CENTER_PING_URL)
.json(&Request::from((secret, config)))
.send()
.await;
if config.enable_server_string {
warn!("Client will send Server header in responses. This is not recommended!");
SEND_SERVER_VERSION.store(true, Ordering::Release);
}
match resp {
Ok(resp) => match resp.json::<Response>().await {
Ok(Response::Ok(mut resp)) => {
@ -64,15 +64,16 @@ impl ServerState {
.token_key
.ok_or(ServerInitError::MissingTokenKey)
.and_then(|key| {
if let Some(key) = base64::decode(&key)
base64::decode(&key)
.ok()
.and_then(|k| PrecomputedKey::from_slice(&k))
{
Ok(key)
} else {
error!("Failed to parse token key: got {}", key);
Err(ServerInitError::KeyParseError(key))
}
.map_or_else(
|| {
error!("Failed to parse token key: got {}", key);
Err(ServerInitError::KeyParseError(key))
},
Ok,
)
})?;
PREVIOUSLY_COMPROMISED.store(resp.compromised, Ordering::Release);
@ -88,26 +89,19 @@ impl ServerState {
if let Some(ref override_url) = config.override_upstream {
resp.image_server = override_url.clone();
warn!("Upstream URL overridden to: {}", resp.image_server);
} else {
}
info!("This client's URL has been set to {}", resp.url);
if config
.unstable_options
.contains(&UnstableOptions::DisableTokenValidation)
{
warn!("Token validation is explicitly disabled!");
} else {
if resp.force_tokens {
info!("This client will validate tokens.");
} else {
info!("This client will not validate tokens.");
}
VALIDATE_TOKENS.store(resp.force_tokens, Ordering::Release);
}
let tls = resp.tls.unwrap();
CERTIFIED_KEY.store(Some(Arc::new(CertifiedKey {
cert: tls.certs.clone(),
key: Arc::clone(&tls.priv_key) as Arc<dyn SigningKey>,
ocsp: None,
sct_list: None,
})));
std::mem::drop(
TLS_PREVIOUSLY_CREATED.set(ArcSwap::from_pointee(tls.created_at)),
);
@ -149,9 +143,10 @@ impl ServerState {
pub fn init_offline() -> Self {
assert!(OFFLINE_MODE.load(Ordering::Acquire));
Self {
precomputed_key: PrecomputedKey::from_slice(&[41; PRECOMPUTEDKEYBYTES]).unwrap(),
image_server: Url::from_file_path("/dev/null").unwrap(),
url: Url::from_str("http://localhost").unwrap(),
precomputed_key: PrecomputedKey::from_slice(&[41; PRECOMPUTEDKEYBYTES])
.expect("expect offline config to work"),
image_server: Url::from_file_path("/dev/null").expect("expect offline config to work"),
url: Url::from_str("http://localhost").expect("expect offline config to work"),
url_overridden: false,
}
}
@ -162,14 +157,9 @@ pub struct RwLockServerState(pub RwLock<ServerState>);
pub struct DynamicServerCert;
impl ResolvesServerCert for DynamicServerCert {
fn resolve(&self, _: ClientHello) -> Option<CertifiedKey> {
fn resolve(&self, _: ClientHello) -> Option<Arc<CertifiedKey>> {
// TODO: wait for actix-web to use a new version of rustls so we can
// remove cloning the certs all the time
Some(CertifiedKey {
cert: TLS_CERTS.get().unwrap().load().as_ref().clone(),
key: TLS_SIGNING_KEY.get().unwrap().load_full(),
ocsp: None,
sct_list: None,
})
CERTIFIED_KEY.load_full()
}
}

View File

@ -1,20 +1,24 @@
use log::{info, warn};
#![cfg(not(tarpaulin_include))]
use reqwest::StatusCode;
use serde::Serialize;
use tracing::{info, warn};
const CONTROL_CENTER_STOP_URL: &str = "https://api.mangadex.network/ping";
use crate::client::HTTP_CLIENT;
use crate::config::ClientSecret;
const CONTROL_CENTER_STOP_URL: &str = "https://api.mangadex.network/stop";
#[derive(Serialize)]
struct StopRequest<'a> {
secret: &'a str,
secret: &'a ClientSecret,
}
pub async fn send_stop(secret: &str) {
let request = StopRequest { secret };
let client = reqwest::Client::new();
match client
pub async fn send_stop(secret: &ClientSecret) {
match HTTP_CLIENT
.inner()
.post(CONTROL_CENTER_STOP_URL)
.json(&request)
.json(&StopRequest { secret })
.send()
.await
{
@ -28,3 +32,17 @@ pub async fn send_stop(secret: &str) {
Err(e) => warn!("Got error while sending stop message: {}", e),
}
}
#[cfg(test)]
mod stop {
use super::CONTROL_CENTER_STOP_URL;
#[test]
fn stop_url_does_not_have_ping_in_url() {
// This looks like a dumb test, yes, but it ensures that clients don't
// get marked compromised because apparently just sending a json obj
// with just the secret is acceptable to the ping endpoint, which messes
// up non-trivial client configs.
assert!(!CONTROL_CENTER_STOP_URL.contains("ping"))
}
}

99
src/units.rs Normal file
View File

@ -0,0 +1,99 @@
use std::fmt::Display;
use std::num::{NonZeroU16, NonZeroU64, ParseIntError};
use std::str::FromStr;
use serde::{Deserialize, Serialize};
/// Wrapper type for a port number.
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub struct Port(NonZeroU16);
impl Port {
pub const fn get(self) -> u16 {
self.0.get()
}
pub fn new(amt: u16) -> Option<Self> {
NonZeroU16::new(amt).map(Self)
}
}
impl Default for Port {
fn default() -> Self {
Self(unsafe { NonZeroU16::new_unchecked(443) })
}
}
impl FromStr for Port {
type Err = <NonZeroU16 as FromStr>::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
NonZeroU16::from_str(s).map(Self)
}
}
impl Display for Port {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[derive(Copy, Clone, Deserialize, Default, Debug, Hash, Eq, PartialEq)]
pub struct Mebibytes(usize);
impl Mebibytes {
#[cfg(test)]
pub fn new(size: usize) -> Self {
Self(size)
}
}
impl FromStr for Mebibytes {
type Err = ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.parse::<usize>().map(Self)
}
}
#[derive(Serialize, Debug)]
pub struct Bytes(pub usize);
impl Bytes {
pub const fn get(&self) -> usize {
self.0
}
}
impl From<Mebibytes> for Bytes {
fn from(mib: Mebibytes) -> Self {
Self(mib.0 << 20)
}
}
#[derive(Copy, Clone, Deserialize, Debug, Hash, Eq, PartialEq)]
pub struct KilobitsPerSecond(NonZeroU64);
impl KilobitsPerSecond {
#[cfg(test)]
pub fn new(size: u64) -> Option<Self> {
NonZeroU64::new(size).map(Self)
}
}
impl FromStr for KilobitsPerSecond {
type Err = ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.parse::<NonZeroU64>().map(Self)
}
}
#[derive(Copy, Clone, Serialize, Debug, Hash, Eq, PartialEq)]
pub struct BytesPerSecond(NonZeroU64);
impl From<KilobitsPerSecond> for BytesPerSecond {
fn from(kbps: KilobitsPerSecond) -> Self {
Self(unsafe { NonZeroU64::new_unchecked(kbps.0.get() * 125) })
}
}