Compare commits

...

12 Commits

Author SHA1 Message Date
Edward Shen 72e672bc73
Use tracing 2022-06-02 23:04:59 -07:00
Edward Shen 3561f488c1
release config optimizations 2022-06-02 22:50:21 -07:00
Edward Shen 4055b9dee4
Remove unused tokio features 2022-06-02 22:50:11 -07:00
Edward Shen f1d7797637
Reformat 2022-06-02 22:46:10 -07:00
Edward Shen 90ff4461a6
2021 idioms 2022-06-02 22:42:19 -07:00
Edward Shen ce592985ce
Remove all unwraps 2022-06-02 22:39:35 -07:00
Edward Shen 0132d32507
Remove unwraps 2022-06-02 22:29:09 -07:00
Edward Shen dc216a80d5
Remove clone 2022-06-02 22:24:46 -07:00
Edward Shen 531a7da636
Clippy 2022-06-02 22:23:35 -07:00
Edward Shen ce68f4dd42
Migrate to axum 2022-06-02 21:58:56 -07:00
Edward Shen 411854385c
Percent-escape single quote 2022-06-02 20:12:07 -07:00
Edward Shen 49e1c8ce0c
Partial dependency update 2022-06-02 20:08:04 -07:00
9 changed files with 1379 additions and 2405 deletions

1897
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@
name = "bunbun"
version = "0.8.0"
authors = ["Edward Shen <code@eddie.sh>"]
edition = "2018"
edition = "2021"
description = "Re-implementation of bunny1 in Rust"
license = "AGPL-3.0"
readme = "README.md"
@ -10,20 +10,25 @@ repository = "https://github.com/edward-shen/bunbun"
exclude = ["/aux/"]
[dependencies]
actix-web = "3"
clap = { version = "3.0.0-beta.2", features = ["wrap_help"] }
dirs = "3"
handlebars = "3"
anyhow = "1"
arc-swap = "1"
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
axum = "0.5"
clap = { version = "3", features = ["wrap_help", "derive", "cargo"] }
dirs = "4"
handlebars = "4"
hotwatch = "0.4"
log = "0.4"
percent-encoding = "2"
serde = "1"
serde = { version = "1", features = ["derive"] }
serde_yaml = "0.8"
serde_json = "1"
simple_logger = "1"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dev-dependencies]
tempfile = "3"
[profile.release]
lto = true
codegen-units = 1
strip = true

View File

@ -1,3 +0,0 @@
tab_spaces = 2
use_field_init_shorthand = true
max_width = 80

View File

@ -1,15 +1,13 @@
use clap::{crate_authors, crate_version, Clap};
use clap::{crate_authors, crate_version, Parser};
use std::path::PathBuf;
use tracing_subscriber::filter::Directive;
#[derive(Clap)]
#[derive(Parser)]
#[clap(version = crate_version!(), author = crate_authors!())]
pub struct Opts {
/// Increases the log level to info, debug, and trace, respectively.
#[clap(short, long, parse(from_occurrences), conflicts_with("quiet"))]
pub verbose: u8,
/// Decreases the log level to error or no logging at all, respectively.
#[clap(short, long, parse(from_occurrences), conflicts_with("verbose"))]
pub quiet: u8,
/// Set the logging directives
#[clap(long, default_value = "info")]
pub log: Vec<Directive>,
/// Specify the location of the config file to read from. Needs read/write permissions.
#[clap(short, long)]
pub config: Option<PathBuf>,

View File

@ -1,6 +1,5 @@
use crate::BunBunError;
use dirs::{config_dir, home_dir};
use log::{debug, info, trace};
use serde::{
de::{self, Deserializer, MapAccess, Unexpected, Visitor},
Deserialize, Serialize,
@ -10,7 +9,7 @@ use std::fmt;
use std::fs::{File, OpenOptions};
use std::io::{Read, Write};
use std::path::PathBuf;
use std::str::FromStr;
use tracing::{debug, info, trace};
const CONFIG_FILENAME: &str = "bunbun.yaml";
const DEFAULT_CONFIG: &[u8] = include_bytes!("../bunbun.default.yaml");
@ -46,17 +45,29 @@ pub struct Route {
pub max_args: Option<usize>,
}
impl FromStr for Route {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self {
impl From<String> for Route {
fn from(s: String) -> Self {
Self {
route_type: get_route_type(&s),
path: s,
hidden: false,
description: None,
min_args: None,
max_args: None,
}
}
}
impl From<&'static str> for Route {
fn from(s: &'static str) -> Self {
Self {
route_type: get_route_type(s),
path: s.to_string(),
hidden: false,
description: None,
min_args: None,
max_args: None,
})
}
}
}
@ -66,7 +77,7 @@ impl FromStr for Route {
/// web path. This incurs a disk check operation, but since users shouldn't be
/// updating the config that frequently, it should be fine.
impl<'de> Deserialize<'de> for Route {
fn deserialize<D>(deserializer: D) -> Result<Route, D::Error>
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
@ -93,8 +104,7 @@ impl<'de> Deserialize<'de> for Route {
where
E: serde::de::Error,
{
// This is infallible
Ok(Self::Value::from_str(path).unwrap())
Ok(Self::Value::from(path.to_owned()))
}
fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
@ -147,8 +157,7 @@ impl<'de> Deserialize<'de> for Route {
{
return Err(de::Error::invalid_value(
Unexpected::Other(&format!(
"argument count range {} to {}",
min_args, max_args
"argument count range {min_args} to {max_args}",
)),
&"a valid argument count range",
));
@ -179,12 +188,12 @@ impl std::fmt::Display for Route {
route_type: RouteType::External,
path,
..
} => write!(f, "raw ({})", path),
} => write!(f, "raw ({path})"),
Self {
route_type: RouteType::Internal,
path,
..
} => write!(f, "file ({})", path),
} => write!(f, "file ({path})"),
}
}
}
@ -192,10 +201,10 @@ impl std::fmt::Display for Route {
/// Classifies the path depending on if the there exists a local file.
fn get_route_type(path: &str) -> RouteType {
if std::path::Path::new(path).exists() {
debug!("Parsed {} as a valid local path.", path);
debug!("Parsed {path} as a valid local path.");
RouteType::Internal
} else {
debug!("{} does not exist on disk, assuming web path.", path);
debug!("{path} does not exist on disk, assuming web path.");
RouteType::External
}
}
@ -208,7 +217,7 @@ pub enum RouteType {
Internal,
}
pub struct ConfigData {
pub struct FileData {
pub path: PathBuf,
pub file: File,
}
@ -217,19 +226,19 @@ pub struct ConfigData {
/// locations for a place to write a config file to. In order, it checks the
/// system-wide config location (`/etc/`, in Linux), followed by the config
/// folder, followed by the user's home folder.
pub fn get_config_data() -> Result<ConfigData, BunBunError> {
pub fn get_config_data() -> Result<FileData, BunBunError> {
// Locations to check, with highest priority first
let locations: Vec<_> = {
let mut folders = vec![PathBuf::from("/etc/")];
// Config folder
if let Some(folder) = config_dir() {
folders.push(folder)
folders.push(folder);
}
// Home folder
if let Some(folder) = home_dir() {
folders.push(folder)
folders.push(folder);
}
folders
@ -242,19 +251,18 @@ pub fn get_config_data() -> Result<ConfigData, BunBunError> {
debug!("Checking locations for config file: {:?}", &locations);
for location in &locations {
let file = OpenOptions::new().read(true).open(location.clone());
let file = OpenOptions::new().read(true).open(location);
match file {
Ok(file) => {
debug!("Found file at {:?}.", location);
return Ok(ConfigData {
debug!("Found file at {location:?}.");
return Ok(FileData {
path: location.clone(),
file,
});
}
Err(e) => debug!(
"Tried to read '{:?}' but failed due to error: {}",
location, e
),
Err(e) => {
debug!("Tried to read '{location:?}' but failed due to error: {e}");
}
}
}
@ -270,19 +278,18 @@ pub fn get_config_data() -> Result<ConfigData, BunBunError> {
.open(location.clone());
match file {
Ok(mut file) => {
info!("Creating new config file at {:?}.", location);
info!("Creating new config file at {location:?}.");
file.write_all(DEFAULT_CONFIG)?;
let file = OpenOptions::new().read(true).open(location.clone())?;
return Ok(ConfigData {
return Ok(FileData {
path: location,
file,
});
}
Err(e) => debug!(
"Tried to open a new file at '{:?}' but failed due to error: {}",
location, e
),
Err(e) => {
debug!("Tried to open a new file at '{location:?}' but failed due to error: {e}",)
}
}
}
@ -291,22 +298,17 @@ pub fn get_config_data() -> Result<ConfigData, BunBunError> {
/// Assumes that the user knows what they're talking about and will only try
/// to load the config at the given path.
pub fn load_custom_path_config(
path: impl Into<PathBuf>,
) -> Result<ConfigData, BunBunError> {
pub fn load_custom_file(path: impl Into<PathBuf>) -> Result<FileData, BunBunError> {
let path = path.into();
let file = OpenOptions::new()
.read(true)
.open(&path)
.map_err(|e| BunBunError::InvalidConfigPath(path.clone(), e))?;
Ok(ConfigData { file, path })
Ok(FileData { path, file })
}
pub fn read_config(
mut config_file: File,
large_config: bool,
) -> Result<Config, BunBunError> {
pub fn load_file(mut config_file: File, large_config: bool) -> Result<Config, BunBunError> {
trace!("Loading config file.");
let file_size = config_file.metadata()?.len();
@ -329,90 +331,96 @@ pub fn read_config(
#[cfg(test)]
mod route {
use super::*;
use anyhow::{Context, Result};
use serde_yaml::{from_str, to_string};
use std::path::Path;
use tempfile::NamedTempFile;
#[test]
fn deserialize_relative_path() {
let tmpfile = NamedTempFile::new_in(".").unwrap();
let path = format!("{}", tmpfile.path().display());
let path = path.get(path.rfind(".").unwrap()..).unwrap();
let path = std::path::Path::new(path);
fn deserialize_relative_path() -> Result<()> {
let tmpfile = NamedTempFile::new_in(".")?;
let path = tmpfile.path().display().to_string();
let path = path
.get(path.rfind(".").context("While finding .")?..)
.context("While getting the path")?;
let path = Path::new(path);
assert!(path.is_relative());
let path = path.to_str().unwrap();
assert_eq!(
from_str::<Route>(path).unwrap(),
Route::from_str(path).unwrap()
);
let path = path.to_str().context("While stringifying path")?;
assert_eq!(from_str::<Route>(path)?, Route::from(path.to_owned()));
Ok(())
}
#[test]
fn deserialize_absolute_path() {
let tmpfile = NamedTempFile::new().unwrap();
fn deserialize_absolute_path() -> Result<()> {
let tmpfile = NamedTempFile::new()?;
let path = format!("{}", tmpfile.path().display());
assert!(tmpfile.path().is_absolute());
assert_eq!(
from_str::<Route>(&path).unwrap(),
Route::from_str(&path).unwrap()
);
assert_eq!(from_str::<Route>(&path)?, Route::from(path));
Ok(())
}
#[test]
fn deserialize_http_path() {
fn deserialize_http_path() -> Result<()> {
assert_eq!(
from_str::<Route>("http://google.com").unwrap(),
Route::from_str("http://google.com").unwrap()
from_str::<Route>("http://google.com")?,
Route::from("http://google.com")
);
Ok(())
}
#[test]
fn deserialize_https_path() {
fn deserialize_https_path() -> Result<()> {
assert_eq!(
from_str::<Route>("https://google.com").unwrap(),
Route::from_str("https://google.com").unwrap()
from_str::<Route>("https://google.com")?,
Route::from("https://google.com")
);
Ok(())
}
#[test]
fn serialize() {
fn serialize() -> Result<()> {
assert_eq!(
&to_string(&Route::from_str("hello world").unwrap()).unwrap(),
&to_string(&Route::from("hello world"))?,
"---\nroute_type: External\npath: hello world\nhidden: false\ndescription: ~\nmin_args: ~\nmax_args: ~\n"
);
Ok(())
}
}
#[cfg(test)]
mod read_config {
use super::*;
use anyhow::Result;
#[test]
fn empty_file() {
let config_file = tempfile::tempfile().unwrap();
fn empty_file() -> Result<()> {
let config_file = tempfile::tempfile()?;
assert!(matches!(
read_config(config_file, false),
load_file(config_file, false),
Err(BunBunError::ZeroByteConfig)
));
Ok(())
}
#[test]
fn config_too_large() {
let mut config_file = tempfile::tempfile().unwrap();
fn config_too_large() -> Result<()> {
let mut config_file = tempfile::tempfile()?;
let size_to_write = (LARGE_FILE_SIZE_THRESHOLD + 1) as usize;
config_file.write(&[0].repeat(size_to_write)).unwrap();
match read_config(config_file, false) {
Err(BunBunError::ConfigTooLarge(size))
if size as usize == size_to_write => {}
config_file.write(&[0].repeat(size_to_write))?;
match load_file(config_file, false) {
Err(BunBunError::ConfigTooLarge(size)) if size as usize == size_to_write => {}
Err(BunBunError::ConfigTooLarge(size)) => {
panic!("Mismatched size: {} != {}", size, size_to_write)
panic!("Mismatched size: {size} != {size_to_write}")
}
res => panic!("Wrong result, got {:#?}", res),
res => panic!("Wrong result, got {res:#?}"),
}
Ok(())
}
#[test]
fn valid_config() {
let config_file = File::open("bunbun.default.yaml").unwrap();
assert!(read_config(config_file, false).is_ok());
fn valid_config() -> Result<()> {
assert!(load_file(File::open("bunbun.default.yaml")?, false).is_ok());
Ok(())
}
}

View File

@ -2,11 +2,11 @@ use std::error::Error;
use std::fmt;
#[derive(Debug)]
#[allow(clippy::module_name_repetitions)]
pub enum BunBunError {
Io(std::io::Error),
Parse(serde_yaml::Error),
Watch(hotwatch::Error),
LoggerInit(log::SetLoggerError),
CustomProgram(String),
NoValidConfigPath,
InvalidConfigPath(std::path::PathBuf, std::io::Error),
@ -23,13 +23,12 @@ impl fmt::Display for BunBunError {
Self::Io(e) => e.fmt(f),
Self::Parse(e) => e.fmt(f),
Self::Watch(e) => e.fmt(f),
Self::LoggerInit(e) => e.fmt(f),
Self::CustomProgram(msg) => write!(f, "{}", msg),
Self::CustomProgram(msg) => msg.fmt(f),
Self::NoValidConfigPath => write!(f, "No valid config path was found!"),
Self::InvalidConfigPath(path, reason) => {
write!(f, "Failed to access {:?}: {}", path, reason)
write!(f, "Failed to access {path:?}: {reason}")
}
Self::ConfigTooLarge(size) => write!(f, "The config file was too large ({} bytes)! Pass in --large-config to bypass this check.", size),
Self::ConfigTooLarge(size) => write!(f, "The config file was too large ({size} bytes)! Pass in --large-config to bypass this check."),
Self::ZeroByteConfig => write!(f, "The config provided reported a size of 0 bytes. Please check your config path!"),
Self::JsonParse(e) => e.fmt(f),
}
@ -51,5 +50,4 @@ macro_rules! from_error {
from_error!(std::io::Error, Io);
from_error!(serde_yaml::Error, Parse);
from_error!(hotwatch::Error, Watch);
from_error!(log::SetLoggerError, LoggerInit);
from_error!(serde_json::Error, JsonParse);

View File

@ -1,25 +1,26 @@
#![forbid(unsafe_code)]
#![deny(missing_docs)]
#![warn(clippy::nursery, clippy::pedantic)]
//! Bunbun is a pure-Rust implementation of bunny1 that provides a customizable
//! search engine and quick-jump tool in one small binary. For information on
//! usage, please take a look at the readme.
use crate::config::{
get_config_data, load_custom_path_config, read_config, ConfigData, Route,
RouteGroup,
};
use actix_web::{middleware::Logger, App, HttpServer};
use clap::Clap;
use crate::config::{get_config_data, load_custom_file, load_file, FileData, Route, RouteGroup};
use anyhow::Result;
use arc_swap::ArcSwap;
use axum::routing::get;
use axum::{Extension, Router};
use clap::Parser;
use error::BunBunError;
use handlebars::{Handlebars, TemplateError};
use handlebars::Handlebars;
use hotwatch::{Event, Hotwatch};
use log::{debug, error, info, trace, warn};
use simple_logger::SimpleLogger;
use std::cmp::min;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, info, trace, warn};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
mod cli;
mod config;
@ -39,100 +40,70 @@ pub struct State {
routes: HashMap<String, Route>,
}
#[actix_web::main]
#[tokio::main]
#[cfg(not(tarpaulin_include))]
async fn main() {
std::process::exit(match run().await {
Ok(_) => 0,
Err(e) => {
error!("{}", e);
1
}
})
}
async fn main() -> Result<()> {
use tracing_subscriber::EnvFilter;
#[cfg(not(tarpaulin_include))]
async fn run() -> Result<(), BunBunError> {
let opts = cli::Opts::parse();
init_logger(opts.verbose, opts.quiet)?;
let mut env_filter = EnvFilter::from_default_env();
for directive in opts.log {
env_filter = env_filter.add_directive(directive);
}
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(env_filter)
.init();
let conf_data = match opts.config {
Some(file_name) => load_custom_path_config(file_name),
Some(file_name) => load_custom_file(file_name),
None => get_config_data(),
}?;
let conf = read_config(conf_data.file.try_clone()?, opts.large_config)?;
let state = Arc::from(RwLock::new(State {
let conf = load_file(conf_data.file.try_clone()?, opts.large_config)?;
let state = Arc::from(ArcSwap::from_pointee(State {
public_address: conf.public_address,
default_route: conf.default_route,
routes: cache_routes(&conf.groups),
routes: cache_routes(conf.groups.clone()),
groups: conf.groups,
}));
// Cannot be named _ or Rust will immediately drop it.
let _watch = start_watch(Arc::clone(&state), conf_data, opts.large_config)?;
let _watch = start_watch(Arc::clone(&state), conf_data, opts.large_config);
HttpServer::new(move || {
let templates = match compile_templates() {
Ok(templates) => templates,
// This implies a template error, which should be a compile time error. If
// we reach here then the release is very broken.
Err(e) => unreachable!("Failed to compile templates: {}", e),
};
App::new()
.data(Arc::clone(&state))
.app_data(templates)
.wrap(Logger::default())
.service(routes::hop)
.service(routes::list)
.service(routes::index)
.service(routes::opensearch)
})
.bind(&conf.bind_address)?
.run()
let app = Router::new()
.route("/", get(routes::index))
.route("/bunbunsearch.xml", get(routes::opensearch))
.route("/ls", get(routes::list))
.route("/hop", get(routes::hop))
.layer(Extension(compile_templates()?))
.layer(Extension(state));
let bind_addr = conf.bind_address.parse()?;
info!("Starting server at {bind_addr}");
axum::Server::bind(&bind_addr)
.serve(app.into_make_service())
.await?;
Ok(())
}
/// Initializes the logger based on the number of quiet and verbose flags passed
/// in. Usually, these values are mutually exclusive, that is, if the number of
/// verbose flags is non-zero then the quiet flag is zero, and vice versa.
#[cfg(not(tarpaulin_include))]
fn init_logger(
num_verbose_flags: u8,
num_quiet_flags: u8,
) -> Result<(), BunBunError> {
let log_level =
match min(num_verbose_flags, 3) as i8 - min(num_quiet_flags, 2) as i8 {
-2 => None,
-1 => Some(log::LevelFilter::Error),
0 => Some(log::LevelFilter::Warn),
1 => Some(log::LevelFilter::Info),
2 => Some(log::LevelFilter::Debug),
3 => Some(log::LevelFilter::Trace),
_ => unreachable!(), // values are clamped to [0, 3] - [0, 2]
};
if let Some(level) = log_level {
SimpleLogger::new().with_level(level).init()?;
}
Ok(())
}
/// Generates a hashmap of routes from the data structure created by the config
/// file. This should improve runtime performance and is a better solution than
/// just iterating over the config object for every hop resolution.
fn cache_routes(groups: &[RouteGroup]) -> HashMap<String, Route> {
fn cache_routes(groups: Vec<RouteGroup>) -> HashMap<String, Route> {
let mut mapping = HashMap::new();
for group in groups {
for (kw, dest) in &group.routes {
for (kw, dest) in group.routes {
// This function isn't called often enough to not be a performance issue.
match mapping.insert(kw.clone(), dest.clone()) {
None => trace!("Inserting {} into mapping.", kw),
None => trace!("Inserting {kw} into mapping."),
Some(old_value) => {
trace!("Overriding {} route from {} to {}.", kw, old_value, dest)
trace!("Overriding {kw} route from {old_value} to {dest}.");
}
}
}
@ -143,7 +114,7 @@ fn cache_routes(groups: &[RouteGroup]) -> HashMap<String, Route> {
/// Returns an instance with all pre-generated templates included into the
/// binary. This allows for users to have a portable binary without needed the
/// templates at runtime.
fn compile_templates() -> Result<Handlebars<'static>, TemplateError> {
fn compile_templates() -> Result<Handlebars<'static>> {
let mut handlebars = Handlebars::new();
handlebars.set_strict_mode(true);
handlebars.register_partial("bunbun_version", env!("CARGO_PKG_VERSION"))?;
@ -176,106 +147,68 @@ fn compile_templates() -> Result<Handlebars<'static>, TemplateError> {
/// watches.
#[cfg(not(tarpaulin_include))]
fn start_watch(
state: Arc<RwLock<State>>,
config_data: ConfigData,
state: Arc<ArcSwap<State>>,
config_data: FileData,
large_config: bool,
) -> Result<Hotwatch, BunBunError> {
) -> Result<Hotwatch> {
let mut watch = Hotwatch::new_with_custom_delay(Duration::from_millis(500))?;
let ConfigData { path, mut file } = config_data;
let FileData { path, mut file } = config_data;
let watch_result = watch.watch(&path, move |e: Event| {
if let Event::Create(ref path) = e {
file = load_custom_path_config(path)
.expect("file to exist at path")
.file;
file = load_custom_file(path).expect("file to exist at path").file;
trace!("Getting new file handler as file was recreated.");
}
match e {
Event::Write(_) | Event::Create(_) => {
trace!("Grabbing writer lock on state...");
let mut state =
state.write().expect("Failed to get write lock on state");
trace!("Obtained writer lock on state!");
match read_config(
match load_file(
file.try_clone().expect("Failed to clone file handle"),
large_config,
) {
Ok(conf) => {
state.public_address = conf.public_address;
state.default_route = conf.default_route;
state.routes = cache_routes(&conf.groups);
state.groups = conf.groups;
state.store(Arc::new(State {
public_address: conf.public_address,
default_route: conf.default_route,
routes: cache_routes(conf.groups.clone()),
groups: conf.groups,
}));
info!("Successfully updated active state");
}
Err(e) => warn!("Failed to update config file: {}", e),
Err(e) => warn!("Failed to update config file: {e}"),
}
}
_ => debug!("Saw event {:#?} but ignored it", e),
_ => debug!("Saw event {e:#?} but ignored it"),
}
});
match watch_result {
Ok(_) => info!("Watcher is now watching {:?}", &path),
Err(e) => warn!(
"Couldn't watch {:?}: {}. Changes to this file won't be seen!",
&path, e
),
Ok(_) => info!("Watcher is now watching {path:?}"),
Err(e) => {
warn!("Couldn't watch {path:?}: {e}. Changes to this file won't be seen!");
}
}
Ok(watch)
}
#[cfg(test)]
mod init_logger {
use super::*;
#[test]
fn defaults_to_warn() -> Result<(), BunBunError> {
init_logger(0, 0)?;
assert_eq!(log::max_level(), log::Level::Warn);
Ok(())
}
// The following tests work but because the log crate is global, initializing
// the logger more than once (read: testing it more than once) leads to a
// panic. These ignored tests must be manually tested.
#[test]
#[ignore]
fn caps_to_2_when_log_level_is_lt_2() -> Result<(), BunBunError> {
init_logger(0, 3)?;
assert_eq!(log::max_level(), log::LevelFilter::Off);
Ok(())
}
#[test]
#[ignore]
fn caps_to_3_when_log_level_is_gt_3() -> Result<(), BunBunError> {
init_logger(4, 0)?;
assert_eq!(log::max_level(), log::Level::Trace);
Ok(())
}
}
#[cfg(test)]
mod cache_routes {
use super::*;
use std::iter::FromIterator;
use std::str::FromStr;
fn generate_external_routes(
routes: &[(&str, &str)],
) -> HashMap<String, Route> {
fn generate_external_routes(routes: &[(&'static str, &'static str)]) -> HashMap<String, Route> {
HashMap::from_iter(
routes
.into_iter()
.map(|kv| (kv.0.into(), Route::from_str(kv.1).unwrap())),
.map(|(key, value)| ((*key).to_owned(), Route::from(*value))),
)
}
#[test]
fn empty_groups_yield_empty_routes() {
assert_eq!(cache_routes(&[]), HashMap::new());
assert_eq!(cache_routes(Vec::new()), HashMap::new());
}
#[test]
@ -295,13 +228,8 @@ mod cache_routes {
};
assert_eq!(
cache_routes(&[group1, group2]),
generate_external_routes(&[
("a", "b"),
("c", "d"),
("1", "2"),
("3", "4")
])
cache_routes(vec![group1, group2]),
generate_external_routes(&[("a", "b"), ("c", "d"), ("1", "2"), ("3", "4")])
);
}
@ -322,7 +250,7 @@ mod cache_routes {
};
assert_eq!(
cache_routes(&[group1.clone(), group2]),
cache_routes(vec![group1.clone(), group2]),
generate_external_routes(&[("a", "1"), ("c", "2")])
);
@ -334,7 +262,7 @@ mod cache_routes {
};
assert_eq!(
cache_routes(&[group1, group3]),
cache_routes(vec![group1, group3]),
generate_external_routes(&[("a", "1"), ("b", "2"), ("c", "d")])
);
}

View File

@ -1,18 +1,21 @@
use crate::config::{Route as ConfigRoute, RouteType};
use crate::{template_args, BunBunError, Route, State};
use actix_web::web::{Data, Query};
use actix_web::{get, http::header};
use actix_web::{HttpRequest, HttpResponse, Responder};
use arc_swap::ArcSwap;
use axum::body::{boxed, Bytes, Full};
use axum::extract::Query;
use axum::http::{header, StatusCode};
use axum::response::{Html, IntoResponse, Response};
use axum::Extension;
use handlebars::Handlebars;
use log::{debug, error};
use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
use serde::Deserialize;
use std::collections::HashMap;
use std::path::PathBuf;
use std::path::Path;
use std::process::Command;
use std::sync::{Arc, RwLock};
use std::sync::Arc;
use tracing::{debug, error};
/// https://url.spec.whatwg.org/#fragment-percent-encode-set
// https://url.spec.whatwg.org/#fragment-percent-encode-set
const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS
.add(b' ')
.add(b'"')
@ -21,73 +24,69 @@ const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS
.add(b'`')
.add(b'+')
.add(b'&') // Interpreted as a GET query
.add(b'#'); // Interpreted as a hyperlink section target
.add(b'#') // Interpreted as a hyperlink section target
.add(b'\'');
type StateData = Data<Arc<RwLock<State>>>;
#[get("/")]
pub async fn index(data: StateData, req: HttpRequest) -> impl Responder {
let data = data.read().unwrap();
HttpResponse::Ok()
.set_header(header::CONTENT_TYPE, "text/html; charset=utf-8")
.body(
req
.app_data::<Handlebars>()
.unwrap()
#[allow(clippy::unused_async)]
pub async fn index(
Extension(data): Extension<Arc<ArcSwap<State>>>,
Extension(handlebars): Extension<Handlebars<'static>>,
) -> impl IntoResponse {
handlebars
.render(
"index",
&template_args::hostname(data.public_address.clone()),
)
.unwrap(),
&template_args::hostname(&data.load().public_address),
)
.map(Html)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
#[get("/bunbunsearch.xml")]
pub async fn opensearch(data: StateData, req: HttpRequest) -> impl Responder {
let data = data.read().unwrap();
HttpResponse::Ok()
.header(
header::CONTENT_TYPE,
"application/opensearchdescription+xml",
)
.body(
req
.app_data::<Handlebars>()
.unwrap()
#[allow(clippy::unused_async)]
pub async fn opensearch(
Extension(data): Extension<Arc<ArcSwap<State>>>,
Extension(handlebars): Extension<Handlebars<'static>>,
) -> impl IntoResponse {
handlebars
.render(
"opensearch",
&template_args::hostname(data.public_address.clone()),
&template_args::hostname(&data.load().public_address),
)
.unwrap(),
.map(|body| {
(
StatusCode::OK,
[(
header::CONTENT_TYPE,
"application/opensearchdescription+xml",
)],
body,
)
})
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
#[get("/ls")]
pub async fn list(data: StateData, req: HttpRequest) -> impl Responder {
let data = data.read().unwrap();
HttpResponse::Ok()
.set_header(header::CONTENT_TYPE, "text/html; charset=utf-8")
.body(
req
.app_data::<Handlebars>()
.unwrap()
.render("list", &data.groups)
.unwrap(),
)
#[allow(clippy::unused_async)]
pub async fn list(
Extension(data): Extension<Arc<ArcSwap<State>>>,
Extension(handlebars): Extension<Handlebars<'static>>,
) -> impl IntoResponse {
handlebars
.render("list", &data.load().groups)
.map(Html)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
pub struct SearchQuery {
to: String,
}
#[get("/hop")]
#[allow(clippy::unused_async)]
pub async fn hop(
data: StateData,
req: HttpRequest,
query: Query<SearchQuery>,
) -> impl Responder {
let data = data.read().unwrap();
Extension(data): Extension<Arc<ArcSwap<State>>>,
Extension(handlebars): Extension<Handlebars<'static>>,
Query(query): Query<SearchQuery>,
) -> impl IntoResponse {
let data = data.load();
match resolve_hop(&query.to, &data.routes, &data.default_route) {
RouteResolution::Resolved { route: path, args } => {
@ -96,7 +95,7 @@ pub async fn hop(
route_type: RouteType::Internal,
path,
..
} => resolve_path(PathBuf::from(path), &args),
} => resolve_path(Path::new(path), &args),
ConfigRoute {
route_type: RouteType::External,
path,
@ -105,30 +104,34 @@ pub async fn hop(
};
match resolved_template {
Ok(HopAction::Redirect(path)) => HttpResponse::Found()
.header(
header::LOCATION,
req
.app_data::<Handlebars>()
.unwrap()
Ok(HopAction::Redirect(path)) => {
let rendered = handlebars
.render_template(
std::str::from_utf8(path.as_bytes()).unwrap(),
&template_args::query(
utf8_percent_encode(&args, FRAGMENT_ENCODE_SET).to_string(),
),
&path,
&template_args::query(utf8_percent_encode(&args, FRAGMENT_ENCODE_SET)),
)
.unwrap(),
)
.finish(),
Ok(HopAction::Body(body)) => HttpResponse::Ok().body(body),
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Response::builder()
.status(StatusCode::FOUND)
.header(header::LOCATION, &path)
.body(boxed(Full::from(rendered)))
}
Ok(HopAction::Body(body)) => Response::builder()
.status(StatusCode::OK)
.body(boxed(Full::new(Bytes::from(body)))),
Err(e) => {
error!("Failed to redirect user for {}: {}", path, e);
HttpResponse::InternalServerError().body("Something went wrong :(\n")
error!("Failed to redirect user for {path}: {e}");
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(boxed(Full::from("Something went wrong :(\n")))
}
}
}
RouteResolution::Unresolved => HttpResponse::NotFound().body("not found"),
RouteResolution::Unresolved => Response::builder()
.status(StatusCode::NOT_FOUND)
.body(boxed(Full::from("not found\n"))),
}
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
#[derive(Debug, PartialEq)]
@ -167,7 +170,7 @@ fn resolve_hop<'a>(
let args = if args.is_empty() { &[] } else { &args[1..] }.join(" ");
let arg_count = arg_count - 1;
if check_route(route, arg_count) {
debug!("Resolved {} with args {}", route, args);
debug!("Resolved {route} with args {args}");
return RouteResolution::Resolved { route, args };
}
}
@ -177,7 +180,7 @@ fn resolve_hop<'a>(
if let Some(route) = routes.get(route) {
if check_route(route, arg_count) {
let args = args.join(" ");
debug!("Using default route {} with args {}", route, args);
debug!("Using default route {route} with args {args}");
return RouteResolution::Resolved { route, args };
}
}
@ -188,7 +191,7 @@ fn resolve_hop<'a>(
/// Checks if the user provided string has the correct properties required by
/// the route to be successfully matched.
fn check_route(route: &Route, arg_count: usize) -> bool {
const fn check_route(route: &Route, arg_count: usize) -> bool {
if let Some(min_args) = route.min_args {
if arg_count < min_args {
return false;
@ -215,7 +218,7 @@ enum HopAction {
/// so long as the executable was successfully executed. Returns an Error if the
/// file doesn't exist or bunbun did not have permission to read and execute the
/// file.
fn resolve_path(path: PathBuf, args: &str) -> Result<HopAction, BunBunError> {
fn resolve_path(path: &Path, args: &str) -> Result<HopAction, BunBunError> {
let output = Command::new(path.canonicalize()?)
.args(args.split(' '))
.output()?;
@ -235,12 +238,9 @@ fn resolve_path(path: PathBuf, args: &str) -> Result<HopAction, BunBunError> {
#[cfg(test)]
mod resolve_hop {
use super::*;
use std::str::FromStr;
use anyhow::Result;
fn generate_route_result<'a>(
keyword: &'a Route,
args: &str,
) -> RouteResolution<'a> {
fn generate_route_result<'a>(keyword: &'a Route, args: &str) -> RouteResolution<'a> {
RouteResolution::Resolved {
route: keyword,
args: String::from(args),
@ -268,51 +268,36 @@ mod resolve_hop {
}
#[test]
fn only_default_routes_some_default_yields_default_hop() {
fn only_default_routes_some_default_yields_default_hop() -> Result<()> {
let mut map: HashMap<String, Route> = HashMap::new();
map.insert(
"google".into(),
Route::from_str("https://example.com").unwrap(),
);
map.insert("google".into(), Route::from("https://example.com"));
assert_eq!(
resolve_hop("hello world", &map, &Some(String::from("google"))),
generate_route_result(
&Route::from_str("https://example.com").unwrap(),
"hello world"
),
generate_route_result(&Route::from("https://example.com"), "hello world"),
);
Ok(())
}
#[test]
fn non_default_routes_some_default_yields_non_default_hop() {
fn non_default_routes_some_default_yields_non_default_hop() -> Result<()> {
let mut map: HashMap<String, Route> = HashMap::new();
map.insert(
"google".into(),
Route::from_str("https://example.com").unwrap(),
);
map.insert("google".into(), Route::from("https://example.com"));
assert_eq!(
resolve_hop("google hello world", &map, &Some(String::from("a"))),
generate_route_result(
&Route::from_str("https://example.com").unwrap(),
"hello world"
),
generate_route_result(&Route::from("https://example.com"), "hello world"),
);
Ok(())
}
#[test]
fn non_default_routes_no_default_yields_non_default_hop() {
fn non_default_routes_no_default_yields_non_default_hop() -> Result<()> {
let mut map: HashMap<String, Route> = HashMap::new();
map.insert(
"google".into(),
Route::from_str("https://example.com").unwrap(),
);
map.insert("google".into(), Route::from("https://example.com"));
assert_eq!(
resolve_hop("google hello world", &map, &None),
generate_route_result(
&Route::from_str("https://example.com").unwrap(),
"hello world"
),
generate_route_result(&Route::from("https://example.com"), "hello world"),
);
Ok(())
}
}
@ -369,62 +354,65 @@ mod check_route {
#[cfg(test)]
mod resolve_path {
use crate::error::BunBunError;
use super::{resolve_path, HopAction};
use anyhow::Result;
use std::env::current_dir;
use std::path::PathBuf;
use std::io::ErrorKind;
use std::path::{Path, PathBuf};
#[test]
fn invalid_path_returns_err() {
assert!(resolve_path(PathBuf::from("/bin/aaaa"), "aaaa").is_err());
assert!(resolve_path(&Path::new("/bin/aaaa"), "aaaa").is_err());
}
#[test]
fn valid_path_returns_ok() {
assert!(
resolve_path(PathBuf::from("/bin/echo"), r#"{"body": "a"}"#).is_ok()
);
assert!(resolve_path(&Path::new("/bin/echo"), r#"{"body": "a"}"#).is_ok());
}
#[test]
fn relative_path_returns_ok() {
fn relative_path_returns_ok() -> Result<()> {
// How many ".." needed to get to /
let nest_level = current_dir().unwrap().ancestors().count() - 1;
let nest_level = current_dir()?.ancestors().count() - 1;
let mut rel_path = PathBuf::from("../".repeat(nest_level));
rel_path.push("./bin/echo");
assert!(resolve_path(rel_path, r#"{"body": "a"}"#).is_ok());
assert!(resolve_path(&rel_path, r#"{"body": "a"}"#).is_ok());
Ok(())
}
#[test]
fn no_permissions_returns_err() {
assert!(
// Trying to run a command without permission
format!(
"{}",
resolve_path(PathBuf::from("/root/some_exec"), "").unwrap_err()
)
.contains("Permission denied")
);
let result = match resolve_path(&Path::new("/root/some_exec"), "") {
Err(BunBunError::Io(e)) => e.kind() == ErrorKind::PermissionDenied,
_ => false,
};
assert!(result);
}
#[test]
fn non_success_exit_code_yields_err() {
// cat-ing a folder always returns exit code 1
assert!(resolve_path(PathBuf::from("/bin/cat"), "/").is_err());
assert!(resolve_path(&Path::new("/bin/cat"), "/").is_err());
}
#[test]
fn return_body() {
fn return_body() -> Result<()> {
assert_eq!(
resolve_path(PathBuf::from("/bin/echo"), r#"{"body": "a"}"#).unwrap(),
resolve_path(&Path::new("/bin/echo"), r#"{"body": "a"}"#)?,
HopAction::Body("a".to_string())
);
Ok(())
}
#[test]
fn return_redirect() {
fn return_redirect() -> Result<()> {
assert_eq!(
resolve_path(PathBuf::from("/bin/echo"), r#"{"redirect": "a"}"#).unwrap(),
resolve_path(&Path::new("/bin/echo"), r#"{"redirect": "a"}"#)?,
HopAction::Redirect("a".to_string())
);
Ok(())
}
}

View File

@ -1,17 +1,22 @@
use std::borrow::Cow;
use percent_encoding::PercentEncode;
use serde::Serialize;
pub fn query(query: String) -> impl Serialize {
pub fn query(query: PercentEncode<'_>) -> impl Serialize + '_ {
#[derive(Serialize)]
struct TemplateArgs {
query: String,
struct TemplateArgs<'a> {
query: Cow<'a, str>,
}
TemplateArgs {
query: query.into(),
}
TemplateArgs { query }
}
pub fn hostname(hostname: String) -> impl Serialize {
pub fn hostname(hostname: &'_ str) -> impl Serialize + '_ {
#[derive(Serialize)]
pub struct TemplateArgs {
pub hostname: String,
pub struct TemplateArgs<'a> {
pub hostname: &'a str,
}
TemplateArgs { hostname }
}