master
Edward Shen 2022-06-02 22:46:10 -07:00
parent 90ff4461a6
commit f1d7797637
Signed by: edward
GPG Key ID: 19182661E818369F
7 changed files with 868 additions and 895 deletions

View File

@ -1,3 +0,0 @@
tab_spaces = 2
use_field_init_shorthand = true
max_width = 80

View File

@ -4,16 +4,16 @@ use std::path::PathBuf;
#[derive(Parser)] #[derive(Parser)]
#[clap(version = crate_version!(), author = crate_authors!())] #[clap(version = crate_version!(), author = crate_authors!())]
pub struct Opts { pub struct Opts {
/// Increases the log level to info, debug, and trace, respectively. /// Increases the log level to info, debug, and trace, respectively.
#[clap(short, long, parse(from_occurrences), conflicts_with("quiet"))] #[clap(short, long, parse(from_occurrences), conflicts_with("quiet"))]
pub verbose: u8, pub verbose: u8,
/// Decreases the log level to error or no logging at all, respectively. /// Decreases the log level to error or no logging at all, respectively.
#[clap(short, long, parse(from_occurrences), conflicts_with("verbose"))] #[clap(short, long, parse(from_occurrences), conflicts_with("verbose"))]
pub quiet: u8, pub quiet: u8,
/// Specify the location of the config file to read from. Needs read/write permissions. /// Specify the location of the config file to read from. Needs read/write permissions.
#[clap(short, long)] #[clap(short, long)]
pub config: Option<PathBuf>, pub config: Option<PathBuf>,
/// Allow config sizes larger than 100MB. /// Allow config sizes larger than 100MB.
#[clap(long)] #[clap(long)]
pub large_config: bool, pub large_config: bool,
} }

View File

@ -2,8 +2,8 @@ use crate::BunBunError;
use dirs::{config_dir, home_dir}; use dirs::{config_dir, home_dir};
use log::{debug, info, trace}; use log::{debug, info, trace};
use serde::{ use serde::{
de::{self, Deserializer, MapAccess, Unexpected, Visitor}, de::{self, Deserializer, MapAccess, Unexpected, Visitor},
Deserialize, Serialize, Deserialize, Serialize,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::fmt;
@ -20,55 +20,55 @@ const LARGE_FILE_SIZE_THRESHOLD: u64 = 1_000_000;
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
pub struct Config { pub struct Config {
pub bind_address: String, pub bind_address: String,
pub public_address: String, pub public_address: String,
pub default_route: Option<String>, pub default_route: Option<String>,
pub groups: Vec<RouteGroup>, pub groups: Vec<RouteGroup>,
} }
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] #[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct RouteGroup { pub struct RouteGroup {
pub name: String, pub name: String,
pub description: Option<String>, pub description: Option<String>,
#[serde(default)] #[serde(default)]
pub hidden: bool, pub hidden: bool,
pub routes: HashMap<String, Route>, pub routes: HashMap<String, Route>,
} }
#[derive(Debug, PartialEq, Clone, Serialize)] #[derive(Debug, PartialEq, Clone, Serialize)]
pub struct Route { pub struct Route {
pub route_type: RouteType, pub route_type: RouteType,
pub path: String, pub path: String,
pub hidden: bool, pub hidden: bool,
pub description: Option<String>, pub description: Option<String>,
pub min_args: Option<usize>, pub min_args: Option<usize>,
pub max_args: Option<usize>, pub max_args: Option<usize>,
} }
impl From<String> for Route { impl From<String> for Route {
fn from(s: String) -> Self { fn from(s: String) -> Self {
Self { Self {
route_type: get_route_type(&s), route_type: get_route_type(&s),
path: s, path: s,
hidden: false, hidden: false,
description: None, description: None,
min_args: None, min_args: None,
max_args: None, max_args: None,
}
} }
}
} }
impl From<&'static str> for Route { impl From<&'static str> for Route {
fn from(s: &'static str) -> Self { fn from(s: &'static str) -> Self {
Self { Self {
route_type: get_route_type(s), route_type: get_route_type(s),
path: s.to_string(), path: s.to_string(),
hidden: false, hidden: false,
description: None, description: None,
min_args: None, min_args: None,
max_args: None, max_args: None,
}
} }
}
} }
/// Deserialization of the route string into the enum requires us to figure out /// Deserialization of the route string into the enum requires us to figure out
@ -77,149 +77,149 @@ impl From<&'static str> for Route {
/// web path. This incurs a disk check operation, but since users shouldn't be /// web path. This incurs a disk check operation, but since users shouldn't be
/// updating the config that frequently, it should be fine. /// updating the config that frequently, it should be fine.
impl<'de> Deserialize<'de> for Route { impl<'de> Deserialize<'de> for Route {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")] #[serde(field_identifier, rename_all = "snake_case")]
enum Field { enum Field {
Path, Path,
Hidden, Hidden,
Description, Description,
MinArgs, MinArgs,
MaxArgs, MaxArgs,
}
struct RouteVisitor;
impl<'de> Visitor<'de> for RouteVisitor {
type Value = Route;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("string")
}
fn visit_str<E>(self, path: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Self::Value::from(path.to_owned()))
}
fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut path = None;
let mut hidden = None;
let mut description = None;
let mut min_args = None;
let mut max_args = None;
while let Some(key) = map.next_key()? {
match key {
Field::Path => {
if path.is_some() {
return Err(de::Error::duplicate_field("path"));
}
path = Some(map.next_value::<String>()?);
}
Field::Hidden => {
if hidden.is_some() {
return Err(de::Error::duplicate_field("hidden"));
}
hidden = map.next_value()?;
}
Field::Description => {
if description.is_some() {
return Err(de::Error::duplicate_field("description"));
}
description = Some(map.next_value()?);
}
Field::MinArgs => {
if min_args.is_some() {
return Err(de::Error::duplicate_field("min_args"));
}
min_args = Some(map.next_value()?);
}
Field::MaxArgs => {
if max_args.is_some() {
return Err(de::Error::duplicate_field("max_args"));
}
max_args = Some(map.next_value()?);
}
}
} }
if let (Some(min_args), Some(max_args)) = (min_args, max_args) { struct RouteVisitor;
if min_args > max_args {
impl<'de> Visitor<'de> for RouteVisitor {
type Value = Route;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("string")
}
fn visit_str<E>(self, path: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{ {
return Err(de::Error::invalid_value( Ok(Self::Value::from(path.to_owned()))
Unexpected::Other(&format!( }
"argument count range {min_args} to {max_args}",
)), fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
&"a valid argument count range", where
)); M: MapAccess<'de>,
{
let mut path = None;
let mut hidden = None;
let mut description = None;
let mut min_args = None;
let mut max_args = None;
while let Some(key) = map.next_key()? {
match key {
Field::Path => {
if path.is_some() {
return Err(de::Error::duplicate_field("path"));
}
path = Some(map.next_value::<String>()?);
}
Field::Hidden => {
if hidden.is_some() {
return Err(de::Error::duplicate_field("hidden"));
}
hidden = map.next_value()?;
}
Field::Description => {
if description.is_some() {
return Err(de::Error::duplicate_field("description"));
}
description = Some(map.next_value()?);
}
Field::MinArgs => {
if min_args.is_some() {
return Err(de::Error::duplicate_field("min_args"));
}
min_args = Some(map.next_value()?);
}
Field::MaxArgs => {
if max_args.is_some() {
return Err(de::Error::duplicate_field("max_args"));
}
max_args = Some(map.next_value()?);
}
}
}
if let (Some(min_args), Some(max_args)) = (min_args, max_args) {
if min_args > max_args {
{
return Err(de::Error::invalid_value(
Unexpected::Other(&format!(
"argument count range {min_args} to {max_args}",
)),
&"a valid argument count range",
));
}
}
}
let path = path.ok_or_else(|| de::Error::missing_field("path"))?;
Ok(Route {
route_type: get_route_type(&path),
path,
hidden: hidden.unwrap_or_default(),
description,
min_args,
max_args,
})
} }
}
} }
let path = path.ok_or_else(|| de::Error::missing_field("path"))?; deserializer.deserialize_any(RouteVisitor)
Ok(Route {
route_type: get_route_type(&path),
path,
hidden: hidden.unwrap_or_default(),
description,
min_args,
max_args,
})
}
} }
deserializer.deserialize_any(RouteVisitor)
}
} }
impl std::fmt::Display for Route { impl std::fmt::Display for Route {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Self { Self {
route_type: RouteType::External, route_type: RouteType::External,
path, path,
.. ..
} => write!(f, "raw ({path})"), } => write!(f, "raw ({path})"),
Self { Self {
route_type: RouteType::Internal, route_type: RouteType::Internal,
path, path,
.. ..
} => write!(f, "file ({path})"), } => write!(f, "file ({path})"),
}
} }
}
} }
/// Classifies the path depending on if the there exists a local file. /// Classifies the path depending on if the there exists a local file.
fn get_route_type(path: &str) -> RouteType { fn get_route_type(path: &str) -> RouteType {
if std::path::Path::new(path).exists() { if std::path::Path::new(path).exists() {
debug!("Parsed {path} as a valid local path."); debug!("Parsed {path} as a valid local path.");
RouteType::Internal RouteType::Internal
} else { } else {
debug!("{path} does not exist on disk, assuming web path."); debug!("{path} does not exist on disk, assuming web path.");
RouteType::External RouteType::External
} }
} }
/// There exists two route types: an external path (e.g. a URL) or an internal /// There exists two route types: an external path (e.g. a URL) or an internal
/// path (to a file). /// path (to a file).
#[derive(Debug, PartialEq, Clone, Serialize)] #[derive(Debug, PartialEq, Clone, Serialize)]
pub enum RouteType { pub enum RouteType {
External, External,
Internal, Internal,
} }
pub struct FileData { pub struct FileData {
pub path: PathBuf, pub path: PathBuf,
pub file: File, pub file: File,
} }
/// If a provided config path isn't found, this function checks known good /// If a provided config path isn't found, this function checks known good
@ -227,206 +227,200 @@ pub struct FileData {
/// system-wide config location (`/etc/`, in Linux), followed by the config /// system-wide config location (`/etc/`, in Linux), followed by the config
/// folder, followed by the user's home folder. /// folder, followed by the user's home folder.
pub fn get_config_data() -> Result<FileData, BunBunError> { pub fn get_config_data() -> Result<FileData, BunBunError> {
// Locations to check, with highest priority first // Locations to check, with highest priority first
let locations: Vec<_> = { let locations: Vec<_> = {
let mut folders = vec![PathBuf::from("/etc/")]; let mut folders = vec![PathBuf::from("/etc/")];
// Config folder // Config folder
if let Some(folder) = config_dir() { if let Some(folder) = config_dir() {
folders.push(folder); folders.push(folder);
}
// Home folder
if let Some(folder) = home_dir() {
folders.push(folder);
}
folders
.iter_mut()
.for_each(|folder| folder.push(CONFIG_FILENAME));
folders
};
debug!("Checking locations for config file: {:?}", &locations);
for location in &locations {
let file = OpenOptions::new().read(true).open(location);
match file {
Ok(file) => {
debug!("Found file at {location:?}.");
return Ok(FileData {
path: location.clone(),
file,
});
}
Err(e) => {
debug!("Tried to read '{location:?}' but failed due to error: {e}");
}
}
} }
// Home folder debug!("Failed to find any config. Now trying to find first writable path");
if let Some(folder) = home_dir() {
folders.push(folder); // If we got here, we failed to read any file paths, meaning no config exists
// yet. In that case, try to return the first location that we can write to,
// after writing the default config
for location in locations {
let file = OpenOptions::new()
.write(true)
.create_new(true)
.open(location.clone());
match file {
Ok(mut file) => {
info!("Creating new config file at {location:?}.");
file.write_all(DEFAULT_CONFIG)?;
let file = OpenOptions::new().read(true).open(location.clone())?;
return Ok(FileData {
path: location,
file,
});
}
Err(e) => {
debug!("Tried to open a new file at '{location:?}' but failed due to error: {e}",)
}
}
} }
folders Err(BunBunError::NoValidConfigPath)
.iter_mut()
.for_each(|folder| folder.push(CONFIG_FILENAME));
folders
};
debug!("Checking locations for config file: {:?}", &locations);
for location in &locations {
let file = OpenOptions::new().read(true).open(location);
match file {
Ok(file) => {
debug!("Found file at {location:?}.");
return Ok(FileData {
path: location.clone(),
file,
});
}
Err(e) => {
debug!("Tried to read '{location:?}' but failed due to error: {e}");
}
}
}
debug!("Failed to find any config. Now trying to find first writable path");
// If we got here, we failed to read any file paths, meaning no config exists
// yet. In that case, try to return the first location that we can write to,
// after writing the default config
for location in locations {
let file = OpenOptions::new()
.write(true)
.create_new(true)
.open(location.clone());
match file {
Ok(mut file) => {
info!("Creating new config file at {location:?}.");
file.write_all(DEFAULT_CONFIG)?;
let file = OpenOptions::new().read(true).open(location.clone())?;
return Ok(FileData {
path: location,
file,
});
}
Err(e) => debug!(
"Tried to open a new file at '{location:?}' but failed due to error: {e}",
),
}
}
Err(BunBunError::NoValidConfigPath)
} }
/// Assumes that the user knows what they're talking about and will only try /// Assumes that the user knows what they're talking about and will only try
/// to load the config at the given path. /// to load the config at the given path.
pub fn load_custom_file( pub fn load_custom_file(path: impl Into<PathBuf>) -> Result<FileData, BunBunError> {
path: impl Into<PathBuf>, let path = path.into();
) -> Result<FileData, BunBunError> { let file = OpenOptions::new()
let path = path.into(); .read(true)
let file = OpenOptions::new() .open(&path)
.read(true) .map_err(|e| BunBunError::InvalidConfigPath(path.clone(), e))?;
.open(&path)
.map_err(|e| BunBunError::InvalidConfigPath(path.clone(), e))?;
Ok(FileData { path, file }) Ok(FileData { path, file })
} }
pub fn load_file( pub fn load_file(mut config_file: File, large_config: bool) -> Result<Config, BunBunError> {
mut config_file: File, trace!("Loading config file.");
large_config: bool, let file_size = config_file.metadata()?.len();
) -> Result<Config, BunBunError> {
trace!("Loading config file.");
let file_size = config_file.metadata()?.len();
// 100 MB // 100 MB
if file_size > LARGE_FILE_SIZE_THRESHOLD && !large_config { if file_size > LARGE_FILE_SIZE_THRESHOLD && !large_config {
return Err(BunBunError::ConfigTooLarge(file_size)); return Err(BunBunError::ConfigTooLarge(file_size));
} }
if file_size == 0 { if file_size == 0 {
return Err(BunBunError::ZeroByteConfig); return Err(BunBunError::ZeroByteConfig);
} }
let mut config_data = String::new(); let mut config_data = String::new();
config_file.read_to_string(&mut config_data)?; config_file.read_to_string(&mut config_data)?;
// Reading from memory is faster than reading directly from a reader for some // Reading from memory is faster than reading directly from a reader for some
// reason; see https://github.com/serde-rs/json/issues/160 // reason; see https://github.com/serde-rs/json/issues/160
Ok(serde_yaml::from_str(&config_data)?) Ok(serde_yaml::from_str(&config_data)?)
} }
#[cfg(test)] #[cfg(test)]
mod route { mod route {
use super::*; use super::*;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use serde_yaml::{from_str, to_string}; use serde_yaml::{from_str, to_string};
use std::path::Path; use std::path::Path;
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
#[test] #[test]
fn deserialize_relative_path() -> Result<()> { fn deserialize_relative_path() -> Result<()> {
let tmpfile = NamedTempFile::new_in(".")?; let tmpfile = NamedTempFile::new_in(".")?;
let path = tmpfile.path().display().to_string(); let path = tmpfile.path().display().to_string();
let path = path let path = path
.get(path.rfind(".").context("While finding .")?..) .get(path.rfind(".").context("While finding .")?..)
.context("While getting the path")?; .context("While getting the path")?;
let path = Path::new(path); let path = Path::new(path);
assert!(path.is_relative()); assert!(path.is_relative());
let path = path.to_str().context("While stringifying path")?; let path = path.to_str().context("While stringifying path")?;
assert_eq!(from_str::<Route>(path)?, Route::from(path.to_owned())); assert_eq!(from_str::<Route>(path)?, Route::from(path.to_owned()));
Ok(()) Ok(())
} }
#[test] #[test]
fn deserialize_absolute_path() -> Result<()> { fn deserialize_absolute_path() -> Result<()> {
let tmpfile = NamedTempFile::new()?; let tmpfile = NamedTempFile::new()?;
let path = format!("{}", tmpfile.path().display()); let path = format!("{}", tmpfile.path().display());
assert!(tmpfile.path().is_absolute()); assert!(tmpfile.path().is_absolute());
assert_eq!(from_str::<Route>(&path)?, Route::from(path)); assert_eq!(from_str::<Route>(&path)?, Route::from(path));
Ok(()) Ok(())
} }
#[test] #[test]
fn deserialize_http_path() -> Result<()> { fn deserialize_http_path() -> Result<()> {
assert_eq!( assert_eq!(
from_str::<Route>("http://google.com")?, from_str::<Route>("http://google.com")?,
Route::from("http://google.com") Route::from("http://google.com")
); );
Ok(()) Ok(())
} }
#[test] #[test]
fn deserialize_https_path() -> Result<()> { fn deserialize_https_path() -> Result<()> {
assert_eq!( assert_eq!(
from_str::<Route>("https://google.com")?, from_str::<Route>("https://google.com")?,
Route::from("https://google.com") Route::from("https://google.com")
); );
Ok(()) Ok(())
} }
#[test] #[test]
fn serialize() -> Result<()> { fn serialize() -> Result<()> {
assert_eq!( assert_eq!(
&to_string(&Route::from("hello world"))?, &to_string(&Route::from("hello world"))?,
"---\nroute_type: External\npath: hello world\nhidden: false\ndescription: ~\nmin_args: ~\nmax_args: ~\n" "---\nroute_type: External\npath: hello world\nhidden: false\ndescription: ~\nmin_args: ~\nmax_args: ~\n"
); );
Ok(()) Ok(())
} }
} }
#[cfg(test)] #[cfg(test)]
mod read_config { mod read_config {
use super::*; use super::*;
use anyhow::Result; use anyhow::Result;
#[test] #[test]
fn empty_file() -> Result<()> { fn empty_file() -> Result<()> {
let config_file = tempfile::tempfile()?; let config_file = tempfile::tempfile()?;
assert!(matches!( assert!(matches!(
load_file(config_file, false), load_file(config_file, false),
Err(BunBunError::ZeroByteConfig) Err(BunBunError::ZeroByteConfig)
)); ));
Ok(()) Ok(())
}
#[test]
fn config_too_large() -> Result<()> {
let mut config_file = tempfile::tempfile()?;
let size_to_write = (LARGE_FILE_SIZE_THRESHOLD + 1) as usize;
config_file.write(&[0].repeat(size_to_write))?;
match load_file(config_file, false) {
Err(BunBunError::ConfigTooLarge(size))
if size as usize == size_to_write => {}
Err(BunBunError::ConfigTooLarge(size)) => {
panic!("Mismatched size: {size} != {size_to_write}")
}
res => panic!("Wrong result, got {res:#?}"),
} }
Ok(())
}
#[test] #[test]
fn valid_config() -> Result<()> { fn config_too_large() -> Result<()> {
assert!(load_file(File::open("bunbun.default.yaml")?, false).is_ok()); let mut config_file = tempfile::tempfile()?;
Ok(()) let size_to_write = (LARGE_FILE_SIZE_THRESHOLD + 1) as usize;
} config_file.write(&[0].repeat(size_to_write))?;
match load_file(config_file, false) {
Err(BunBunError::ConfigTooLarge(size)) if size as usize == size_to_write => {}
Err(BunBunError::ConfigTooLarge(size)) => {
panic!("Mismatched size: {size} != {size_to_write}")
}
res => panic!("Wrong result, got {res:#?}"),
}
Ok(())
}
#[test]
fn valid_config() -> Result<()> {
assert!(load_file(File::open("bunbun.default.yaml")?, false).is_ok());
Ok(())
}
} }

View File

@ -4,49 +4,49 @@ use std::fmt;
#[derive(Debug)] #[derive(Debug)]
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
pub enum BunBunError { pub enum BunBunError {
Io(std::io::Error), Io(std::io::Error),
Parse(serde_yaml::Error), Parse(serde_yaml::Error),
Watch(hotwatch::Error), Watch(hotwatch::Error),
LoggerInit(log::SetLoggerError), LoggerInit(log::SetLoggerError),
CustomProgram(String), CustomProgram(String),
NoValidConfigPath, NoValidConfigPath,
InvalidConfigPath(std::path::PathBuf, std::io::Error), InvalidConfigPath(std::path::PathBuf, std::io::Error),
ConfigTooLarge(u64), ConfigTooLarge(u64),
ZeroByteConfig, ZeroByteConfig,
JsonParse(serde_json::Error), JsonParse(serde_json::Error),
} }
impl Error for BunBunError {} impl Error for BunBunError {}
impl fmt::Display for BunBunError { impl fmt::Display for BunBunError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
Self::Io(e) => e.fmt(f), Self::Io(e) => e.fmt(f),
Self::Parse(e) => e.fmt(f), Self::Parse(e) => e.fmt(f),
Self::Watch(e) => e.fmt(f), Self::Watch(e) => e.fmt(f),
Self::LoggerInit(e) => e.fmt(f), Self::LoggerInit(e) => e.fmt(f),
Self::CustomProgram(msg) => msg.fmt(f), Self::CustomProgram(msg) => msg.fmt(f),
Self::NoValidConfigPath => write!(f, "No valid config path was found!"), Self::NoValidConfigPath => write!(f, "No valid config path was found!"),
Self::InvalidConfigPath(path, reason) => { Self::InvalidConfigPath(path, reason) => {
write!(f, "Failed to access {path:?}: {reason}") write!(f, "Failed to access {path:?}: {reason}")
} }
Self::ConfigTooLarge(size) => write!(f, "The config file was too large ({size} bytes)! Pass in --large-config to bypass this check."), Self::ConfigTooLarge(size) => write!(f, "The config file was too large ({size} bytes)! Pass in --large-config to bypass this check."),
Self::ZeroByteConfig => write!(f, "The config provided reported a size of 0 bytes. Please check your config path!"), Self::ZeroByteConfig => write!(f, "The config provided reported a size of 0 bytes. Please check your config path!"),
Self::JsonParse(e) => e.fmt(f), Self::JsonParse(e) => e.fmt(f),
}
} }
}
} }
/// Generates a from implementation from the specified type to the provided /// Generates a from implementation from the specified type to the provided
/// bunbun error. /// bunbun error.
macro_rules! from_error { macro_rules! from_error {
($from:ty, $to:ident) => { ($from:ty, $to:ident) => {
impl From<$from> for BunBunError { impl From<$from> for BunBunError {
fn from(e: $from) -> Self { fn from(e: $from) -> Self {
Self::$to(e) Self::$to(e)
} }
} }
}; };
} }
from_error!(std::io::Error, Io); from_error!(std::io::Error, Io);

View File

@ -6,9 +6,7 @@
//! search engine and quick-jump tool in one small binary. For information on //! search engine and quick-jump tool in one small binary. For information on
//! usage, please take a look at the readme. //! usage, please take a look at the readme.
use crate::config::{ use crate::config::{get_config_data, load_custom_file, load_file, FileData, Route, RouteGroup};
get_config_data, load_custom_file, load_file, FileData, Route, RouteGroup,
};
use anyhow::Result; use anyhow::Result;
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use axum::routing::get; use axum::routing::get;
@ -35,49 +33,49 @@ mod template_args;
/// Dynamic variables that either need to be present at runtime, or can be /// Dynamic variables that either need to be present at runtime, or can be
/// changed during runtime. /// changed during runtime.
pub struct State { pub struct State {
public_address: String, public_address: String,
default_route: Option<String>, default_route: Option<String>,
groups: Vec<RouteGroup>, groups: Vec<RouteGroup>,
/// Cached, flattened mapping of all routes and their destinations. /// Cached, flattened mapping of all routes and their destinations.
routes: HashMap<String, Route>, routes: HashMap<String, Route>,
} }
#[tokio::main] #[tokio::main]
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
async fn main() -> Result<()> { async fn main() -> Result<()> {
let opts = cli::Opts::parse(); let opts = cli::Opts::parse();
init_logger(opts.verbose, opts.quiet)?; init_logger(opts.verbose, opts.quiet)?;
let conf_data = match opts.config { let conf_data = match opts.config {
Some(file_name) => load_custom_file(file_name), Some(file_name) => load_custom_file(file_name),
None => get_config_data(), None => get_config_data(),
}?; }?;
let conf = load_file(conf_data.file.try_clone()?, opts.large_config)?; let conf = load_file(conf_data.file.try_clone()?, opts.large_config)?;
let state = Arc::from(ArcSwap::from_pointee(State { let state = Arc::from(ArcSwap::from_pointee(State {
public_address: conf.public_address, public_address: conf.public_address,
default_route: conf.default_route, default_route: conf.default_route,
routes: cache_routes(conf.groups.clone()), routes: cache_routes(conf.groups.clone()),
groups: conf.groups, groups: conf.groups,
})); }));
// Cannot be named _ or Rust will immediately drop it. // Cannot be named _ or Rust will immediately drop it.
let _watch = start_watch(Arc::clone(&state), conf_data, opts.large_config); let _watch = start_watch(Arc::clone(&state), conf_data, opts.large_config);
let app = Router::new() let app = Router::new()
.route("/", get(routes::index)) .route("/", get(routes::index))
.route("/bunbunsearch.xml", get(routes::opensearch)) .route("/bunbunsearch.xml", get(routes::opensearch))
.route("/ls", get(routes::list)) .route("/ls", get(routes::list))
.route("/hop", get(routes::hop)) .route("/hop", get(routes::hop))
.layer(Extension(compile_templates()?)) .layer(Extension(compile_templates()?))
.layer(Extension(state)); .layer(Extension(state));
axum::Server::bind(&conf.bind_address.parse()?) axum::Server::bind(&conf.bind_address.parse()?)
.serve(app.into_make_service()) .serve(app.into_make_service())
.await?; .await?;
Ok(()) Ok(())
} }
/// Initializes the logger based on the number of quiet and verbose flags passed /// Initializes the logger based on the number of quiet and verbose flags passed
@ -85,52 +83,51 @@ async fn main() -> Result<()> {
/// verbose flags is non-zero then the quiet flag is zero, and vice versa. /// verbose flags is non-zero then the quiet flag is zero, and vice versa.
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
fn init_logger(num_verbose_flags: u8, num_quiet_flags: u8) -> Result<()> { fn init_logger(num_verbose_flags: u8, num_quiet_flags: u8) -> Result<()> {
let log_level = let log_level = match min(num_verbose_flags, 3) as i8 - min(num_quiet_flags, 2) as i8 {
match min(num_verbose_flags, 3) as i8 - min(num_quiet_flags, 2) as i8 { -2 => None,
-2 => None, -1 => Some(log::LevelFilter::Error),
-1 => Some(log::LevelFilter::Error), 0 => Some(log::LevelFilter::Warn),
0 => Some(log::LevelFilter::Warn), 1 => Some(log::LevelFilter::Info),
1 => Some(log::LevelFilter::Info), 2 => Some(log::LevelFilter::Debug),
2 => Some(log::LevelFilter::Debug), 3 => Some(log::LevelFilter::Trace),
3 => Some(log::LevelFilter::Trace), _ => unreachable!(), // values are clamped to [0, 3] - [0, 2]
_ => unreachable!(), // values are clamped to [0, 3] - [0, 2]
}; };
if let Some(level) = log_level { if let Some(level) = log_level {
SimpleLogger::new().with_level(level).init()?; SimpleLogger::new().with_level(level).init()?;
} }
Ok(()) Ok(())
} }
/// Generates a hashmap of routes from the data structure created by the config /// Generates a hashmap of routes from the data structure created by the config
/// file. This should improve runtime performance and is a better solution than /// file. This should improve runtime performance and is a better solution than
/// just iterating over the config object for every hop resolution. /// just iterating over the config object for every hop resolution.
fn cache_routes(groups: Vec<RouteGroup>) -> HashMap<String, Route> { fn cache_routes(groups: Vec<RouteGroup>) -> HashMap<String, Route> {
let mut mapping = HashMap::new(); let mut mapping = HashMap::new();
for group in groups { for group in groups {
for (kw, dest) in group.routes { for (kw, dest) in group.routes {
// This function isn't called often enough to not be a performance issue. // This function isn't called often enough to not be a performance issue.
match mapping.insert(kw.clone(), dest.clone()) { match mapping.insert(kw.clone(), dest.clone()) {
None => trace!("Inserting {kw} into mapping."), None => trace!("Inserting {kw} into mapping."),
Some(old_value) => { Some(old_value) => {
trace!("Overriding {kw} route from {old_value} to {dest}."); trace!("Overriding {kw} route from {old_value} to {dest}.");
}
}
} }
}
} }
} mapping
mapping
} }
/// Returns an instance with all pre-generated templates included into the /// Returns an instance with all pre-generated templates included into the
/// binary. This allows for users to have a portable binary without needed the /// binary. This allows for users to have a portable binary without needed the
/// templates at runtime. /// templates at runtime.
fn compile_templates() -> Result<Handlebars<'static>> { fn compile_templates() -> Result<Handlebars<'static>> {
let mut handlebars = Handlebars::new(); let mut handlebars = Handlebars::new();
handlebars.set_strict_mode(true); handlebars.set_strict_mode(true);
handlebars.register_partial("bunbun_version", env!("CARGO_PKG_VERSION"))?; handlebars.register_partial("bunbun_version", env!("CARGO_PKG_VERSION"))?;
handlebars.register_partial("bunbun_src", env!("CARGO_PKG_REPOSITORY"))?; handlebars.register_partial("bunbun_src", env!("CARGO_PKG_REPOSITORY"))?;
macro_rules! register_template { macro_rules! register_template {
[ $( $template:expr ),* ] => { [ $( $template:expr ),* ] => {
$( $(
handlebars handlebars
@ -143,8 +140,8 @@ fn compile_templates() -> Result<Handlebars<'static>> {
)* )*
}; };
} }
register_template!["index", "list", "opensearch"]; register_template!["index", "list", "opensearch"];
Ok(handlebars) Ok(handlebars)
} }
/// Starts the watch on a file, if possible. This will only return an Error if /// Starts the watch on a file, if possible. This will only return an Error if
@ -158,179 +155,170 @@ fn compile_templates() -> Result<Handlebars<'static>> {
/// watches. /// watches.
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
fn start_watch( fn start_watch(
state: Arc<ArcSwap<State>>, state: Arc<ArcSwap<State>>,
config_data: FileData, config_data: FileData,
large_config: bool, large_config: bool,
) -> Result<Hotwatch> { ) -> Result<Hotwatch> {
let mut watch = Hotwatch::new_with_custom_delay(Duration::from_millis(500))?; let mut watch = Hotwatch::new_with_custom_delay(Duration::from_millis(500))?;
let FileData { path, mut file } = config_data; let FileData { path, mut file } = config_data;
let watch_result = watch.watch(&path, move |e: Event| { let watch_result = watch.watch(&path, move |e: Event| {
if let Event::Create(ref path) = e { if let Event::Create(ref path) = e {
file = load_custom_file(path).expect("file to exist at path").file; file = load_custom_file(path).expect("file to exist at path").file;
trace!("Getting new file handler as file was recreated."); trace!("Getting new file handler as file was recreated.");
}
match e {
Event::Write(_) | Event::Create(_) => {
trace!("Grabbing writer lock on state...");
trace!("Obtained writer lock on state!");
match load_file(
file.try_clone().expect("Failed to clone file handle"),
large_config,
) {
Ok(conf) => {
state.store(Arc::new(State {
public_address: conf.public_address,
default_route: conf.default_route,
routes: cache_routes(conf.groups.clone()),
groups: conf.groups,
}));
info!("Successfully updated active state");
}
Err(e) => warn!("Failed to update config file: {e}"),
} }
}
_ => debug!("Saw event {e:#?} but ignored it"),
}
});
match watch_result { match e {
Ok(_) => info!("Watcher is now watching {path:?}"), Event::Write(_) | Event::Create(_) => {
Err(e) => { trace!("Grabbing writer lock on state...");
warn!( trace!("Obtained writer lock on state!");
"Couldn't watch {path:?}: {e}. Changes to this file won't be seen!" match load_file(
); file.try_clone().expect("Failed to clone file handle"),
} large_config,
} ) {
Ok(conf) => {
state.store(Arc::new(State {
public_address: conf.public_address,
default_route: conf.default_route,
routes: cache_routes(conf.groups.clone()),
groups: conf.groups,
}));
info!("Successfully updated active state");
}
Err(e) => warn!("Failed to update config file: {e}"),
}
}
_ => debug!("Saw event {e:#?} but ignored it"),
}
});
Ok(watch) match watch_result {
Ok(_) => info!("Watcher is now watching {path:?}"),
Err(e) => {
warn!("Couldn't watch {path:?}: {e}. Changes to this file won't be seen!");
}
}
Ok(watch)
} }
#[cfg(test)] #[cfg(test)]
mod init_logger { mod init_logger {
use super::*; use super::*;
use anyhow::Result; use anyhow::Result;
#[test] #[test]
fn defaults_to_warn() -> Result<()> { fn defaults_to_warn() -> Result<()> {
init_logger(0, 0)?; init_logger(0, 0)?;
assert_eq!(log::max_level(), log::Level::Warn); assert_eq!(log::max_level(), log::Level::Warn);
Ok(()) Ok(())
} }
// The following tests work but because the log crate is global, initializing // The following tests work but because the log crate is global, initializing
// the logger more than once (read: testing it more than once) leads to a // the logger more than once (read: testing it more than once) leads to a
// panic. These ignored tests must be manually tested. // panic. These ignored tests must be manually tested.
#[test] #[test]
#[ignore] #[ignore]
fn caps_to_2_when_log_level_is_lt_2() -> Result<()> { fn caps_to_2_when_log_level_is_lt_2() -> Result<()> {
init_logger(0, 3)?; init_logger(0, 3)?;
assert_eq!(log::max_level(), log::LevelFilter::Off); assert_eq!(log::max_level(), log::LevelFilter::Off);
Ok(()) Ok(())
} }
#[test] #[test]
#[ignore] #[ignore]
fn caps_to_3_when_log_level_is_gt_3() -> Result<()> { fn caps_to_3_when_log_level_is_gt_3() -> Result<()> {
init_logger(4, 0)?; init_logger(4, 0)?;
assert_eq!(log::max_level(), log::Level::Trace); assert_eq!(log::max_level(), log::Level::Trace);
Ok(()) Ok(())
} }
} }
#[cfg(test)] #[cfg(test)]
mod cache_routes { mod cache_routes {
use super::*; use super::*;
use std::iter::FromIterator; use std::iter::FromIterator;
fn generate_external_routes( fn generate_external_routes(routes: &[(&'static str, &'static str)]) -> HashMap<String, Route> {
routes: &[(&'static str, &'static str)], HashMap::from_iter(
) -> HashMap<String, Route> { routes
HashMap::from_iter( .into_iter()
routes .map(|(key, value)| ((*key).to_owned(), Route::from(*value))),
.into_iter() )
.map(|(key, value)| ((*key).to_owned(), Route::from(*value))), }
)
}
#[test] #[test]
fn empty_groups_yield_empty_routes() { fn empty_groups_yield_empty_routes() {
assert_eq!(cache_routes(Vec::new()), HashMap::new()); assert_eq!(cache_routes(Vec::new()), HashMap::new());
} }
#[test] #[test]
fn disjoint_groups_yield_summed_routes() { fn disjoint_groups_yield_summed_routes() {
let group1 = RouteGroup { let group1 = RouteGroup {
name: String::from("x"), name: String::from("x"),
description: Some(String::from("y")), description: Some(String::from("y")),
routes: generate_external_routes(&[("a", "b"), ("c", "d")]), routes: generate_external_routes(&[("a", "b"), ("c", "d")]),
hidden: false, hidden: false,
}; };
let group2 = RouteGroup { let group2 = RouteGroup {
name: String::from("5"), name: String::from("5"),
description: Some(String::from("6")), description: Some(String::from("6")),
routes: generate_external_routes(&[("1", "2"), ("3", "4")]), routes: generate_external_routes(&[("1", "2"), ("3", "4")]),
hidden: false, hidden: false,
}; };
assert_eq!( assert_eq!(
cache_routes(vec![group1, group2]), cache_routes(vec![group1, group2]),
generate_external_routes(&[ generate_external_routes(&[("a", "b"), ("c", "d"), ("1", "2"), ("3", "4")])
("a", "b"), );
("c", "d"), }
("1", "2"),
("3", "4")
])
);
}
#[test] #[test]
fn overlapping_groups_use_latter_routes() { fn overlapping_groups_use_latter_routes() {
let group1 = RouteGroup { let group1 = RouteGroup {
name: String::from("x"), name: String::from("x"),
description: Some(String::from("y")), description: Some(String::from("y")),
routes: generate_external_routes(&[("a", "b"), ("c", "d")]), routes: generate_external_routes(&[("a", "b"), ("c", "d")]),
hidden: false, hidden: false,
}; };
let group2 = RouteGroup { let group2 = RouteGroup {
name: String::from("5"), name: String::from("5"),
description: Some(String::from("6")), description: Some(String::from("6")),
routes: generate_external_routes(&[("a", "1"), ("c", "2")]), routes: generate_external_routes(&[("a", "1"), ("c", "2")]),
hidden: false, hidden: false,
}; };
assert_eq!( assert_eq!(
cache_routes(vec![group1.clone(), group2]), cache_routes(vec![group1.clone(), group2]),
generate_external_routes(&[("a", "1"), ("c", "2")]) generate_external_routes(&[("a", "1"), ("c", "2")])
); );
let group3 = RouteGroup { let group3 = RouteGroup {
name: String::from("5"), name: String::from("5"),
description: Some(String::from("6")), description: Some(String::from("6")),
routes: generate_external_routes(&[("a", "1"), ("b", "2")]), routes: generate_external_routes(&[("a", "1"), ("b", "2")]),
hidden: false, hidden: false,
}; };
assert_eq!( assert_eq!(
cache_routes(vec![group1, group3]), cache_routes(vec![group1, group3]),
generate_external_routes(&[("a", "1"), ("b", "2"), ("c", "d")]) generate_external_routes(&[("a", "1"), ("b", "2"), ("c", "d")])
); );
} }
} }
#[cfg(test)] #[cfg(test)]
mod compile_templates { mod compile_templates {
use super::compile_templates; use super::compile_templates;
/// Successful compilation of the binary guarantees that the templates will be /// Successful compilation of the binary guarantees that the templates will be
/// present to be registered to. Thus, we only really need to see that /// present to be registered to. Thus, we only really need to see that
/// compilation of the templates don't panic, which is just making sure that /// compilation of the templates don't panic, which is just making sure that
/// the function can be successfully called. /// the function can be successfully called.
#[test] #[test]
fn templates_compile() { fn templates_compile() {
let _ = compile_templates(); let _ = compile_templates();
} }
} }

View File

@ -17,130 +17,127 @@ use std::sync::Arc;
// https://url.spec.whatwg.org/#fragment-percent-encode-set // https://url.spec.whatwg.org/#fragment-percent-encode-set
const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS
.add(b' ') .add(b' ')
.add(b'"') .add(b'"')
.add(b'<') .add(b'<')
.add(b'>') .add(b'>')
.add(b'`') .add(b'`')
.add(b'+') .add(b'+')
.add(b'&') // Interpreted as a GET query .add(b'&') // Interpreted as a GET query
.add(b'#') // Interpreted as a hyperlink section target .add(b'#') // Interpreted as a hyperlink section target
.add(b'\''); .add(b'\'');
#[allow(clippy::unused_async)] #[allow(clippy::unused_async)]
pub async fn index( pub async fn index(
Extension(data): Extension<Arc<ArcSwap<State>>>, Extension(data): Extension<Arc<ArcSwap<State>>>,
Extension(handlebars): Extension<Handlebars<'static>>, Extension(handlebars): Extension<Handlebars<'static>>,
) -> impl IntoResponse { ) -> impl IntoResponse {
handlebars handlebars
.render( .render(
"index", "index",
&template_args::hostname(&data.load().public_address), &template_args::hostname(&data.load().public_address),
) )
.map(Html) .map(Html)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} }
#[allow(clippy::unused_async)] #[allow(clippy::unused_async)]
pub async fn opensearch( pub async fn opensearch(
Extension(data): Extension<Arc<ArcSwap<State>>>, Extension(data): Extension<Arc<ArcSwap<State>>>,
Extension(handlebars): Extension<Handlebars<'static>>, Extension(handlebars): Extension<Handlebars<'static>>,
) -> impl IntoResponse { ) -> impl IntoResponse {
handlebars handlebars
.render( .render(
"opensearch", "opensearch",
&template_args::hostname(&data.load().public_address), &template_args::hostname(&data.load().public_address),
) )
.map(|body| { .map(|body| {
( (
StatusCode::OK, StatusCode::OK,
[( [(
header::CONTENT_TYPE, header::CONTENT_TYPE,
"application/opensearchdescription+xml", "application/opensearchdescription+xml",
)], )],
body, body,
) )
}) })
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} }
#[allow(clippy::unused_async)] #[allow(clippy::unused_async)]
pub async fn list( pub async fn list(
Extension(data): Extension<Arc<ArcSwap<State>>>, Extension(data): Extension<Arc<ArcSwap<State>>>,
Extension(handlebars): Extension<Handlebars<'static>>, Extension(handlebars): Extension<Handlebars<'static>>,
) -> impl IntoResponse { ) -> impl IntoResponse {
handlebars handlebars
.render("list", &data.load().groups) .render("list", &data.load().groups)
.map(Html) .map(Html)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct SearchQuery { pub struct SearchQuery {
to: String, to: String,
} }
#[allow(clippy::unused_async)] #[allow(clippy::unused_async)]
pub async fn hop( pub async fn hop(
Extension(data): Extension<Arc<ArcSwap<State>>>, Extension(data): Extension<Arc<ArcSwap<State>>>,
Extension(handlebars): Extension<Handlebars<'static>>, Extension(handlebars): Extension<Handlebars<'static>>,
Query(query): Query<SearchQuery>, Query(query): Query<SearchQuery>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let data = data.load(); let data = data.load();
match resolve_hop(&query.to, &data.routes, &data.default_route) { match resolve_hop(&query.to, &data.routes, &data.default_route) {
RouteResolution::Resolved { route: path, args } => { RouteResolution::Resolved { route: path, args } => {
let resolved_template = match path { let resolved_template = match path {
ConfigRoute { ConfigRoute {
route_type: RouteType::Internal, route_type: RouteType::Internal,
path, path,
.. ..
} => resolve_path(Path::new(path), &args), } => resolve_path(Path::new(path), &args),
ConfigRoute { ConfigRoute {
route_type: RouteType::External, route_type: RouteType::External,
path, path,
.. ..
} => Ok(HopAction::Redirect(path.clone())), } => Ok(HopAction::Redirect(path.clone())),
}; };
match resolved_template { match resolved_template {
Ok(HopAction::Redirect(path)) => { Ok(HopAction::Redirect(path)) => {
let rendered = handlebars let rendered = handlebars
.render_template( .render_template(
&path, &path,
&template_args::query(utf8_percent_encode( &template_args::query(utf8_percent_encode(&args, FRAGMENT_ENCODE_SET)),
&args, )
FRAGMENT_ENCODE_SET, .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
)), Response::builder()
) .status(StatusCode::FOUND)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; .header(header::LOCATION, &path)
Response::builder() .body(boxed(Full::from(rendered)))
.status(StatusCode::FOUND) }
.header(header::LOCATION, &path) Ok(HopAction::Body(body)) => Response::builder()
.body(boxed(Full::from(rendered))) .status(StatusCode::OK)
.body(boxed(Full::new(Bytes::from(body)))),
Err(e) => {
error!("Failed to redirect user for {path}: {e}");
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(boxed(Full::from("Something went wrong :(\n")))
}
}
} }
Ok(HopAction::Body(body)) => Response::builder() RouteResolution::Unresolved => Response::builder()
.status(StatusCode::OK) .status(StatusCode::NOT_FOUND)
.body(boxed(Full::new(Bytes::from(body)))), .body(boxed(Full::from("not found\n"))),
Err(e) => {
error!("Failed to redirect user for {path}: {e}");
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(boxed(Full::from("Something went wrong :(\n")))
}
}
} }
RouteResolution::Unresolved => Response::builder() .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
.status(StatusCode::NOT_FOUND)
.body(boxed(Full::from("not found\n"))),
}
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
enum RouteResolution<'a> { enum RouteResolution<'a> {
Resolved { route: &'a Route, args: String }, Resolved { route: &'a Route, args: String },
Unresolved, Unresolved,
} }
/// Attempts to resolve the provided string into its route and its arguments. /// Attempts to resolve the provided string into its route and its arguments.
@ -150,71 +147,71 @@ enum RouteResolution<'a> {
/// The first element in the tuple describes the route, while the second element /// The first element in the tuple describes the route, while the second element
/// returns the remaining arguments. If none remain, an empty string is given. /// returns the remaining arguments. If none remain, an empty string is given.
fn resolve_hop<'a>( fn resolve_hop<'a>(
query: &str, query: &str,
routes: &'a HashMap<String, Route>, routes: &'a HashMap<String, Route>,
default_route: &Option<String>, default_route: &Option<String>,
) -> RouteResolution<'a> { ) -> RouteResolution<'a> {
let mut split_args = query.split_ascii_whitespace().peekable(); let mut split_args = query.split_ascii_whitespace().peekable();
let maybe_route = { let maybe_route = {
match split_args.peek() { match split_args.peek() {
Some(command) => routes.get(*command), Some(command) => routes.get(*command),
None => { None => {
debug!("Found empty query, returning no route."); debug!("Found empty query, returning no route.");
return RouteResolution::Unresolved; return RouteResolution::Unresolved;
} }
}
};
let args = split_args.collect::<Vec<_>>();
let arg_count = args.len();
// Try resolving with a matched command
if let Some(route) = maybe_route {
let args = if args.is_empty() { &[] } else { &args[1..] }.join(" ");
let arg_count = arg_count - 1;
if check_route(route, arg_count) {
debug!("Resolved {route} with args {args}");
return RouteResolution::Resolved { route, args };
}
} }
};
let args = split_args.collect::<Vec<_>>(); // Try resolving with the default route, if it exists
let arg_count = args.len(); if let Some(route) = default_route {
if let Some(route) = routes.get(route) {
// Try resolving with a matched command if check_route(route, arg_count) {
if let Some(route) = maybe_route { let args = args.join(" ");
let args = if args.is_empty() { &[] } else { &args[1..] }.join(" "); debug!("Using default route {route} with args {args}");
let arg_count = arg_count - 1; return RouteResolution::Resolved { route, args };
if check_route(route, arg_count) { }
debug!("Resolved {route} with args {args}"); }
return RouteResolution::Resolved { route, args };
} }
}
// Try resolving with the default route, if it exists RouteResolution::Unresolved
if let Some(route) = default_route {
if let Some(route) = routes.get(route) {
if check_route(route, arg_count) {
let args = args.join(" ");
debug!("Using default route {route} with args {args}");
return RouteResolution::Resolved { route, args };
}
}
}
RouteResolution::Unresolved
} }
/// Checks if the user provided string has the correct properties required by /// Checks if the user provided string has the correct properties required by
/// the route to be successfully matched. /// the route to be successfully matched.
const fn check_route(route: &Route, arg_count: usize) -> bool { const fn check_route(route: &Route, arg_count: usize) -> bool {
if let Some(min_args) = route.min_args { if let Some(min_args) = route.min_args {
if arg_count < min_args { if arg_count < min_args {
return false; return false;
}
} }
}
if let Some(max_args) = route.max_args { if let Some(max_args) = route.max_args {
if arg_count > max_args { if arg_count > max_args {
return false; return false;
}
} }
}
true true
} }
#[derive(Deserialize, Debug, PartialEq, Eq)] #[derive(Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
enum HopAction { enum HopAction {
Redirect(String), Redirect(String),
Body(String), Body(String),
} }
/// Runs the executable with the user's input as a single argument. Returns Ok /// Runs the executable with the user's input as a single argument. Returns Ok
@ -222,203 +219,200 @@ enum HopAction {
/// file doesn't exist or bunbun did not have permission to read and execute the /// file doesn't exist or bunbun did not have permission to read and execute the
/// file. /// file.
fn resolve_path(path: &Path, args: &str) -> Result<HopAction, BunBunError> { fn resolve_path(path: &Path, args: &str) -> Result<HopAction, BunBunError> {
let output = Command::new(path.canonicalize()?) let output = Command::new(path.canonicalize()?)
.args(args.split(' ')) .args(args.split(' '))
.output()?; .output()?;
if output.status.success() { if output.status.success() {
Ok(serde_json::from_slice(&output.stdout[..])?) Ok(serde_json::from_slice(&output.stdout[..])?)
} else { } else {
error!( error!(
"Program exit code for {} was not 0! Dumping standard error!", "Program exit code for {} was not 0! Dumping standard error!",
path.display(), path.display(),
); );
let error = String::from_utf8_lossy(&output.stderr); let error = String::from_utf8_lossy(&output.stderr);
Err(BunBunError::CustomProgram(error.to_string())) Err(BunBunError::CustomProgram(error.to_string()))
} }
} }
#[cfg(test)] #[cfg(test)]
mod resolve_hop { mod resolve_hop {
use super::*; use super::*;
use anyhow::Result; use anyhow::Result;
fn generate_route_result<'a>( fn generate_route_result<'a>(keyword: &'a Route, args: &str) -> RouteResolution<'a> {
keyword: &'a Route, RouteResolution::Resolved {
args: &str, route: keyword,
) -> RouteResolution<'a> { args: String::from(args),
RouteResolution::Resolved { }
route: keyword,
args: String::from(args),
} }
}
#[test] #[test]
fn empty_routes_no_default_yields_failed_hop() { fn empty_routes_no_default_yields_failed_hop() {
assert_eq!( assert_eq!(
resolve_hop("hello world", &HashMap::new(), &None), resolve_hop("hello world", &HashMap::new(), &None),
RouteResolution::Unresolved RouteResolution::Unresolved
); );
} }
#[test] #[test]
fn empty_routes_some_default_yields_failed_hop() { fn empty_routes_some_default_yields_failed_hop() {
assert_eq!( assert_eq!(
resolve_hop( resolve_hop(
"hello world", "hello world",
&HashMap::new(), &HashMap::new(),
&Some(String::from("google")) &Some(String::from("google"))
), ),
RouteResolution::Unresolved RouteResolution::Unresolved
); );
} }
#[test] #[test]
fn only_default_routes_some_default_yields_default_hop() -> Result<()> { fn only_default_routes_some_default_yields_default_hop() -> Result<()> {
let mut map: HashMap<String, Route> = HashMap::new(); let mut map: HashMap<String, Route> = HashMap::new();
map.insert("google".into(), Route::from("https://example.com")); map.insert("google".into(), Route::from("https://example.com"));
assert_eq!( assert_eq!(
resolve_hop("hello world", &map, &Some(String::from("google"))), resolve_hop("hello world", &map, &Some(String::from("google"))),
generate_route_result(&Route::from("https://example.com"), "hello world"), generate_route_result(&Route::from("https://example.com"), "hello world"),
); );
Ok(()) Ok(())
} }
#[test] #[test]
fn non_default_routes_some_default_yields_non_default_hop() -> Result<()> { fn non_default_routes_some_default_yields_non_default_hop() -> Result<()> {
let mut map: HashMap<String, Route> = HashMap::new(); let mut map: HashMap<String, Route> = HashMap::new();
map.insert("google".into(), Route::from("https://example.com")); map.insert("google".into(), Route::from("https://example.com"));
assert_eq!( assert_eq!(
resolve_hop("google hello world", &map, &Some(String::from("a"))), resolve_hop("google hello world", &map, &Some(String::from("a"))),
generate_route_result(&Route::from("https://example.com"), "hello world"), generate_route_result(&Route::from("https://example.com"), "hello world"),
); );
Ok(()) Ok(())
} }
#[test] #[test]
fn non_default_routes_no_default_yields_non_default_hop() -> Result<()> { fn non_default_routes_no_default_yields_non_default_hop() -> Result<()> {
let mut map: HashMap<String, Route> = HashMap::new(); let mut map: HashMap<String, Route> = HashMap::new();
map.insert("google".into(), Route::from("https://example.com")); map.insert("google".into(), Route::from("https://example.com"));
assert_eq!( assert_eq!(
resolve_hop("google hello world", &map, &None), resolve_hop("google hello world", &map, &None),
generate_route_result(&Route::from("https://example.com"), "hello world"), generate_route_result(&Route::from("https://example.com"), "hello world"),
); );
Ok(()) Ok(())
} }
} }
#[cfg(test)] #[cfg(test)]
mod check_route { mod check_route {
use super::*; use super::*;
fn create_route( fn create_route(
min_args: impl Into<Option<usize>>, min_args: impl Into<Option<usize>>,
max_args: impl Into<Option<usize>>, max_args: impl Into<Option<usize>>,
) -> Route { ) -> Route {
Route { Route {
description: None, description: None,
hidden: false, hidden: false,
max_args: max_args.into(), max_args: max_args.into(),
min_args: min_args.into(), min_args: min_args.into(),
path: String::new(), path: String::new(),
route_type: RouteType::External, route_type: RouteType::External,
}
} }
}
#[test] #[test]
fn no_min_arg_no_max_arg_counts() { fn no_min_arg_no_max_arg_counts() {
assert!(check_route(&create_route(None, None), 0)); assert!(check_route(&create_route(None, None), 0));
assert!(check_route(&create_route(None, None), usize::MAX)); assert!(check_route(&create_route(None, None), usize::MAX));
} }
#[test] #[test]
fn min_arg_no_max_arg_counts() { fn min_arg_no_max_arg_counts() {
assert!(!check_route(&create_route(3, None), 0)); assert!(!check_route(&create_route(3, None), 0));
assert!(!check_route(&create_route(3, None), 2)); assert!(!check_route(&create_route(3, None), 2));
assert!(check_route(&create_route(3, None), 3)); assert!(check_route(&create_route(3, None), 3));
assert!(check_route(&create_route(3, None), 4)); assert!(check_route(&create_route(3, None), 4));
assert!(check_route(&create_route(3, None), usize::MAX)); assert!(check_route(&create_route(3, None), usize::MAX));
} }
#[test] #[test]
fn no_min_arg_max_arg_counts() { fn no_min_arg_max_arg_counts() {
assert!(check_route(&create_route(None, 3), 0)); assert!(check_route(&create_route(None, 3), 0));
assert!(check_route(&create_route(None, 3), 2)); assert!(check_route(&create_route(None, 3), 2));
assert!(check_route(&create_route(None, 3), 3)); assert!(check_route(&create_route(None, 3), 3));
assert!(!check_route(&create_route(None, 3), 4)); assert!(!check_route(&create_route(None, 3), 4));
assert!(!check_route(&create_route(None, 3), usize::MAX)); assert!(!check_route(&create_route(None, 3), usize::MAX));
} }
#[test] #[test]
fn min_arg_max_arg_counts() { fn min_arg_max_arg_counts() {
assert!(!check_route(&create_route(2, 3), 1)); assert!(!check_route(&create_route(2, 3), 1));
assert!(check_route(&create_route(2, 3), 2)); assert!(check_route(&create_route(2, 3), 2));
assert!(check_route(&create_route(2, 3), 3)); assert!(check_route(&create_route(2, 3), 3));
assert!(!check_route(&create_route(2, 3), 4)); assert!(!check_route(&create_route(2, 3), 4));
} }
} }
#[cfg(test)] #[cfg(test)]
mod resolve_path { mod resolve_path {
use crate::error::BunBunError; use crate::error::BunBunError;
use super::{resolve_path, HopAction}; use super::{resolve_path, HopAction};
use anyhow::Result; use anyhow::Result;
use std::env::current_dir; use std::env::current_dir;
use std::io::ErrorKind; use std::io::ErrorKind;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
#[test] #[test]
fn invalid_path_returns_err() { fn invalid_path_returns_err() {
assert!(resolve_path(&Path::new("/bin/aaaa"), "aaaa").is_err()); assert!(resolve_path(&Path::new("/bin/aaaa"), "aaaa").is_err());
} }
#[test] #[test]
fn valid_path_returns_ok() { fn valid_path_returns_ok() {
assert!(resolve_path(&Path::new("/bin/echo"), r#"{"body": "a"}"#).is_ok()); assert!(resolve_path(&Path::new("/bin/echo"), r#"{"body": "a"}"#).is_ok());
} }
#[test] #[test]
fn relative_path_returns_ok() -> Result<()> { fn relative_path_returns_ok() -> Result<()> {
// How many ".." needed to get to / // How many ".." needed to get to /
let nest_level = current_dir()?.ancestors().count() - 1; let nest_level = current_dir()?.ancestors().count() - 1;
let mut rel_path = PathBuf::from("../".repeat(nest_level)); let mut rel_path = PathBuf::from("../".repeat(nest_level));
rel_path.push("./bin/echo"); rel_path.push("./bin/echo");
assert!(resolve_path(&rel_path, r#"{"body": "a"}"#).is_ok()); assert!(resolve_path(&rel_path, r#"{"body": "a"}"#).is_ok());
Ok(()) Ok(())
} }
#[test] #[test]
fn no_permissions_returns_err() { fn no_permissions_returns_err() {
let result = match resolve_path(&Path::new("/root/some_exec"), "") { let result = match resolve_path(&Path::new("/root/some_exec"), "") {
Err(BunBunError::Io(e)) => e.kind() == ErrorKind::PermissionDenied, Err(BunBunError::Io(e)) => e.kind() == ErrorKind::PermissionDenied,
_ => false, _ => false,
}; };
assert!(result); assert!(result);
} }
#[test] #[test]
fn non_success_exit_code_yields_err() { fn non_success_exit_code_yields_err() {
// cat-ing a folder always returns exit code 1 // cat-ing a folder always returns exit code 1
assert!(resolve_path(&Path::new("/bin/cat"), "/").is_err()); assert!(resolve_path(&Path::new("/bin/cat"), "/").is_err());
} }
#[test] #[test]
fn return_body() -> Result<()> { fn return_body() -> Result<()> {
assert_eq!( assert_eq!(
resolve_path(&Path::new("/bin/echo"), r#"{"body": "a"}"#)?, resolve_path(&Path::new("/bin/echo"), r#"{"body": "a"}"#)?,
HopAction::Body("a".to_string()) HopAction::Body("a".to_string())
); );
Ok(()) Ok(())
} }
#[test] #[test]
fn return_redirect() -> Result<()> { fn return_redirect() -> Result<()> {
assert_eq!( assert_eq!(
resolve_path(&Path::new("/bin/echo"), r#"{"redirect": "a"}"#)?, resolve_path(&Path::new("/bin/echo"), r#"{"redirect": "a"}"#)?,
HopAction::Redirect("a".to_string()) HopAction::Redirect("a".to_string())
); );
Ok(()) Ok(())
} }
} }

View File

@ -4,19 +4,19 @@ use percent_encoding::PercentEncode;
use serde::Serialize; use serde::Serialize;
pub fn query(query: PercentEncode<'_>) -> impl Serialize + '_ { pub fn query(query: PercentEncode<'_>) -> impl Serialize + '_ {
#[derive(Serialize)] #[derive(Serialize)]
struct TemplateArgs<'a> { struct TemplateArgs<'a> {
query: Cow<'a, str>, query: Cow<'a, str>,
} }
TemplateArgs { TemplateArgs {
query: query.into(), query: query.into(),
} }
} }
pub fn hostname(hostname: &'_ str) -> impl Serialize + '_ { pub fn hostname(hostname: &'_ str) -> impl Serialize + '_ {
#[derive(Serialize)] #[derive(Serialize)]
pub struct TemplateArgs<'a> { pub struct TemplateArgs<'a> {
pub hostname: &'a str, pub hostname: &'a str,
} }
TemplateArgs { hostname } TemplateArgs { hostname }
} }