Migrate to axum

This commit is contained in:
Edward Shen 2022-06-02 21:58:56 -07:00
parent 411854385c
commit ce68f4dd42
Signed by: edward
GPG key ID: 19182661E818369F
7 changed files with 497 additions and 1338 deletions

1387
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -10,14 +10,17 @@ repository = "https://github.com/edward-shen/bunbun"
exclude = ["/aux/"] exclude = ["/aux/"]
[dependencies] [dependencies]
actix-web = "3" anyhow = "1"
arc-swap = "1"
tokio = { version = "1", features = ["full"] }
axum = "0.5"
clap = { version = "3", features = ["wrap_help", "derive", "cargo"] } clap = { version = "3", features = ["wrap_help", "derive", "cargo"] }
dirs = "4" dirs = "4"
handlebars = "4" handlebars = "4"
hotwatch = "0.4" hotwatch = "0.4"
log = "0.4" log = "0.4"
percent-encoding = "2" percent-encoding = "2"
serde = "1" serde = { version = "1", features = ["derive"] }
serde_yaml = "0.8" serde_yaml = "0.8"
serde_json = "1" serde_json = "1"
simple_logger = "2" simple_logger = "2"

View file

@ -60,6 +60,32 @@ impl FromStr for Route {
} }
} }
impl From<String> for Route {
fn from(s: String) -> Self {
Self {
route_type: get_route_type(&s),
path: s,
hidden: false,
description: None,
min_args: None,
max_args: None,
}
}
}
impl From<&'static str> for Route {
fn from(s: &'static str) -> Self {
Self {
route_type: get_route_type(s),
path: s.to_string(),
hidden: false,
description: None,
min_args: None,
max_args: None,
}
}
}
/// Deserialization of the route string into the enum requires us to figure out /// Deserialization of the route string into the enum requires us to figure out
/// whether or not the string is valid to run as an executable or not. To /// whether or not the string is valid to run as an executable or not. To
/// determine this, we simply check if it exists on disk or assume that it's a /// determine this, we simply check if it exists on disk or assume that it's a
@ -147,8 +173,7 @@ impl<'de> Deserialize<'de> for Route {
{ {
return Err(de::Error::invalid_value( return Err(de::Error::invalid_value(
Unexpected::Other(&format!( Unexpected::Other(&format!(
"argument count range {} to {}", "argument count range {min_args} to {max_args}",
min_args, max_args
)), )),
&"a valid argument count range", &"a valid argument count range",
)); ));
@ -179,12 +204,12 @@ impl std::fmt::Display for Route {
route_type: RouteType::External, route_type: RouteType::External,
path, path,
.. ..
} => write!(f, "raw ({})", path), } => write!(f, "raw ({path})"),
Self { Self {
route_type: RouteType::Internal, route_type: RouteType::Internal,
path, path,
.. ..
} => write!(f, "file ({})", path), } => write!(f, "file ({path})"),
} }
} }
} }
@ -192,10 +217,10 @@ impl std::fmt::Display for Route {
/// Classifies the path depending on if the there exists a local file. /// Classifies the path depending on if the there exists a local file.
fn get_route_type(path: &str) -> RouteType { fn get_route_type(path: &str) -> RouteType {
if std::path::Path::new(path).exists() { if std::path::Path::new(path).exists() {
debug!("Parsed {} as a valid local path.", path); debug!("Parsed {path} as a valid local path.");
RouteType::Internal RouteType::Internal
} else { } else {
debug!("{} does not exist on disk, assuming web path.", path); debug!("{path} does not exist on disk, assuming web path.");
RouteType::External RouteType::External
} }
} }
@ -245,16 +270,15 @@ pub fn get_config_data() -> Result<ConfigData, BunBunError> {
let file = OpenOptions::new().read(true).open(location.clone()); let file = OpenOptions::new().read(true).open(location.clone());
match file { match file {
Ok(file) => { Ok(file) => {
debug!("Found file at {:?}.", location); debug!("Found file at {location:?}.");
return Ok(ConfigData { return Ok(ConfigData {
path: location.clone(), path: location.clone(),
file, file,
}); });
} }
Err(e) => debug!( Err(e) => {
"Tried to read '{:?}' but failed due to error: {}", debug!("Tried to read '{location:?}' but failed due to error: {e}",)
location, e }
),
} }
} }
@ -270,7 +294,7 @@ pub fn get_config_data() -> Result<ConfigData, BunBunError> {
.open(location.clone()); .open(location.clone());
match file { match file {
Ok(mut file) => { Ok(mut file) => {
info!("Creating new config file at {:?}.", location); info!("Creating new config file at {location:?}.");
file.write_all(DEFAULT_CONFIG)?; file.write_all(DEFAULT_CONFIG)?;
let file = OpenOptions::new().read(true).open(location.clone())?; let file = OpenOptions::new().read(true).open(location.clone())?;
@ -280,8 +304,7 @@ pub fn get_config_data() -> Result<ConfigData, BunBunError> {
}); });
} }
Err(e) => debug!( Err(e) => debug!(
"Tried to open a new file at '{:?}' but failed due to error: {}", "Tried to open a new file at '{location:?}' but failed due to error: {e}",
location, e
), ),
} }
} }
@ -329,90 +352,96 @@ pub fn read_config(
#[cfg(test)] #[cfg(test)]
mod route { mod route {
use super::*; use super::*;
use anyhow::{Context, Result};
use serde_yaml::{from_str, to_string}; use serde_yaml::{from_str, to_string};
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
#[test] #[test]
fn deserialize_relative_path() { fn deserialize_relative_path() -> Result<()> {
let tmpfile = NamedTempFile::new_in(".").unwrap(); let tmpfile = NamedTempFile::new_in(".")?;
let path = format!("{}", tmpfile.path().display()); let path = format!("{}", tmpfile.path().display());
let path = path.get(path.rfind(".").unwrap()..).unwrap(); let path = path
.get(path.rfind(".").context("While finding .")?..)
.context("While getting the path")?;
let path = std::path::Path::new(path); let path = std::path::Path::new(path);
assert!(path.is_relative()); assert!(path.is_relative());
let path = path.to_str().unwrap(); let path = path.to_str().context("While stringifying path")?;
assert_eq!( assert_eq!(from_str::<Route>(path)?, Route::from_str(path)?);
from_str::<Route>(path).unwrap(), Ok(())
Route::from_str(path).unwrap()
);
} }
#[test] #[test]
fn deserialize_absolute_path() { fn deserialize_absolute_path() -> Result<()> {
let tmpfile = NamedTempFile::new().unwrap(); let tmpfile = NamedTempFile::new()?;
let path = format!("{}", tmpfile.path().display()); let path = format!("{}", tmpfile.path().display());
assert!(tmpfile.path().is_absolute()); assert!(tmpfile.path().is_absolute());
assert_eq!( assert_eq!(from_str::<Route>(&path)?, Route::from_str(&path)?);
from_str::<Route>(&path).unwrap(),
Route::from_str(&path).unwrap() Ok(())
);
} }
#[test] #[test]
fn deserialize_http_path() { fn deserialize_http_path() -> Result<()> {
assert_eq!( assert_eq!(
from_str::<Route>("http://google.com").unwrap(), from_str::<Route>("http://google.com")?,
Route::from_str("http://google.com").unwrap() Route::from_str("http://google.com")?
); );
Ok(())
} }
#[test] #[test]
fn deserialize_https_path() { fn deserialize_https_path() -> Result<()> {
assert_eq!( assert_eq!(
from_str::<Route>("https://google.com").unwrap(), from_str::<Route>("https://google.com")?,
Route::from_str("https://google.com").unwrap() Route::from_str("https://google.com")?
); );
Ok(())
} }
#[test] #[test]
fn serialize() { fn serialize() -> Result<()> {
assert_eq!( assert_eq!(
&to_string(&Route::from_str("hello world").unwrap()).unwrap(), &to_string(&Route::from_str("hello world")?)?,
"---\nroute_type: External\npath: hello world\nhidden: false\ndescription: ~\nmin_args: ~\nmax_args: ~\n" "---\nroute_type: External\npath: hello world\nhidden: false\ndescription: ~\nmin_args: ~\nmax_args: ~\n"
); );
Ok(())
} }
} }
#[cfg(test)] #[cfg(test)]
mod read_config { mod read_config {
use super::*; use super::*;
use anyhow::Result;
#[test] #[test]
fn empty_file() { fn empty_file() -> Result<()> {
let config_file = tempfile::tempfile().unwrap(); let config_file = tempfile::tempfile()?;
assert!(matches!( assert!(matches!(
read_config(config_file, false), read_config(config_file, false),
Err(BunBunError::ZeroByteConfig) Err(BunBunError::ZeroByteConfig)
)); ));
Ok(())
} }
#[test] #[test]
fn config_too_large() { fn config_too_large() -> Result<()> {
let mut config_file = tempfile::tempfile().unwrap(); let mut config_file = tempfile::tempfile()?;
let size_to_write = (LARGE_FILE_SIZE_THRESHOLD + 1) as usize; let size_to_write = (LARGE_FILE_SIZE_THRESHOLD + 1) as usize;
config_file.write(&[0].repeat(size_to_write)).unwrap(); config_file.write(&[0].repeat(size_to_write))?;
match read_config(config_file, false) { match read_config(config_file, false) {
Err(BunBunError::ConfigTooLarge(size)) Err(BunBunError::ConfigTooLarge(size))
if size as usize == size_to_write => {} if size as usize == size_to_write => {}
Err(BunBunError::ConfigTooLarge(size)) => { Err(BunBunError::ConfigTooLarge(size)) => {
panic!("Mismatched size: {} != {}", size, size_to_write) panic!("Mismatched size: {} != {}", size, size_to_write)
} }
res => panic!("Wrong result, got {:#?}", res), res => panic!("Wrong result, got {res:#?}"),
} }
Ok(())
} }
#[test] #[test]
fn valid_config() { fn valid_config() -> Result<()> {
let config_file = File::open("bunbun.default.yaml").unwrap(); assert!(read_config(File::open("bunbun.default.yaml")?, false).is_ok());
assert!(read_config(config_file, false).is_ok()); Ok(())
} }
} }

View file

@ -27,9 +27,9 @@ impl fmt::Display for BunBunError {
Self::CustomProgram(msg) => write!(f, "{}", msg), Self::CustomProgram(msg) => write!(f, "{}", msg),
Self::NoValidConfigPath => write!(f, "No valid config path was found!"), Self::NoValidConfigPath => write!(f, "No valid config path was found!"),
Self::InvalidConfigPath(path, reason) => { Self::InvalidConfigPath(path, reason) => {
write!(f, "Failed to access {:?}: {}", path, reason) write!(f, "Failed to access {path:?}: {reason}")
} }
Self::ConfigTooLarge(size) => write!(f, "The config file was too large ({} bytes)! Pass in --large-config to bypass this check.", size), Self::ConfigTooLarge(size) => write!(f, "The config file was too large ({size} bytes)! Pass in --large-config to bypass this check."),
Self::ZeroByteConfig => write!(f, "The config provided reported a size of 0 bytes. Please check your config path!"), Self::ZeroByteConfig => write!(f, "The config provided reported a size of 0 bytes. Please check your config path!"),
Self::JsonParse(e) => e.fmt(f), Self::JsonParse(e) => e.fmt(f),
} }

View file

@ -9,16 +9,19 @@ use crate::config::{
get_config_data, load_custom_path_config, read_config, ConfigData, Route, get_config_data, load_custom_path_config, read_config, ConfigData, Route,
RouteGroup, RouteGroup,
}; };
use actix_web::{middleware::Logger, App, HttpServer}; use anyhow::Result;
use arc_swap::ArcSwap;
use axum::routing::get;
use axum::{Extension, Router};
use clap::Parser; use clap::Parser;
use error::BunBunError; use error::BunBunError;
use handlebars::{Handlebars, TemplateError}; use handlebars::Handlebars;
use hotwatch::{Event, Hotwatch}; use hotwatch::{Event, Hotwatch};
use log::{debug, error, info, trace, warn}; use log::{debug, info, trace, warn};
use simple_logger::SimpleLogger; use simple_logger::SimpleLogger;
use std::cmp::min; use std::cmp::min;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
mod cli; mod cli;
@ -39,20 +42,9 @@ pub struct State {
routes: HashMap<String, Route>, routes: HashMap<String, Route>,
} }
#[actix_web::main] #[tokio::main]
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
async fn main() { async fn main() -> Result<()> {
std::process::exit(match run().await {
Ok(_) => 0,
Err(e) => {
error!("{}", e);
1
}
})
}
#[cfg(not(tarpaulin_include))]
async fn run() -> Result<(), BunBunError> {
let opts = cli::Opts::parse(); let opts = cli::Opts::parse();
init_logger(opts.verbose, opts.quiet)?; init_logger(opts.verbose, opts.quiet)?;
@ -63,7 +55,7 @@ async fn run() -> Result<(), BunBunError> {
}?; }?;
let conf = read_config(conf_data.file.try_clone()?, opts.large_config)?; let conf = read_config(conf_data.file.try_clone()?, opts.large_config)?;
let state = Arc::from(RwLock::new(State { let state = Arc::from(ArcSwap::from_pointee(State {
public_address: conf.public_address, public_address: conf.public_address,
default_route: conf.default_route, default_route: conf.default_route,
routes: cache_routes(&conf.groups), routes: cache_routes(&conf.groups),
@ -71,26 +63,18 @@ async fn run() -> Result<(), BunBunError> {
})); }));
// Cannot be named _ or Rust will immediately drop it. // Cannot be named _ or Rust will immediately drop it.
let _watch = start_watch(Arc::clone(&state), conf_data, opts.large_config)?; let _watch = start_watch(Arc::clone(&state), conf_data, opts.large_config);
HttpServer::new(move || { let app = Router::new()
let templates = match compile_templates() { .route("/", get(routes::index))
Ok(templates) => templates, .route("/bunbunsearch.xml", get(routes::opensearch))
// This implies a template error, which should be a compile time error. If .route("/ls", get(routes::list))
// we reach here then the release is very broken. .route("/hop", get(routes::hop))
Err(e) => unreachable!("Failed to compile templates: {}", e), .layer(Extension(compile_templates()?))
}; .layer(Extension(state));
App::new()
.data(Arc::clone(&state)) axum::Server::bind(&conf.bind_address.parse()?)
.app_data(templates) .serve(app.into_make_service())
.wrap(Logger::default())
.service(routes::hop)
.service(routes::list)
.service(routes::index)
.service(routes::opensearch)
})
.bind(&conf.bind_address)?
.run()
.await?; .await?;
Ok(()) Ok(())
@ -100,10 +84,7 @@ async fn run() -> Result<(), BunBunError> {
/// in. Usually, these values are mutually exclusive, that is, if the number of /// in. Usually, these values are mutually exclusive, that is, if the number of
/// verbose flags is non-zero then the quiet flag is zero, and vice versa. /// verbose flags is non-zero then the quiet flag is zero, and vice versa.
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
fn init_logger( fn init_logger(num_verbose_flags: u8, num_quiet_flags: u8) -> Result<()> {
num_verbose_flags: u8,
num_quiet_flags: u8,
) -> Result<(), BunBunError> {
let log_level = let log_level =
match min(num_verbose_flags, 3) as i8 - min(num_quiet_flags, 2) as i8 { match min(num_verbose_flags, 3) as i8 - min(num_quiet_flags, 2) as i8 {
-2 => None, -2 => None,
@ -143,7 +124,7 @@ fn cache_routes(groups: &[RouteGroup]) -> HashMap<String, Route> {
/// Returns an instance with all pre-generated templates included into the /// Returns an instance with all pre-generated templates included into the
/// binary. This allows for users to have a portable binary without needed the /// binary. This allows for users to have a portable binary without needed the
/// templates at runtime. /// templates at runtime.
fn compile_templates() -> Result<Handlebars<'static>, TemplateError> { fn compile_templates() -> Result<Handlebars<'static>> {
let mut handlebars = Handlebars::new(); let mut handlebars = Handlebars::new();
handlebars.set_strict_mode(true); handlebars.set_strict_mode(true);
handlebars.register_partial("bunbun_version", env!("CARGO_PKG_VERSION"))?; handlebars.register_partial("bunbun_version", env!("CARGO_PKG_VERSION"))?;
@ -176,10 +157,10 @@ fn compile_templates() -> Result<Handlebars<'static>, TemplateError> {
/// watches. /// watches.
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
fn start_watch( fn start_watch(
state: Arc<RwLock<State>>, state: Arc<ArcSwap<State>>,
config_data: ConfigData, config_data: ConfigData,
large_config: bool, large_config: bool,
) -> Result<Hotwatch, BunBunError> { ) -> Result<Hotwatch> {
let mut watch = Hotwatch::new_with_custom_delay(Duration::from_millis(500))?; let mut watch = Hotwatch::new_with_custom_delay(Duration::from_millis(500))?;
let ConfigData { path, mut file } = config_data; let ConfigData { path, mut file } = config_data;
let watch_result = watch.watch(&path, move |e: Event| { let watch_result = watch.watch(&path, move |e: Event| {
@ -193,33 +174,32 @@ fn start_watch(
match e { match e {
Event::Write(_) | Event::Create(_) => { Event::Write(_) | Event::Create(_) => {
trace!("Grabbing writer lock on state..."); trace!("Grabbing writer lock on state...");
let mut state =
state.write().expect("Failed to get write lock on state");
trace!("Obtained writer lock on state!"); trace!("Obtained writer lock on state!");
match read_config( match read_config(
file.try_clone().expect("Failed to clone file handle"), file.try_clone().expect("Failed to clone file handle"),
large_config, large_config,
) { ) {
Ok(conf) => { Ok(conf) => {
state.public_address = conf.public_address; state.store(Arc::new(State {
state.default_route = conf.default_route; public_address: conf.public_address,
state.routes = cache_routes(&conf.groups); default_route: conf.default_route,
state.groups = conf.groups; routes: cache_routes(&conf.groups),
groups: conf.groups,
}));
info!("Successfully updated active state"); info!("Successfully updated active state");
} }
Err(e) => warn!("Failed to update config file: {}", e), Err(e) => warn!("Failed to update config file: {e}"),
} }
} }
_ => debug!("Saw event {:#?} but ignored it", e), _ => debug!("Saw event {e:#?} but ignored it"),
} }
}); });
match watch_result { match watch_result {
Ok(_) => info!("Watcher is now watching {:?}", &path), Ok(_) => info!("Watcher is now watching {path:?}"),
Err(e) => warn!( Err(e) => {
"Couldn't watch {:?}: {}. Changes to this file won't be seen!", warn!("Couldn't watch {path:?}: {e}. Changes to this file won't be seen!",)
&path, e }
),
} }
Ok(watch) Ok(watch)
@ -228,9 +208,10 @@ fn start_watch(
#[cfg(test)] #[cfg(test)]
mod init_logger { mod init_logger {
use super::*; use super::*;
use anyhow::Result;
#[test] #[test]
fn defaults_to_warn() -> Result<(), BunBunError> { fn defaults_to_warn() -> Result<()> {
init_logger(0, 0)?; init_logger(0, 0)?;
assert_eq!(log::max_level(), log::Level::Warn); assert_eq!(log::max_level(), log::Level::Warn);
Ok(()) Ok(())
@ -242,7 +223,7 @@ mod init_logger {
#[test] #[test]
#[ignore] #[ignore]
fn caps_to_2_when_log_level_is_lt_2() -> Result<(), BunBunError> { fn caps_to_2_when_log_level_is_lt_2() -> Result<()> {
init_logger(0, 3)?; init_logger(0, 3)?;
assert_eq!(log::max_level(), log::LevelFilter::Off); assert_eq!(log::max_level(), log::LevelFilter::Off);
Ok(()) Ok(())
@ -250,7 +231,7 @@ mod init_logger {
#[test] #[test]
#[ignore] #[ignore]
fn caps_to_3_when_log_level_is_gt_3() -> Result<(), BunBunError> { fn caps_to_3_when_log_level_is_gt_3() -> Result<()> {
init_logger(4, 0)?; init_logger(4, 0)?;
assert_eq!(log::max_level(), log::Level::Trace); assert_eq!(log::max_level(), log::Level::Trace);
Ok(()) Ok(())
@ -261,15 +242,14 @@ mod init_logger {
mod cache_routes { mod cache_routes {
use super::*; use super::*;
use std::iter::FromIterator; use std::iter::FromIterator;
use std::str::FromStr;
fn generate_external_routes( fn generate_external_routes(
routes: &[(&str, &str)], routes: &[(&'static str, &'static str)],
) -> HashMap<String, Route> { ) -> HashMap<String, Route> {
HashMap::from_iter( HashMap::from_iter(
routes routes
.into_iter() .into_iter()
.map(|kv| (kv.0.into(), Route::from_str(kv.1).unwrap())), .map(|(key, value)| ((*key).to_owned(), Route::from(*value))),
) )
} }

View file

@ -1,8 +1,11 @@
use crate::config::{Route as ConfigRoute, RouteType}; use crate::config::{Route as ConfigRoute, RouteType};
use crate::{template_args, BunBunError, Route, State}; use crate::{template_args, BunBunError, Route, State};
use actix_web::web::{Data, Query}; use arc_swap::ArcSwap;
use actix_web::{get, http::header}; use axum::body::{boxed, Bytes, Full};
use actix_web::{HttpRequest, HttpResponse, Responder}; use axum::extract::Query;
use axum::http::{header, StatusCode};
use axum::response::{Html, IntoResponse, Response};
use axum::Extension;
use handlebars::Handlebars; use handlebars::Handlebars;
use log::{debug, error}; use log::{debug, error};
use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
@ -10,7 +13,7 @@ use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
use std::path::PathBuf; use std::path::PathBuf;
use std::process::Command; use std::process::Command;
use std::sync::{Arc, RwLock}; use std::sync::Arc;
/// https://url.spec.whatwg.org/#fragment-percent-encode-set /// https://url.spec.whatwg.org/#fragment-percent-encode-set
const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS
@ -24,71 +27,62 @@ const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS
.add(b'#') // Interpreted as a hyperlink section target .add(b'#') // Interpreted as a hyperlink section target
.add(b'\''); .add(b'\'');
type StateData = Data<Arc<RwLock<State>>>; pub async fn index(
Extension(data): Extension<Arc<ArcSwap<State>>>,
#[get("/")] Extension(handlebars): Extension<Handlebars<'static>>,
pub async fn index(data: StateData, req: HttpRequest) -> impl Responder { ) -> impl IntoResponse {
let data = data.read().unwrap(); handlebars
HttpResponse::Ok()
.set_header(header::CONTENT_TYPE, "text/html; charset=utf-8")
.body(
req
.app_data::<Handlebars>()
.unwrap()
.render( .render(
"index", "index",
&template_args::hostname(data.public_address.clone()), &template_args::hostname(&data.load().public_address),
)
.unwrap(),
) )
.map(Html)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} }
#[get("/bunbunsearch.xml")] pub async fn opensearch(
pub async fn opensearch(data: StateData, req: HttpRequest) -> impl Responder { Extension(data): Extension<Arc<ArcSwap<State>>>,
let data = data.read().unwrap(); Extension(handlebars): Extension<Handlebars<'static>>,
HttpResponse::Ok() ) -> impl IntoResponse {
.header( handlebars
header::CONTENT_TYPE,
"application/opensearchdescription+xml",
)
.body(
req
.app_data::<Handlebars>()
.unwrap()
.render( .render(
"opensearch", "opensearch",
&template_args::hostname(data.public_address.clone()), &template_args::hostname(&data.load().public_address),
) )
.unwrap(), .map(|body| {
(
StatusCode::OK,
[(
header::CONTENT_TYPE,
"application/opensearchdescription+xml",
)],
body,
) )
})
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} }
#[get("/ls")] pub async fn list(
pub async fn list(data: StateData, req: HttpRequest) -> impl Responder { Extension(data): Extension<Arc<ArcSwap<State>>>,
let data = data.read().unwrap(); Extension(handlebars): Extension<Handlebars<'static>>,
HttpResponse::Ok() ) -> impl IntoResponse {
.set_header(header::CONTENT_TYPE, "text/html; charset=utf-8") handlebars
.body( .render("list", &data.load().groups)
req .map(Html)
.app_data::<Handlebars>() .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
.unwrap()
.render("list", &data.groups)
.unwrap(),
)
} }
#[derive(Deserialize)] #[derive(Deserialize, Debug)]
pub struct SearchQuery { pub struct SearchQuery {
to: String, to: String,
} }
#[get("/hop")]
pub async fn hop( pub async fn hop(
data: StateData, Extension(data): Extension<Arc<ArcSwap<State>>>,
req: HttpRequest, Extension(handlebars): Extension<Handlebars<'static>>,
query: Query<SearchQuery>, Query(query): Query<SearchQuery>,
) -> impl Responder { ) -> impl IntoResponse {
let data = data.read().unwrap(); let data = data.load();
match resolve_hop(&query.to, &data.routes, &data.default_route) { match resolve_hop(&query.to, &data.routes, &data.default_route) {
RouteResolution::Resolved { route: path, args } => { RouteResolution::Resolved { route: path, args } => {
@ -106,29 +100,36 @@ pub async fn hop(
}; };
match resolved_template { match resolved_template {
Ok(HopAction::Redirect(path)) => HttpResponse::Found() Ok(HopAction::Redirect(path)) => Response::builder()
.header( .status(StatusCode::FOUND)
header::LOCATION, .header(header::LOCATION, &path)
req .body(boxed(Full::from(
.app_data::<Handlebars>() handlebars
.unwrap()
.render_template( .render_template(
std::str::from_utf8(path.as_bytes()).unwrap(), &path,
&template_args::query( &template_args::query(utf8_percent_encode(
utf8_percent_encode(&args, FRAGMENT_ENCODE_SET).to_string(), &args,
), FRAGMENT_ENCODE_SET,
)),
) )
.unwrap(), .unwrap(),
) ))),
.finish(), Ok(HopAction::Body(body)) => Response::builder()
Ok(HopAction::Body(body)) => HttpResponse::Ok().body(body), .status(StatusCode::OK)
.body(boxed(Full::new(Bytes::from(body)))),
Err(e) => { Err(e) => {
error!("Failed to redirect user for {}: {}", path, e); error!("Failed to redirect user for {}: {}", path, e);
HttpResponse::InternalServerError().body("Something went wrong :(\n") Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(boxed(Full::from("Something went wrong :(\n")))
} }
} }
.unwrap()
} }
RouteResolution::Unresolved => HttpResponse::NotFound().body("not found"), RouteResolution::Unresolved => Response::builder()
.status(StatusCode::NOT_FOUND)
.body(boxed(Full::from("not found\n")))
.unwrap(),
} }
} }
@ -236,6 +237,7 @@ fn resolve_path(path: PathBuf, args: &str) -> Result<HopAction, BunBunError> {
#[cfg(test)] #[cfg(test)]
mod resolve_hop { mod resolve_hop {
use super::*; use super::*;
use anyhow::Result;
use std::str::FromStr; use std::str::FromStr;
fn generate_route_result<'a>( fn generate_route_result<'a>(
@ -269,51 +271,45 @@ mod resolve_hop {
} }
#[test] #[test]
fn only_default_routes_some_default_yields_default_hop() { fn only_default_routes_some_default_yields_default_hop() -> Result<()> {
let mut map: HashMap<String, Route> = HashMap::new(); let mut map: HashMap<String, Route> = HashMap::new();
map.insert( map.insert("google".into(), Route::from_str("https://example.com")?);
"google".into(),
Route::from_str("https://example.com").unwrap(),
);
assert_eq!( assert_eq!(
resolve_hop("hello world", &map, &Some(String::from("google"))), resolve_hop("hello world", &map, &Some(String::from("google"))),
generate_route_result( generate_route_result(
&Route::from_str("https://example.com").unwrap(), &Route::from_str("https://example.com")?,
"hello world" "hello world"
), ),
); );
Ok(())
} }
#[test] #[test]
fn non_default_routes_some_default_yields_non_default_hop() { fn non_default_routes_some_default_yields_non_default_hop() -> Result<()> {
let mut map: HashMap<String, Route> = HashMap::new(); let mut map: HashMap<String, Route> = HashMap::new();
map.insert( map.insert("google".into(), Route::from_str("https://example.com")?);
"google".into(),
Route::from_str("https://example.com").unwrap(),
);
assert_eq!( assert_eq!(
resolve_hop("google hello world", &map, &Some(String::from("a"))), resolve_hop("google hello world", &map, &Some(String::from("a"))),
generate_route_result( generate_route_result(
&Route::from_str("https://example.com").unwrap(), &Route::from_str("https://example.com")?,
"hello world" "hello world"
), ),
); );
Ok(())
} }
#[test] #[test]
fn non_default_routes_no_default_yields_non_default_hop() { fn non_default_routes_no_default_yields_non_default_hop() -> Result<()> {
let mut map: HashMap<String, Route> = HashMap::new(); let mut map: HashMap<String, Route> = HashMap::new();
map.insert( map.insert("google".into(), Route::from_str("https://example.com")?);
"google".into(),
Route::from_str("https://example.com").unwrap(),
);
assert_eq!( assert_eq!(
resolve_hop("google hello world", &map, &None), resolve_hop("google hello world", &map, &None),
generate_route_result( generate_route_result(
&Route::from_str("https://example.com").unwrap(), &Route::from_str("https://example.com")?,
"hello world" "hello world"
), ),
); );
Ok(())
} }
} }
@ -371,6 +367,7 @@ mod check_route {
#[cfg(test)] #[cfg(test)]
mod resolve_path { mod resolve_path {
use super::{resolve_path, HopAction}; use super::{resolve_path, HopAction};
use anyhow::Result;
use std::env::current_dir; use std::env::current_dir;
use std::path::PathBuf; use std::path::PathBuf;
@ -387,12 +384,13 @@ mod resolve_path {
} }
#[test] #[test]
fn relative_path_returns_ok() { fn relative_path_returns_ok() -> Result<()> {
// How many ".." needed to get to / // How many ".." needed to get to /
let nest_level = current_dir().unwrap().ancestors().count() - 1; let nest_level = current_dir()?.ancestors().count() - 1;
let mut rel_path = PathBuf::from("../".repeat(nest_level)); let mut rel_path = PathBuf::from("../".repeat(nest_level));
rel_path.push("./bin/echo"); rel_path.push("./bin/echo");
assert!(resolve_path(rel_path, r#"{"body": "a"}"#).is_ok()); assert!(resolve_path(rel_path, r#"{"body": "a"}"#).is_ok());
Ok(())
} }
#[test] #[test]
@ -414,18 +412,21 @@ mod resolve_path {
} }
#[test] #[test]
fn return_body() { fn return_body() -> Result<()> {
assert_eq!( assert_eq!(
resolve_path(PathBuf::from("/bin/echo"), r#"{"body": "a"}"#).unwrap(), resolve_path(PathBuf::from("/bin/echo"), r#"{"body": "a"}"#)?,
HopAction::Body("a".to_string()) HopAction::Body("a".to_string())
); );
Ok(())
} }
#[test] #[test]
fn return_redirect() { fn return_redirect() -> Result<()> {
assert_eq!( assert_eq!(
resolve_path(PathBuf::from("/bin/echo"), r#"{"redirect": "a"}"#).unwrap(), resolve_path(PathBuf::from("/bin/echo"), r#"{"redirect": "a"}"#)?,
HopAction::Redirect("a".to_string()) HopAction::Redirect("a".to_string())
); );
Ok(())
} }
} }

View file

@ -1,17 +1,22 @@
use std::borrow::Cow;
use percent_encoding::PercentEncode;
use serde::Serialize; use serde::Serialize;
pub fn query(query: String) -> impl Serialize { pub fn query<'a>(query: PercentEncode<'a>) -> impl Serialize + 'a {
#[derive(Serialize)] #[derive(Serialize)]
struct TemplateArgs { struct TemplateArgs<'a> {
query: String, query: Cow<'a, str>,
}
TemplateArgs {
query: query.into(),
} }
TemplateArgs { query }
} }
pub fn hostname(hostname: String) -> impl Serialize { pub fn hostname<'a>(hostname: &'a str) -> impl Serialize + 'a {
#[derive(Serialize)] #[derive(Serialize)]
pub struct TemplateArgs { pub struct TemplateArgs<'a> {
pub hostname: String, pub hostname: &'a str,
} }
TemplateArgs { hostname } TemplateArgs { hostname }
} }