Migrate to axum
This commit is contained in:
parent
411854385c
commit
ce68f4dd42
7 changed files with 497 additions and 1338 deletions
1387
Cargo.lock
generated
1387
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -10,14 +10,17 @@ repository = "https://github.com/edward-shen/bunbun"
|
||||||
exclude = ["/aux/"]
|
exclude = ["/aux/"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
actix-web = "3"
|
anyhow = "1"
|
||||||
|
arc-swap = "1"
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
axum = "0.5"
|
||||||
clap = { version = "3", features = ["wrap_help", "derive", "cargo"] }
|
clap = { version = "3", features = ["wrap_help", "derive", "cargo"] }
|
||||||
dirs = "4"
|
dirs = "4"
|
||||||
handlebars = "4"
|
handlebars = "4"
|
||||||
hotwatch = "0.4"
|
hotwatch = "0.4"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
percent-encoding = "2"
|
percent-encoding = "2"
|
||||||
serde = "1"
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_yaml = "0.8"
|
serde_yaml = "0.8"
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
simple_logger = "2"
|
simple_logger = "2"
|
||||||
|
|
119
src/config.rs
119
src/config.rs
|
@ -60,6 +60,32 @@ impl FromStr for Route {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<String> for Route {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
Self {
|
||||||
|
route_type: get_route_type(&s),
|
||||||
|
path: s,
|
||||||
|
hidden: false,
|
||||||
|
description: None,
|
||||||
|
min_args: None,
|
||||||
|
max_args: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&'static str> for Route {
|
||||||
|
fn from(s: &'static str) -> Self {
|
||||||
|
Self {
|
||||||
|
route_type: get_route_type(s),
|
||||||
|
path: s.to_string(),
|
||||||
|
hidden: false,
|
||||||
|
description: None,
|
||||||
|
min_args: None,
|
||||||
|
max_args: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Deserialization of the route string into the enum requires us to figure out
|
/// Deserialization of the route string into the enum requires us to figure out
|
||||||
/// whether or not the string is valid to run as an executable or not. To
|
/// whether or not the string is valid to run as an executable or not. To
|
||||||
/// determine this, we simply check if it exists on disk or assume that it's a
|
/// determine this, we simply check if it exists on disk or assume that it's a
|
||||||
|
@ -147,8 +173,7 @@ impl<'de> Deserialize<'de> for Route {
|
||||||
{
|
{
|
||||||
return Err(de::Error::invalid_value(
|
return Err(de::Error::invalid_value(
|
||||||
Unexpected::Other(&format!(
|
Unexpected::Other(&format!(
|
||||||
"argument count range {} to {}",
|
"argument count range {min_args} to {max_args}",
|
||||||
min_args, max_args
|
|
||||||
)),
|
)),
|
||||||
&"a valid argument count range",
|
&"a valid argument count range",
|
||||||
));
|
));
|
||||||
|
@ -179,12 +204,12 @@ impl std::fmt::Display for Route {
|
||||||
route_type: RouteType::External,
|
route_type: RouteType::External,
|
||||||
path,
|
path,
|
||||||
..
|
..
|
||||||
} => write!(f, "raw ({})", path),
|
} => write!(f, "raw ({path})"),
|
||||||
Self {
|
Self {
|
||||||
route_type: RouteType::Internal,
|
route_type: RouteType::Internal,
|
||||||
path,
|
path,
|
||||||
..
|
..
|
||||||
} => write!(f, "file ({})", path),
|
} => write!(f, "file ({path})"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -192,10 +217,10 @@ impl std::fmt::Display for Route {
|
||||||
/// Classifies the path depending on if the there exists a local file.
|
/// Classifies the path depending on if the there exists a local file.
|
||||||
fn get_route_type(path: &str) -> RouteType {
|
fn get_route_type(path: &str) -> RouteType {
|
||||||
if std::path::Path::new(path).exists() {
|
if std::path::Path::new(path).exists() {
|
||||||
debug!("Parsed {} as a valid local path.", path);
|
debug!("Parsed {path} as a valid local path.");
|
||||||
RouteType::Internal
|
RouteType::Internal
|
||||||
} else {
|
} else {
|
||||||
debug!("{} does not exist on disk, assuming web path.", path);
|
debug!("{path} does not exist on disk, assuming web path.");
|
||||||
RouteType::External
|
RouteType::External
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -245,16 +270,15 @@ pub fn get_config_data() -> Result<ConfigData, BunBunError> {
|
||||||
let file = OpenOptions::new().read(true).open(location.clone());
|
let file = OpenOptions::new().read(true).open(location.clone());
|
||||||
match file {
|
match file {
|
||||||
Ok(file) => {
|
Ok(file) => {
|
||||||
debug!("Found file at {:?}.", location);
|
debug!("Found file at {location:?}.");
|
||||||
return Ok(ConfigData {
|
return Ok(ConfigData {
|
||||||
path: location.clone(),
|
path: location.clone(),
|
||||||
file,
|
file,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Err(e) => debug!(
|
Err(e) => {
|
||||||
"Tried to read '{:?}' but failed due to error: {}",
|
debug!("Tried to read '{location:?}' but failed due to error: {e}",)
|
||||||
location, e
|
}
|
||||||
),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -270,7 +294,7 @@ pub fn get_config_data() -> Result<ConfigData, BunBunError> {
|
||||||
.open(location.clone());
|
.open(location.clone());
|
||||||
match file {
|
match file {
|
||||||
Ok(mut file) => {
|
Ok(mut file) => {
|
||||||
info!("Creating new config file at {:?}.", location);
|
info!("Creating new config file at {location:?}.");
|
||||||
file.write_all(DEFAULT_CONFIG)?;
|
file.write_all(DEFAULT_CONFIG)?;
|
||||||
|
|
||||||
let file = OpenOptions::new().read(true).open(location.clone())?;
|
let file = OpenOptions::new().read(true).open(location.clone())?;
|
||||||
|
@ -280,8 +304,7 @@ pub fn get_config_data() -> Result<ConfigData, BunBunError> {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Err(e) => debug!(
|
Err(e) => debug!(
|
||||||
"Tried to open a new file at '{:?}' but failed due to error: {}",
|
"Tried to open a new file at '{location:?}' but failed due to error: {e}",
|
||||||
location, e
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -329,90 +352,96 @@ pub fn read_config(
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod route {
|
mod route {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use anyhow::{Context, Result};
|
||||||
use serde_yaml::{from_str, to_string};
|
use serde_yaml::{from_str, to_string};
|
||||||
use tempfile::NamedTempFile;
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn deserialize_relative_path() {
|
fn deserialize_relative_path() -> Result<()> {
|
||||||
let tmpfile = NamedTempFile::new_in(".").unwrap();
|
let tmpfile = NamedTempFile::new_in(".")?;
|
||||||
let path = format!("{}", tmpfile.path().display());
|
let path = format!("{}", tmpfile.path().display());
|
||||||
let path = path.get(path.rfind(".").unwrap()..).unwrap();
|
let path = path
|
||||||
|
.get(path.rfind(".").context("While finding .")?..)
|
||||||
|
.context("While getting the path")?;
|
||||||
let path = std::path::Path::new(path);
|
let path = std::path::Path::new(path);
|
||||||
assert!(path.is_relative());
|
assert!(path.is_relative());
|
||||||
let path = path.to_str().unwrap();
|
let path = path.to_str().context("While stringifying path")?;
|
||||||
assert_eq!(
|
assert_eq!(from_str::<Route>(path)?, Route::from_str(path)?);
|
||||||
from_str::<Route>(path).unwrap(),
|
Ok(())
|
||||||
Route::from_str(path).unwrap()
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn deserialize_absolute_path() {
|
fn deserialize_absolute_path() -> Result<()> {
|
||||||
let tmpfile = NamedTempFile::new().unwrap();
|
let tmpfile = NamedTempFile::new()?;
|
||||||
let path = format!("{}", tmpfile.path().display());
|
let path = format!("{}", tmpfile.path().display());
|
||||||
assert!(tmpfile.path().is_absolute());
|
assert!(tmpfile.path().is_absolute());
|
||||||
assert_eq!(
|
assert_eq!(from_str::<Route>(&path)?, Route::from_str(&path)?);
|
||||||
from_str::<Route>(&path).unwrap(),
|
|
||||||
Route::from_str(&path).unwrap()
|
Ok(())
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn deserialize_http_path() {
|
fn deserialize_http_path() -> Result<()> {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
from_str::<Route>("http://google.com").unwrap(),
|
from_str::<Route>("http://google.com")?,
|
||||||
Route::from_str("http://google.com").unwrap()
|
Route::from_str("http://google.com")?
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn deserialize_https_path() {
|
fn deserialize_https_path() -> Result<()> {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
from_str::<Route>("https://google.com").unwrap(),
|
from_str::<Route>("https://google.com")?,
|
||||||
Route::from_str("https://google.com").unwrap()
|
Route::from_str("https://google.com")?
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn serialize() {
|
fn serialize() -> Result<()> {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
&to_string(&Route::from_str("hello world").unwrap()).unwrap(),
|
&to_string(&Route::from_str("hello world")?)?,
|
||||||
"---\nroute_type: External\npath: hello world\nhidden: false\ndescription: ~\nmin_args: ~\nmax_args: ~\n"
|
"---\nroute_type: External\npath: hello world\nhidden: false\ndescription: ~\nmin_args: ~\nmax_args: ~\n"
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod read_config {
|
mod read_config {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use anyhow::Result;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn empty_file() {
|
fn empty_file() -> Result<()> {
|
||||||
let config_file = tempfile::tempfile().unwrap();
|
let config_file = tempfile::tempfile()?;
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
read_config(config_file, false),
|
read_config(config_file, false),
|
||||||
Err(BunBunError::ZeroByteConfig)
|
Err(BunBunError::ZeroByteConfig)
|
||||||
));
|
));
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn config_too_large() {
|
fn config_too_large() -> Result<()> {
|
||||||
let mut config_file = tempfile::tempfile().unwrap();
|
let mut config_file = tempfile::tempfile()?;
|
||||||
let size_to_write = (LARGE_FILE_SIZE_THRESHOLD + 1) as usize;
|
let size_to_write = (LARGE_FILE_SIZE_THRESHOLD + 1) as usize;
|
||||||
config_file.write(&[0].repeat(size_to_write)).unwrap();
|
config_file.write(&[0].repeat(size_to_write))?;
|
||||||
match read_config(config_file, false) {
|
match read_config(config_file, false) {
|
||||||
Err(BunBunError::ConfigTooLarge(size))
|
Err(BunBunError::ConfigTooLarge(size))
|
||||||
if size as usize == size_to_write => {}
|
if size as usize == size_to_write => {}
|
||||||
Err(BunBunError::ConfigTooLarge(size)) => {
|
Err(BunBunError::ConfigTooLarge(size)) => {
|
||||||
panic!("Mismatched size: {} != {}", size, size_to_write)
|
panic!("Mismatched size: {} != {}", size, size_to_write)
|
||||||
}
|
}
|
||||||
res => panic!("Wrong result, got {:#?}", res),
|
res => panic!("Wrong result, got {res:#?}"),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn valid_config() {
|
fn valid_config() -> Result<()> {
|
||||||
let config_file = File::open("bunbun.default.yaml").unwrap();
|
assert!(read_config(File::open("bunbun.default.yaml")?, false).is_ok());
|
||||||
assert!(read_config(config_file, false).is_ok());
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,9 +27,9 @@ impl fmt::Display for BunBunError {
|
||||||
Self::CustomProgram(msg) => write!(f, "{}", msg),
|
Self::CustomProgram(msg) => write!(f, "{}", msg),
|
||||||
Self::NoValidConfigPath => write!(f, "No valid config path was found!"),
|
Self::NoValidConfigPath => write!(f, "No valid config path was found!"),
|
||||||
Self::InvalidConfigPath(path, reason) => {
|
Self::InvalidConfigPath(path, reason) => {
|
||||||
write!(f, "Failed to access {:?}: {}", path, reason)
|
write!(f, "Failed to access {path:?}: {reason}")
|
||||||
}
|
}
|
||||||
Self::ConfigTooLarge(size) => write!(f, "The config file was too large ({} bytes)! Pass in --large-config to bypass this check.", size),
|
Self::ConfigTooLarge(size) => write!(f, "The config file was too large ({size} bytes)! Pass in --large-config to bypass this check."),
|
||||||
Self::ZeroByteConfig => write!(f, "The config provided reported a size of 0 bytes. Please check your config path!"),
|
Self::ZeroByteConfig => write!(f, "The config provided reported a size of 0 bytes. Please check your config path!"),
|
||||||
Self::JsonParse(e) => e.fmt(f),
|
Self::JsonParse(e) => e.fmt(f),
|
||||||
}
|
}
|
||||||
|
|
108
src/main.rs
108
src/main.rs
|
@ -9,16 +9,19 @@ use crate::config::{
|
||||||
get_config_data, load_custom_path_config, read_config, ConfigData, Route,
|
get_config_data, load_custom_path_config, read_config, ConfigData, Route,
|
||||||
RouteGroup,
|
RouteGroup,
|
||||||
};
|
};
|
||||||
use actix_web::{middleware::Logger, App, HttpServer};
|
use anyhow::Result;
|
||||||
|
use arc_swap::ArcSwap;
|
||||||
|
use axum::routing::get;
|
||||||
|
use axum::{Extension, Router};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use error::BunBunError;
|
use error::BunBunError;
|
||||||
use handlebars::{Handlebars, TemplateError};
|
use handlebars::Handlebars;
|
||||||
use hotwatch::{Event, Hotwatch};
|
use hotwatch::{Event, Hotwatch};
|
||||||
use log::{debug, error, info, trace, warn};
|
use log::{debug, info, trace, warn};
|
||||||
use simple_logger::SimpleLogger;
|
use simple_logger::SimpleLogger;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
mod cli;
|
mod cli;
|
||||||
|
@ -39,20 +42,9 @@ pub struct State {
|
||||||
routes: HashMap<String, Route>,
|
routes: HashMap<String, Route>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::main]
|
#[tokio::main]
|
||||||
#[cfg(not(tarpaulin_include))]
|
#[cfg(not(tarpaulin_include))]
|
||||||
async fn main() {
|
async fn main() -> Result<()> {
|
||||||
std::process::exit(match run().await {
|
|
||||||
Ok(_) => 0,
|
|
||||||
Err(e) => {
|
|
||||||
error!("{}", e);
|
|
||||||
1
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(tarpaulin_include))]
|
|
||||||
async fn run() -> Result<(), BunBunError> {
|
|
||||||
let opts = cli::Opts::parse();
|
let opts = cli::Opts::parse();
|
||||||
|
|
||||||
init_logger(opts.verbose, opts.quiet)?;
|
init_logger(opts.verbose, opts.quiet)?;
|
||||||
|
@ -63,7 +55,7 @@ async fn run() -> Result<(), BunBunError> {
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
let conf = read_config(conf_data.file.try_clone()?, opts.large_config)?;
|
let conf = read_config(conf_data.file.try_clone()?, opts.large_config)?;
|
||||||
let state = Arc::from(RwLock::new(State {
|
let state = Arc::from(ArcSwap::from_pointee(State {
|
||||||
public_address: conf.public_address,
|
public_address: conf.public_address,
|
||||||
default_route: conf.default_route,
|
default_route: conf.default_route,
|
||||||
routes: cache_routes(&conf.groups),
|
routes: cache_routes(&conf.groups),
|
||||||
|
@ -71,27 +63,19 @@ async fn run() -> Result<(), BunBunError> {
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Cannot be named _ or Rust will immediately drop it.
|
// Cannot be named _ or Rust will immediately drop it.
|
||||||
let _watch = start_watch(Arc::clone(&state), conf_data, opts.large_config)?;
|
let _watch = start_watch(Arc::clone(&state), conf_data, opts.large_config);
|
||||||
|
|
||||||
HttpServer::new(move || {
|
let app = Router::new()
|
||||||
let templates = match compile_templates() {
|
.route("/", get(routes::index))
|
||||||
Ok(templates) => templates,
|
.route("/bunbunsearch.xml", get(routes::opensearch))
|
||||||
// This implies a template error, which should be a compile time error. If
|
.route("/ls", get(routes::list))
|
||||||
// we reach here then the release is very broken.
|
.route("/hop", get(routes::hop))
|
||||||
Err(e) => unreachable!("Failed to compile templates: {}", e),
|
.layer(Extension(compile_templates()?))
|
||||||
};
|
.layer(Extension(state));
|
||||||
App::new()
|
|
||||||
.data(Arc::clone(&state))
|
axum::Server::bind(&conf.bind_address.parse()?)
|
||||||
.app_data(templates)
|
.serve(app.into_make_service())
|
||||||
.wrap(Logger::default())
|
.await?;
|
||||||
.service(routes::hop)
|
|
||||||
.service(routes::list)
|
|
||||||
.service(routes::index)
|
|
||||||
.service(routes::opensearch)
|
|
||||||
})
|
|
||||||
.bind(&conf.bind_address)?
|
|
||||||
.run()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -100,10 +84,7 @@ async fn run() -> Result<(), BunBunError> {
|
||||||
/// in. Usually, these values are mutually exclusive, that is, if the number of
|
/// in. Usually, these values are mutually exclusive, that is, if the number of
|
||||||
/// verbose flags is non-zero then the quiet flag is zero, and vice versa.
|
/// verbose flags is non-zero then the quiet flag is zero, and vice versa.
|
||||||
#[cfg(not(tarpaulin_include))]
|
#[cfg(not(tarpaulin_include))]
|
||||||
fn init_logger(
|
fn init_logger(num_verbose_flags: u8, num_quiet_flags: u8) -> Result<()> {
|
||||||
num_verbose_flags: u8,
|
|
||||||
num_quiet_flags: u8,
|
|
||||||
) -> Result<(), BunBunError> {
|
|
||||||
let log_level =
|
let log_level =
|
||||||
match min(num_verbose_flags, 3) as i8 - min(num_quiet_flags, 2) as i8 {
|
match min(num_verbose_flags, 3) as i8 - min(num_quiet_flags, 2) as i8 {
|
||||||
-2 => None,
|
-2 => None,
|
||||||
|
@ -143,7 +124,7 @@ fn cache_routes(groups: &[RouteGroup]) -> HashMap<String, Route> {
|
||||||
/// Returns an instance with all pre-generated templates included into the
|
/// Returns an instance with all pre-generated templates included into the
|
||||||
/// binary. This allows for users to have a portable binary without needed the
|
/// binary. This allows for users to have a portable binary without needed the
|
||||||
/// templates at runtime.
|
/// templates at runtime.
|
||||||
fn compile_templates() -> Result<Handlebars<'static>, TemplateError> {
|
fn compile_templates() -> Result<Handlebars<'static>> {
|
||||||
let mut handlebars = Handlebars::new();
|
let mut handlebars = Handlebars::new();
|
||||||
handlebars.set_strict_mode(true);
|
handlebars.set_strict_mode(true);
|
||||||
handlebars.register_partial("bunbun_version", env!("CARGO_PKG_VERSION"))?;
|
handlebars.register_partial("bunbun_version", env!("CARGO_PKG_VERSION"))?;
|
||||||
|
@ -176,10 +157,10 @@ fn compile_templates() -> Result<Handlebars<'static>, TemplateError> {
|
||||||
/// watches.
|
/// watches.
|
||||||
#[cfg(not(tarpaulin_include))]
|
#[cfg(not(tarpaulin_include))]
|
||||||
fn start_watch(
|
fn start_watch(
|
||||||
state: Arc<RwLock<State>>,
|
state: Arc<ArcSwap<State>>,
|
||||||
config_data: ConfigData,
|
config_data: ConfigData,
|
||||||
large_config: bool,
|
large_config: bool,
|
||||||
) -> Result<Hotwatch, BunBunError> {
|
) -> Result<Hotwatch> {
|
||||||
let mut watch = Hotwatch::new_with_custom_delay(Duration::from_millis(500))?;
|
let mut watch = Hotwatch::new_with_custom_delay(Duration::from_millis(500))?;
|
||||||
let ConfigData { path, mut file } = config_data;
|
let ConfigData { path, mut file } = config_data;
|
||||||
let watch_result = watch.watch(&path, move |e: Event| {
|
let watch_result = watch.watch(&path, move |e: Event| {
|
||||||
|
@ -193,33 +174,32 @@ fn start_watch(
|
||||||
match e {
|
match e {
|
||||||
Event::Write(_) | Event::Create(_) => {
|
Event::Write(_) | Event::Create(_) => {
|
||||||
trace!("Grabbing writer lock on state...");
|
trace!("Grabbing writer lock on state...");
|
||||||
let mut state =
|
|
||||||
state.write().expect("Failed to get write lock on state");
|
|
||||||
trace!("Obtained writer lock on state!");
|
trace!("Obtained writer lock on state!");
|
||||||
match read_config(
|
match read_config(
|
||||||
file.try_clone().expect("Failed to clone file handle"),
|
file.try_clone().expect("Failed to clone file handle"),
|
||||||
large_config,
|
large_config,
|
||||||
) {
|
) {
|
||||||
Ok(conf) => {
|
Ok(conf) => {
|
||||||
state.public_address = conf.public_address;
|
state.store(Arc::new(State {
|
||||||
state.default_route = conf.default_route;
|
public_address: conf.public_address,
|
||||||
state.routes = cache_routes(&conf.groups);
|
default_route: conf.default_route,
|
||||||
state.groups = conf.groups;
|
routes: cache_routes(&conf.groups),
|
||||||
|
groups: conf.groups,
|
||||||
|
}));
|
||||||
info!("Successfully updated active state");
|
info!("Successfully updated active state");
|
||||||
}
|
}
|
||||||
Err(e) => warn!("Failed to update config file: {}", e),
|
Err(e) => warn!("Failed to update config file: {e}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => debug!("Saw event {:#?} but ignored it", e),
|
_ => debug!("Saw event {e:#?} but ignored it"),
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
match watch_result {
|
match watch_result {
|
||||||
Ok(_) => info!("Watcher is now watching {:?}", &path),
|
Ok(_) => info!("Watcher is now watching {path:?}"),
|
||||||
Err(e) => warn!(
|
Err(e) => {
|
||||||
"Couldn't watch {:?}: {}. Changes to this file won't be seen!",
|
warn!("Couldn't watch {path:?}: {e}. Changes to this file won't be seen!",)
|
||||||
&path, e
|
}
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(watch)
|
Ok(watch)
|
||||||
|
@ -228,9 +208,10 @@ fn start_watch(
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod init_logger {
|
mod init_logger {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use anyhow::Result;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn defaults_to_warn() -> Result<(), BunBunError> {
|
fn defaults_to_warn() -> Result<()> {
|
||||||
init_logger(0, 0)?;
|
init_logger(0, 0)?;
|
||||||
assert_eq!(log::max_level(), log::Level::Warn);
|
assert_eq!(log::max_level(), log::Level::Warn);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -242,7 +223,7 @@ mod init_logger {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
fn caps_to_2_when_log_level_is_lt_2() -> Result<(), BunBunError> {
|
fn caps_to_2_when_log_level_is_lt_2() -> Result<()> {
|
||||||
init_logger(0, 3)?;
|
init_logger(0, 3)?;
|
||||||
assert_eq!(log::max_level(), log::LevelFilter::Off);
|
assert_eq!(log::max_level(), log::LevelFilter::Off);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -250,7 +231,7 @@ mod init_logger {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
fn caps_to_3_when_log_level_is_gt_3() -> Result<(), BunBunError> {
|
fn caps_to_3_when_log_level_is_gt_3() -> Result<()> {
|
||||||
init_logger(4, 0)?;
|
init_logger(4, 0)?;
|
||||||
assert_eq!(log::max_level(), log::Level::Trace);
|
assert_eq!(log::max_level(), log::Level::Trace);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -261,15 +242,14 @@ mod init_logger {
|
||||||
mod cache_routes {
|
mod cache_routes {
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::iter::FromIterator;
|
use std::iter::FromIterator;
|
||||||
use std::str::FromStr;
|
|
||||||
|
|
||||||
fn generate_external_routes(
|
fn generate_external_routes(
|
||||||
routes: &[(&str, &str)],
|
routes: &[(&'static str, &'static str)],
|
||||||
) -> HashMap<String, Route> {
|
) -> HashMap<String, Route> {
|
||||||
HashMap::from_iter(
|
HashMap::from_iter(
|
||||||
routes
|
routes
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|kv| (kv.0.into(), Route::from_str(kv.1).unwrap())),
|
.map(|(key, value)| ((*key).to_owned(), Route::from(*value))),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
191
src/routes.rs
191
src/routes.rs
|
@ -1,8 +1,11 @@
|
||||||
use crate::config::{Route as ConfigRoute, RouteType};
|
use crate::config::{Route as ConfigRoute, RouteType};
|
||||||
use crate::{template_args, BunBunError, Route, State};
|
use crate::{template_args, BunBunError, Route, State};
|
||||||
use actix_web::web::{Data, Query};
|
use arc_swap::ArcSwap;
|
||||||
use actix_web::{get, http::header};
|
use axum::body::{boxed, Bytes, Full};
|
||||||
use actix_web::{HttpRequest, HttpResponse, Responder};
|
use axum::extract::Query;
|
||||||
|
use axum::http::{header, StatusCode};
|
||||||
|
use axum::response::{Html, IntoResponse, Response};
|
||||||
|
use axum::Extension;
|
||||||
use handlebars::Handlebars;
|
use handlebars::Handlebars;
|
||||||
use log::{debug, error};
|
use log::{debug, error};
|
||||||
use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
|
use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
|
||||||
|
@ -10,7 +13,7 @@ use serde::Deserialize;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// https://url.spec.whatwg.org/#fragment-percent-encode-set
|
/// https://url.spec.whatwg.org/#fragment-percent-encode-set
|
||||||
const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS
|
const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS
|
||||||
|
@ -24,71 +27,62 @@ const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS
|
||||||
.add(b'#') // Interpreted as a hyperlink section target
|
.add(b'#') // Interpreted as a hyperlink section target
|
||||||
.add(b'\'');
|
.add(b'\'');
|
||||||
|
|
||||||
type StateData = Data<Arc<RwLock<State>>>;
|
pub async fn index(
|
||||||
|
Extension(data): Extension<Arc<ArcSwap<State>>>,
|
||||||
#[get("/")]
|
Extension(handlebars): Extension<Handlebars<'static>>,
|
||||||
pub async fn index(data: StateData, req: HttpRequest) -> impl Responder {
|
) -> impl IntoResponse {
|
||||||
let data = data.read().unwrap();
|
handlebars
|
||||||
HttpResponse::Ok()
|
.render(
|
||||||
.set_header(header::CONTENT_TYPE, "text/html; charset=utf-8")
|
"index",
|
||||||
.body(
|
&template_args::hostname(&data.load().public_address),
|
||||||
req
|
|
||||||
.app_data::<Handlebars>()
|
|
||||||
.unwrap()
|
|
||||||
.render(
|
|
||||||
"index",
|
|
||||||
&template_args::hostname(data.public_address.clone()),
|
|
||||||
)
|
|
||||||
.unwrap(),
|
|
||||||
)
|
)
|
||||||
|
.map(Html)
|
||||||
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/bunbunsearch.xml")]
|
pub async fn opensearch(
|
||||||
pub async fn opensearch(data: StateData, req: HttpRequest) -> impl Responder {
|
Extension(data): Extension<Arc<ArcSwap<State>>>,
|
||||||
let data = data.read().unwrap();
|
Extension(handlebars): Extension<Handlebars<'static>>,
|
||||||
HttpResponse::Ok()
|
) -> impl IntoResponse {
|
||||||
.header(
|
handlebars
|
||||||
header::CONTENT_TYPE,
|
.render(
|
||||||
"application/opensearchdescription+xml",
|
"opensearch",
|
||||||
)
|
&template_args::hostname(&data.load().public_address),
|
||||||
.body(
|
|
||||||
req
|
|
||||||
.app_data::<Handlebars>()
|
|
||||||
.unwrap()
|
|
||||||
.render(
|
|
||||||
"opensearch",
|
|
||||||
&template_args::hostname(data.public_address.clone()),
|
|
||||||
)
|
|
||||||
.unwrap(),
|
|
||||||
)
|
)
|
||||||
|
.map(|body| {
|
||||||
|
(
|
||||||
|
StatusCode::OK,
|
||||||
|
[(
|
||||||
|
header::CONTENT_TYPE,
|
||||||
|
"application/opensearchdescription+xml",
|
||||||
|
)],
|
||||||
|
body,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/ls")]
|
pub async fn list(
|
||||||
pub async fn list(data: StateData, req: HttpRequest) -> impl Responder {
|
Extension(data): Extension<Arc<ArcSwap<State>>>,
|
||||||
let data = data.read().unwrap();
|
Extension(handlebars): Extension<Handlebars<'static>>,
|
||||||
HttpResponse::Ok()
|
) -> impl IntoResponse {
|
||||||
.set_header(header::CONTENT_TYPE, "text/html; charset=utf-8")
|
handlebars
|
||||||
.body(
|
.render("list", &data.load().groups)
|
||||||
req
|
.map(Html)
|
||||||
.app_data::<Handlebars>()
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
.unwrap()
|
|
||||||
.render("list", &data.groups)
|
|
||||||
.unwrap(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize, Debug)]
|
||||||
pub struct SearchQuery {
|
pub struct SearchQuery {
|
||||||
to: String,
|
to: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/hop")]
|
|
||||||
pub async fn hop(
|
pub async fn hop(
|
||||||
data: StateData,
|
Extension(data): Extension<Arc<ArcSwap<State>>>,
|
||||||
req: HttpRequest,
|
Extension(handlebars): Extension<Handlebars<'static>>,
|
||||||
query: Query<SearchQuery>,
|
Query(query): Query<SearchQuery>,
|
||||||
) -> impl Responder {
|
) -> impl IntoResponse {
|
||||||
let data = data.read().unwrap();
|
let data = data.load();
|
||||||
|
|
||||||
match resolve_hop(&query.to, &data.routes, &data.default_route) {
|
match resolve_hop(&query.to, &data.routes, &data.default_route) {
|
||||||
RouteResolution::Resolved { route: path, args } => {
|
RouteResolution::Resolved { route: path, args } => {
|
||||||
|
@ -106,29 +100,36 @@ pub async fn hop(
|
||||||
};
|
};
|
||||||
|
|
||||||
match resolved_template {
|
match resolved_template {
|
||||||
Ok(HopAction::Redirect(path)) => HttpResponse::Found()
|
Ok(HopAction::Redirect(path)) => Response::builder()
|
||||||
.header(
|
.status(StatusCode::FOUND)
|
||||||
header::LOCATION,
|
.header(header::LOCATION, &path)
|
||||||
req
|
.body(boxed(Full::from(
|
||||||
.app_data::<Handlebars>()
|
handlebars
|
||||||
.unwrap()
|
|
||||||
.render_template(
|
.render_template(
|
||||||
std::str::from_utf8(path.as_bytes()).unwrap(),
|
&path,
|
||||||
&template_args::query(
|
&template_args::query(utf8_percent_encode(
|
||||||
utf8_percent_encode(&args, FRAGMENT_ENCODE_SET).to_string(),
|
&args,
|
||||||
),
|
FRAGMENT_ENCODE_SET,
|
||||||
|
)),
|
||||||
)
|
)
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
)
|
))),
|
||||||
.finish(),
|
Ok(HopAction::Body(body)) => Response::builder()
|
||||||
Ok(HopAction::Body(body)) => HttpResponse::Ok().body(body),
|
.status(StatusCode::OK)
|
||||||
|
.body(boxed(Full::new(Bytes::from(body)))),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to redirect user for {}: {}", path, e);
|
error!("Failed to redirect user for {}: {}", path, e);
|
||||||
HttpResponse::InternalServerError().body("Something went wrong :(\n")
|
Response::builder()
|
||||||
|
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
|
.body(boxed(Full::from("Something went wrong :(\n")))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
.unwrap()
|
||||||
}
|
}
|
||||||
RouteResolution::Unresolved => HttpResponse::NotFound().body("not found"),
|
RouteResolution::Unresolved => Response::builder()
|
||||||
|
.status(StatusCode::NOT_FOUND)
|
||||||
|
.body(boxed(Full::from("not found\n")))
|
||||||
|
.unwrap(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,6 +237,7 @@ fn resolve_path(path: PathBuf, args: &str) -> Result<HopAction, BunBunError> {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod resolve_hop {
|
mod resolve_hop {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use anyhow::Result;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
fn generate_route_result<'a>(
|
fn generate_route_result<'a>(
|
||||||
|
@ -269,51 +271,45 @@ mod resolve_hop {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn only_default_routes_some_default_yields_default_hop() {
|
fn only_default_routes_some_default_yields_default_hop() -> Result<()> {
|
||||||
let mut map: HashMap<String, Route> = HashMap::new();
|
let mut map: HashMap<String, Route> = HashMap::new();
|
||||||
map.insert(
|
map.insert("google".into(), Route::from_str("https://example.com")?);
|
||||||
"google".into(),
|
|
||||||
Route::from_str("https://example.com").unwrap(),
|
|
||||||
);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
resolve_hop("hello world", &map, &Some(String::from("google"))),
|
resolve_hop("hello world", &map, &Some(String::from("google"))),
|
||||||
generate_route_result(
|
generate_route_result(
|
||||||
&Route::from_str("https://example.com").unwrap(),
|
&Route::from_str("https://example.com")?,
|
||||||
"hello world"
|
"hello world"
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn non_default_routes_some_default_yields_non_default_hop() {
|
fn non_default_routes_some_default_yields_non_default_hop() -> Result<()> {
|
||||||
let mut map: HashMap<String, Route> = HashMap::new();
|
let mut map: HashMap<String, Route> = HashMap::new();
|
||||||
map.insert(
|
map.insert("google".into(), Route::from_str("https://example.com")?);
|
||||||
"google".into(),
|
|
||||||
Route::from_str("https://example.com").unwrap(),
|
|
||||||
);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
resolve_hop("google hello world", &map, &Some(String::from("a"))),
|
resolve_hop("google hello world", &map, &Some(String::from("a"))),
|
||||||
generate_route_result(
|
generate_route_result(
|
||||||
&Route::from_str("https://example.com").unwrap(),
|
&Route::from_str("https://example.com")?,
|
||||||
"hello world"
|
"hello world"
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn non_default_routes_no_default_yields_non_default_hop() {
|
fn non_default_routes_no_default_yields_non_default_hop() -> Result<()> {
|
||||||
let mut map: HashMap<String, Route> = HashMap::new();
|
let mut map: HashMap<String, Route> = HashMap::new();
|
||||||
map.insert(
|
map.insert("google".into(), Route::from_str("https://example.com")?);
|
||||||
"google".into(),
|
|
||||||
Route::from_str("https://example.com").unwrap(),
|
|
||||||
);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
resolve_hop("google hello world", &map, &None),
|
resolve_hop("google hello world", &map, &None),
|
||||||
generate_route_result(
|
generate_route_result(
|
||||||
&Route::from_str("https://example.com").unwrap(),
|
&Route::from_str("https://example.com")?,
|
||||||
"hello world"
|
"hello world"
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -371,6 +367,7 @@ mod check_route {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod resolve_path {
|
mod resolve_path {
|
||||||
use super::{resolve_path, HopAction};
|
use super::{resolve_path, HopAction};
|
||||||
|
use anyhow::Result;
|
||||||
use std::env::current_dir;
|
use std::env::current_dir;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
@ -387,12 +384,13 @@ mod resolve_path {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn relative_path_returns_ok() {
|
fn relative_path_returns_ok() -> Result<()> {
|
||||||
// How many ".." needed to get to /
|
// How many ".." needed to get to /
|
||||||
let nest_level = current_dir().unwrap().ancestors().count() - 1;
|
let nest_level = current_dir()?.ancestors().count() - 1;
|
||||||
let mut rel_path = PathBuf::from("../".repeat(nest_level));
|
let mut rel_path = PathBuf::from("../".repeat(nest_level));
|
||||||
rel_path.push("./bin/echo");
|
rel_path.push("./bin/echo");
|
||||||
assert!(resolve_path(rel_path, r#"{"body": "a"}"#).is_ok());
|
assert!(resolve_path(rel_path, r#"{"body": "a"}"#).is_ok());
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -414,18 +412,21 @@ mod resolve_path {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn return_body() {
|
fn return_body() -> Result<()> {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
resolve_path(PathBuf::from("/bin/echo"), r#"{"body": "a"}"#).unwrap(),
|
resolve_path(PathBuf::from("/bin/echo"), r#"{"body": "a"}"#)?,
|
||||||
HopAction::Body("a".to_string())
|
HopAction::Body("a".to_string())
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn return_redirect() {
|
fn return_redirect() -> Result<()> {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
resolve_path(PathBuf::from("/bin/echo"), r#"{"redirect": "a"}"#).unwrap(),
|
resolve_path(PathBuf::from("/bin/echo"), r#"{"redirect": "a"}"#)?,
|
||||||
HopAction::Redirect("a".to_string())
|
HopAction::Redirect("a".to_string())
|
||||||
);
|
);
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,17 +1,22 @@
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
|
use percent_encoding::PercentEncode;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
pub fn query(query: String) -> impl Serialize {
|
pub fn query<'a>(query: PercentEncode<'a>) -> impl Serialize + 'a {
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct TemplateArgs {
|
struct TemplateArgs<'a> {
|
||||||
query: String,
|
query: Cow<'a, str>,
|
||||||
|
}
|
||||||
|
TemplateArgs {
|
||||||
|
query: query.into(),
|
||||||
}
|
}
|
||||||
TemplateArgs { query }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn hostname(hostname: String) -> impl Serialize {
|
pub fn hostname<'a>(hostname: &'a str) -> impl Serialize + 'a {
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub struct TemplateArgs {
|
pub struct TemplateArgs<'a> {
|
||||||
pub hostname: String,
|
pub hostname: &'a str,
|
||||||
}
|
}
|
||||||
TemplateArgs { hostname }
|
TemplateArgs { hostname }
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue