diff --git a/rustfmt.toml b/rustfmt.toml deleted file mode 100644 index 2489c66..0000000 --- a/rustfmt.toml +++ /dev/null @@ -1,3 +0,0 @@ -tab_spaces = 2 -use_field_init_shorthand = true -max_width = 80 diff --git a/src/cli.rs b/src/cli.rs index 13f9553..93ba71b 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -4,16 +4,16 @@ use std::path::PathBuf; #[derive(Parser)] #[clap(version = crate_version!(), author = crate_authors!())] pub struct Opts { - /// Increases the log level to info, debug, and trace, respectively. - #[clap(short, long, parse(from_occurrences), conflicts_with("quiet"))] - pub verbose: u8, - /// Decreases the log level to error or no logging at all, respectively. - #[clap(short, long, parse(from_occurrences), conflicts_with("verbose"))] - pub quiet: u8, - /// Specify the location of the config file to read from. Needs read/write permissions. - #[clap(short, long)] - pub config: Option, - /// Allow config sizes larger than 100MB. - #[clap(long)] - pub large_config: bool, + /// Increases the log level to info, debug, and trace, respectively. + #[clap(short, long, parse(from_occurrences), conflicts_with("quiet"))] + pub verbose: u8, + /// Decreases the log level to error or no logging at all, respectively. + #[clap(short, long, parse(from_occurrences), conflicts_with("verbose"))] + pub quiet: u8, + /// Specify the location of the config file to read from. Needs read/write permissions. + #[clap(short, long)] + pub config: Option, + /// Allow config sizes larger than 100MB. + #[clap(long)] + pub large_config: bool, } diff --git a/src/config.rs b/src/config.rs index ff4d7c3..0649992 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,8 +2,8 @@ use crate::BunBunError; use dirs::{config_dir, home_dir}; use log::{debug, info, trace}; use serde::{ - de::{self, Deserializer, MapAccess, Unexpected, Visitor}, - Deserialize, Serialize, + de::{self, Deserializer, MapAccess, Unexpected, Visitor}, + Deserialize, Serialize, }; use std::collections::HashMap; use std::fmt; @@ -20,55 +20,55 @@ const LARGE_FILE_SIZE_THRESHOLD: u64 = 1_000_000; #[derive(Deserialize, Debug, PartialEq)] pub struct Config { - pub bind_address: String, - pub public_address: String, - pub default_route: Option, - pub groups: Vec, + pub bind_address: String, + pub public_address: String, + pub default_route: Option, + pub groups: Vec, } #[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] pub struct RouteGroup { - pub name: String, - pub description: Option, - #[serde(default)] - pub hidden: bool, - pub routes: HashMap, + pub name: String, + pub description: Option, + #[serde(default)] + pub hidden: bool, + pub routes: HashMap, } #[derive(Debug, PartialEq, Clone, Serialize)] pub struct Route { - pub route_type: RouteType, - pub path: String, - pub hidden: bool, - pub description: Option, - pub min_args: Option, - pub max_args: Option, + pub route_type: RouteType, + pub path: String, + pub hidden: bool, + pub description: Option, + pub min_args: Option, + pub max_args: Option, } impl From for Route { - fn from(s: String) -> Self { - Self { - route_type: get_route_type(&s), - path: s, - hidden: false, - description: None, - min_args: None, - max_args: None, + fn from(s: String) -> Self { + Self { + route_type: get_route_type(&s), + path: s, + hidden: false, + description: None, + min_args: None, + max_args: None, + } } - } } impl From<&'static str> for Route { - fn from(s: &'static str) -> Self { - Self { - route_type: get_route_type(s), - path: s.to_string(), - hidden: false, - description: None, - min_args: None, - max_args: None, + fn from(s: &'static str) -> Self { + Self { + route_type: get_route_type(s), + path: s.to_string(), + hidden: false, + description: None, + min_args: None, + max_args: None, + } } - } } /// Deserialization of the route string into the enum requires us to figure out @@ -77,149 +77,149 @@ impl From<&'static str> for Route { /// web path. This incurs a disk check operation, but since users shouldn't be /// updating the config that frequently, it should be fine. impl<'de> Deserialize<'de> for Route { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - #[serde(field_identifier, rename_all = "snake_case")] - enum Field { - Path, - Hidden, - Description, - MinArgs, - MaxArgs, - } - - struct RouteVisitor; - - impl<'de> Visitor<'de> for RouteVisitor { - type Value = Route; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("string") - } - - fn visit_str(self, path: &str) -> Result - where - E: serde::de::Error, - { - Ok(Self::Value::from(path.to_owned())) - } - - fn visit_map(self, mut map: M) -> Result - where - M: MapAccess<'de>, - { - let mut path = None; - let mut hidden = None; - let mut description = None; - let mut min_args = None; - let mut max_args = None; - - while let Some(key) = map.next_key()? { - match key { - Field::Path => { - if path.is_some() { - return Err(de::Error::duplicate_field("path")); - } - path = Some(map.next_value::()?); - } - Field::Hidden => { - if hidden.is_some() { - return Err(de::Error::duplicate_field("hidden")); - } - hidden = map.next_value()?; - } - Field::Description => { - if description.is_some() { - return Err(de::Error::duplicate_field("description")); - } - description = Some(map.next_value()?); - } - Field::MinArgs => { - if min_args.is_some() { - return Err(de::Error::duplicate_field("min_args")); - } - min_args = Some(map.next_value()?); - } - Field::MaxArgs => { - if max_args.is_some() { - return Err(de::Error::duplicate_field("max_args")); - } - max_args = Some(map.next_value()?); - } - } + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(field_identifier, rename_all = "snake_case")] + enum Field { + Path, + Hidden, + Description, + MinArgs, + MaxArgs, } - if let (Some(min_args), Some(max_args)) = (min_args, max_args) { - if min_args > max_args { + struct RouteVisitor; + + impl<'de> Visitor<'de> for RouteVisitor { + type Value = Route; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("string") + } + + fn visit_str(self, path: &str) -> Result + where + E: serde::de::Error, { - return Err(de::Error::invalid_value( - Unexpected::Other(&format!( - "argument count range {min_args} to {max_args}", - )), - &"a valid argument count range", - )); + Ok(Self::Value::from(path.to_owned())) + } + + fn visit_map(self, mut map: M) -> Result + where + M: MapAccess<'de>, + { + let mut path = None; + let mut hidden = None; + let mut description = None; + let mut min_args = None; + let mut max_args = None; + + while let Some(key) = map.next_key()? { + match key { + Field::Path => { + if path.is_some() { + return Err(de::Error::duplicate_field("path")); + } + path = Some(map.next_value::()?); + } + Field::Hidden => { + if hidden.is_some() { + return Err(de::Error::duplicate_field("hidden")); + } + hidden = map.next_value()?; + } + Field::Description => { + if description.is_some() { + return Err(de::Error::duplicate_field("description")); + } + description = Some(map.next_value()?); + } + Field::MinArgs => { + if min_args.is_some() { + return Err(de::Error::duplicate_field("min_args")); + } + min_args = Some(map.next_value()?); + } + Field::MaxArgs => { + if max_args.is_some() { + return Err(de::Error::duplicate_field("max_args")); + } + max_args = Some(map.next_value()?); + } + } + } + + if let (Some(min_args), Some(max_args)) = (min_args, max_args) { + if min_args > max_args { + { + return Err(de::Error::invalid_value( + Unexpected::Other(&format!( + "argument count range {min_args} to {max_args}", + )), + &"a valid argument count range", + )); + } + } + } + + let path = path.ok_or_else(|| de::Error::missing_field("path"))?; + Ok(Route { + route_type: get_route_type(&path), + path, + hidden: hidden.unwrap_or_default(), + description, + min_args, + max_args, + }) } - } } - let path = path.ok_or_else(|| de::Error::missing_field("path"))?; - Ok(Route { - route_type: get_route_type(&path), - path, - hidden: hidden.unwrap_or_default(), - description, - min_args, - max_args, - }) - } + deserializer.deserialize_any(RouteVisitor) } - - deserializer.deserialize_any(RouteVisitor) - } } impl std::fmt::Display for Route { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self { - route_type: RouteType::External, - path, - .. - } => write!(f, "raw ({path})"), - Self { - route_type: RouteType::Internal, - path, - .. - } => write!(f, "file ({path})"), + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self { + route_type: RouteType::External, + path, + .. + } => write!(f, "raw ({path})"), + Self { + route_type: RouteType::Internal, + path, + .. + } => write!(f, "file ({path})"), + } } - } } /// Classifies the path depending on if the there exists a local file. fn get_route_type(path: &str) -> RouteType { - if std::path::Path::new(path).exists() { - debug!("Parsed {path} as a valid local path."); - RouteType::Internal - } else { - debug!("{path} does not exist on disk, assuming web path."); - RouteType::External - } + if std::path::Path::new(path).exists() { + debug!("Parsed {path} as a valid local path."); + RouteType::Internal + } else { + debug!("{path} does not exist on disk, assuming web path."); + RouteType::External + } } /// There exists two route types: an external path (e.g. a URL) or an internal /// path (to a file). #[derive(Debug, PartialEq, Clone, Serialize)] pub enum RouteType { - External, - Internal, + External, + Internal, } pub struct FileData { - pub path: PathBuf, - pub file: File, + pub path: PathBuf, + pub file: File, } /// If a provided config path isn't found, this function checks known good @@ -227,206 +227,200 @@ pub struct FileData { /// system-wide config location (`/etc/`, in Linux), followed by the config /// folder, followed by the user's home folder. pub fn get_config_data() -> Result { - // Locations to check, with highest priority first - let locations: Vec<_> = { - let mut folders = vec![PathBuf::from("/etc/")]; + // Locations to check, with highest priority first + let locations: Vec<_> = { + let mut folders = vec![PathBuf::from("/etc/")]; - // Config folder - if let Some(folder) = config_dir() { - folders.push(folder); + // Config folder + if let Some(folder) = config_dir() { + folders.push(folder); + } + + // Home folder + if let Some(folder) = home_dir() { + folders.push(folder); + } + + folders + .iter_mut() + .for_each(|folder| folder.push(CONFIG_FILENAME)); + + folders + }; + + debug!("Checking locations for config file: {:?}", &locations); + + for location in &locations { + let file = OpenOptions::new().read(true).open(location); + match file { + Ok(file) => { + debug!("Found file at {location:?}."); + return Ok(FileData { + path: location.clone(), + file, + }); + } + Err(e) => { + debug!("Tried to read '{location:?}' but failed due to error: {e}"); + } + } } - // Home folder - if let Some(folder) = home_dir() { - folders.push(folder); + debug!("Failed to find any config. Now trying to find first writable path"); + + // If we got here, we failed to read any file paths, meaning no config exists + // yet. In that case, try to return the first location that we can write to, + // after writing the default config + for location in locations { + let file = OpenOptions::new() + .write(true) + .create_new(true) + .open(location.clone()); + match file { + Ok(mut file) => { + info!("Creating new config file at {location:?}."); + file.write_all(DEFAULT_CONFIG)?; + + let file = OpenOptions::new().read(true).open(location.clone())?; + return Ok(FileData { + path: location, + file, + }); + } + Err(e) => { + debug!("Tried to open a new file at '{location:?}' but failed due to error: {e}",) + } + } } - folders - .iter_mut() - .for_each(|folder| folder.push(CONFIG_FILENAME)); - - folders - }; - - debug!("Checking locations for config file: {:?}", &locations); - - for location in &locations { - let file = OpenOptions::new().read(true).open(location); - match file { - Ok(file) => { - debug!("Found file at {location:?}."); - return Ok(FileData { - path: location.clone(), - file, - }); - } - Err(e) => { - debug!("Tried to read '{location:?}' but failed due to error: {e}"); - } - } - } - - debug!("Failed to find any config. Now trying to find first writable path"); - - // If we got here, we failed to read any file paths, meaning no config exists - // yet. In that case, try to return the first location that we can write to, - // after writing the default config - for location in locations { - let file = OpenOptions::new() - .write(true) - .create_new(true) - .open(location.clone()); - match file { - Ok(mut file) => { - info!("Creating new config file at {location:?}."); - file.write_all(DEFAULT_CONFIG)?; - - let file = OpenOptions::new().read(true).open(location.clone())?; - return Ok(FileData { - path: location, - file, - }); - } - Err(e) => debug!( - "Tried to open a new file at '{location:?}' but failed due to error: {e}", - ), - } - } - - Err(BunBunError::NoValidConfigPath) + Err(BunBunError::NoValidConfigPath) } /// Assumes that the user knows what they're talking about and will only try /// to load the config at the given path. -pub fn load_custom_file( - path: impl Into, -) -> Result { - let path = path.into(); - let file = OpenOptions::new() - .read(true) - .open(&path) - .map_err(|e| BunBunError::InvalidConfigPath(path.clone(), e))?; +pub fn load_custom_file(path: impl Into) -> Result { + let path = path.into(); + let file = OpenOptions::new() + .read(true) + .open(&path) + .map_err(|e| BunBunError::InvalidConfigPath(path.clone(), e))?; - Ok(FileData { path, file }) + Ok(FileData { path, file }) } -pub fn load_file( - mut config_file: File, - large_config: bool, -) -> Result { - trace!("Loading config file."); - let file_size = config_file.metadata()?.len(); +pub fn load_file(mut config_file: File, large_config: bool) -> Result { + trace!("Loading config file."); + let file_size = config_file.metadata()?.len(); - // 100 MB - if file_size > LARGE_FILE_SIZE_THRESHOLD && !large_config { - return Err(BunBunError::ConfigTooLarge(file_size)); - } + // 100 MB + if file_size > LARGE_FILE_SIZE_THRESHOLD && !large_config { + return Err(BunBunError::ConfigTooLarge(file_size)); + } - if file_size == 0 { - return Err(BunBunError::ZeroByteConfig); - } + if file_size == 0 { + return Err(BunBunError::ZeroByteConfig); + } - let mut config_data = String::new(); - config_file.read_to_string(&mut config_data)?; - // Reading from memory is faster than reading directly from a reader for some - // reason; see https://github.com/serde-rs/json/issues/160 - Ok(serde_yaml::from_str(&config_data)?) + let mut config_data = String::new(); + config_file.read_to_string(&mut config_data)?; + // Reading from memory is faster than reading directly from a reader for some + // reason; see https://github.com/serde-rs/json/issues/160 + Ok(serde_yaml::from_str(&config_data)?) } #[cfg(test)] mod route { - use super::*; - use anyhow::{Context, Result}; - use serde_yaml::{from_str, to_string}; - use std::path::Path; - use tempfile::NamedTempFile; + use super::*; + use anyhow::{Context, Result}; + use serde_yaml::{from_str, to_string}; + use std::path::Path; + use tempfile::NamedTempFile; - #[test] - fn deserialize_relative_path() -> Result<()> { - let tmpfile = NamedTempFile::new_in(".")?; - let path = tmpfile.path().display().to_string(); - let path = path - .get(path.rfind(".").context("While finding .")?..) - .context("While getting the path")?; - let path = Path::new(path); - assert!(path.is_relative()); - let path = path.to_str().context("While stringifying path")?; - assert_eq!(from_str::(path)?, Route::from(path.to_owned())); - Ok(()) - } + #[test] + fn deserialize_relative_path() -> Result<()> { + let tmpfile = NamedTempFile::new_in(".")?; + let path = tmpfile.path().display().to_string(); + let path = path + .get(path.rfind(".").context("While finding .")?..) + .context("While getting the path")?; + let path = Path::new(path); + assert!(path.is_relative()); + let path = path.to_str().context("While stringifying path")?; + assert_eq!(from_str::(path)?, Route::from(path.to_owned())); + Ok(()) + } - #[test] - fn deserialize_absolute_path() -> Result<()> { - let tmpfile = NamedTempFile::new()?; - let path = format!("{}", tmpfile.path().display()); - assert!(tmpfile.path().is_absolute()); - assert_eq!(from_str::(&path)?, Route::from(path)); + #[test] + fn deserialize_absolute_path() -> Result<()> { + let tmpfile = NamedTempFile::new()?; + let path = format!("{}", tmpfile.path().display()); + assert!(tmpfile.path().is_absolute()); + assert_eq!(from_str::(&path)?, Route::from(path)); - Ok(()) - } + Ok(()) + } - #[test] - fn deserialize_http_path() -> Result<()> { - assert_eq!( - from_str::("http://google.com")?, - Route::from("http://google.com") - ); - Ok(()) - } + #[test] + fn deserialize_http_path() -> Result<()> { + assert_eq!( + from_str::("http://google.com")?, + Route::from("http://google.com") + ); + Ok(()) + } - #[test] - fn deserialize_https_path() -> Result<()> { - assert_eq!( - from_str::("https://google.com")?, - Route::from("https://google.com") - ); - Ok(()) - } + #[test] + fn deserialize_https_path() -> Result<()> { + assert_eq!( + from_str::("https://google.com")?, + Route::from("https://google.com") + ); + Ok(()) + } - #[test] - fn serialize() -> Result<()> { - assert_eq!( - &to_string(&Route::from("hello world"))?, - "---\nroute_type: External\npath: hello world\nhidden: false\ndescription: ~\nmin_args: ~\nmax_args: ~\n" - ); - Ok(()) - } + #[test] + fn serialize() -> Result<()> { + assert_eq!( + &to_string(&Route::from("hello world"))?, + "---\nroute_type: External\npath: hello world\nhidden: false\ndescription: ~\nmin_args: ~\nmax_args: ~\n" + ); + Ok(()) + } } #[cfg(test)] mod read_config { - use super::*; - use anyhow::Result; + use super::*; + use anyhow::Result; - #[test] - fn empty_file() -> Result<()> { - let config_file = tempfile::tempfile()?; - assert!(matches!( - load_file(config_file, false), - Err(BunBunError::ZeroByteConfig) - )); - Ok(()) - } - - #[test] - fn config_too_large() -> Result<()> { - let mut config_file = tempfile::tempfile()?; - let size_to_write = (LARGE_FILE_SIZE_THRESHOLD + 1) as usize; - config_file.write(&[0].repeat(size_to_write))?; - match load_file(config_file, false) { - Err(BunBunError::ConfigTooLarge(size)) - if size as usize == size_to_write => {} - Err(BunBunError::ConfigTooLarge(size)) => { - panic!("Mismatched size: {size} != {size_to_write}") - } - res => panic!("Wrong result, got {res:#?}"), + #[test] + fn empty_file() -> Result<()> { + let config_file = tempfile::tempfile()?; + assert!(matches!( + load_file(config_file, false), + Err(BunBunError::ZeroByteConfig) + )); + Ok(()) } - Ok(()) - } - #[test] - fn valid_config() -> Result<()> { - assert!(load_file(File::open("bunbun.default.yaml")?, false).is_ok()); - Ok(()) - } + #[test] + fn config_too_large() -> Result<()> { + let mut config_file = tempfile::tempfile()?; + let size_to_write = (LARGE_FILE_SIZE_THRESHOLD + 1) as usize; + config_file.write(&[0].repeat(size_to_write))?; + match load_file(config_file, false) { + Err(BunBunError::ConfigTooLarge(size)) if size as usize == size_to_write => {} + Err(BunBunError::ConfigTooLarge(size)) => { + panic!("Mismatched size: {size} != {size_to_write}") + } + res => panic!("Wrong result, got {res:#?}"), + } + Ok(()) + } + + #[test] + fn valid_config() -> Result<()> { + assert!(load_file(File::open("bunbun.default.yaml")?, false).is_ok()); + Ok(()) + } } diff --git a/src/error.rs b/src/error.rs index 2231c49..065ad9a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,49 +4,49 @@ use std::fmt; #[derive(Debug)] #[allow(clippy::module_name_repetitions)] pub enum BunBunError { - Io(std::io::Error), - Parse(serde_yaml::Error), - Watch(hotwatch::Error), - LoggerInit(log::SetLoggerError), - CustomProgram(String), - NoValidConfigPath, - InvalidConfigPath(std::path::PathBuf, std::io::Error), - ConfigTooLarge(u64), - ZeroByteConfig, - JsonParse(serde_json::Error), + Io(std::io::Error), + Parse(serde_yaml::Error), + Watch(hotwatch::Error), + LoggerInit(log::SetLoggerError), + CustomProgram(String), + NoValidConfigPath, + InvalidConfigPath(std::path::PathBuf, std::io::Error), + ConfigTooLarge(u64), + ZeroByteConfig, + JsonParse(serde_json::Error), } impl Error for BunBunError {} impl fmt::Display for BunBunError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Io(e) => e.fmt(f), - Self::Parse(e) => e.fmt(f), - Self::Watch(e) => e.fmt(f), - Self::LoggerInit(e) => e.fmt(f), - Self::CustomProgram(msg) => msg.fmt(f), - Self::NoValidConfigPath => write!(f, "No valid config path was found!"), - Self::InvalidConfigPath(path, reason) => { - write!(f, "Failed to access {path:?}: {reason}") - } - Self::ConfigTooLarge(size) => write!(f, "The config file was too large ({size} bytes)! Pass in --large-config to bypass this check."), - Self::ZeroByteConfig => write!(f, "The config provided reported a size of 0 bytes. Please check your config path!"), - Self::JsonParse(e) => e.fmt(f), + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Io(e) => e.fmt(f), + Self::Parse(e) => e.fmt(f), + Self::Watch(e) => e.fmt(f), + Self::LoggerInit(e) => e.fmt(f), + Self::CustomProgram(msg) => msg.fmt(f), + Self::NoValidConfigPath => write!(f, "No valid config path was found!"), + Self::InvalidConfigPath(path, reason) => { + write!(f, "Failed to access {path:?}: {reason}") + } + Self::ConfigTooLarge(size) => write!(f, "The config file was too large ({size} bytes)! Pass in --large-config to bypass this check."), + Self::ZeroByteConfig => write!(f, "The config provided reported a size of 0 bytes. Please check your config path!"), + Self::JsonParse(e) => e.fmt(f), + } } - } } /// Generates a from implementation from the specified type to the provided /// bunbun error. macro_rules! from_error { - ($from:ty, $to:ident) => { - impl From<$from> for BunBunError { - fn from(e: $from) -> Self { - Self::$to(e) - } - } - }; + ($from:ty, $to:ident) => { + impl From<$from> for BunBunError { + fn from(e: $from) -> Self { + Self::$to(e) + } + } + }; } from_error!(std::io::Error, Io); diff --git a/src/main.rs b/src/main.rs index cab22f3..9ae1af0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,9 +6,7 @@ //! search engine and quick-jump tool in one small binary. For information on //! usage, please take a look at the readme. -use crate::config::{ - get_config_data, load_custom_file, load_file, FileData, Route, RouteGroup, -}; +use crate::config::{get_config_data, load_custom_file, load_file, FileData, Route, RouteGroup}; use anyhow::Result; use arc_swap::ArcSwap; use axum::routing::get; @@ -35,49 +33,49 @@ mod template_args; /// Dynamic variables that either need to be present at runtime, or can be /// changed during runtime. pub struct State { - public_address: String, - default_route: Option, - groups: Vec, - /// Cached, flattened mapping of all routes and their destinations. - routes: HashMap, + public_address: String, + default_route: Option, + groups: Vec, + /// Cached, flattened mapping of all routes and their destinations. + routes: HashMap, } #[tokio::main] #[cfg(not(tarpaulin_include))] async fn main() -> Result<()> { - let opts = cli::Opts::parse(); + let opts = cli::Opts::parse(); - init_logger(opts.verbose, opts.quiet)?; + init_logger(opts.verbose, opts.quiet)?; - let conf_data = match opts.config { - Some(file_name) => load_custom_file(file_name), - None => get_config_data(), - }?; + let conf_data = match opts.config { + Some(file_name) => load_custom_file(file_name), + None => get_config_data(), + }?; - let conf = load_file(conf_data.file.try_clone()?, opts.large_config)?; - let state = Arc::from(ArcSwap::from_pointee(State { - public_address: conf.public_address, - default_route: conf.default_route, - routes: cache_routes(conf.groups.clone()), - groups: conf.groups, - })); + let conf = load_file(conf_data.file.try_clone()?, opts.large_config)?; + let state = Arc::from(ArcSwap::from_pointee(State { + public_address: conf.public_address, + default_route: conf.default_route, + routes: cache_routes(conf.groups.clone()), + groups: conf.groups, + })); - // Cannot be named _ or Rust will immediately drop it. - let _watch = start_watch(Arc::clone(&state), conf_data, opts.large_config); + // Cannot be named _ or Rust will immediately drop it. + let _watch = start_watch(Arc::clone(&state), conf_data, opts.large_config); - let app = Router::new() - .route("/", get(routes::index)) - .route("/bunbunsearch.xml", get(routes::opensearch)) - .route("/ls", get(routes::list)) - .route("/hop", get(routes::hop)) - .layer(Extension(compile_templates()?)) - .layer(Extension(state)); + let app = Router::new() + .route("/", get(routes::index)) + .route("/bunbunsearch.xml", get(routes::opensearch)) + .route("/ls", get(routes::list)) + .route("/hop", get(routes::hop)) + .layer(Extension(compile_templates()?)) + .layer(Extension(state)); - axum::Server::bind(&conf.bind_address.parse()?) - .serve(app.into_make_service()) - .await?; + axum::Server::bind(&conf.bind_address.parse()?) + .serve(app.into_make_service()) + .await?; - Ok(()) + Ok(()) } /// Initializes the logger based on the number of quiet and verbose flags passed @@ -85,52 +83,51 @@ async fn main() -> Result<()> { /// verbose flags is non-zero then the quiet flag is zero, and vice versa. #[cfg(not(tarpaulin_include))] fn init_logger(num_verbose_flags: u8, num_quiet_flags: u8) -> Result<()> { - let log_level = - match min(num_verbose_flags, 3) as i8 - min(num_quiet_flags, 2) as i8 { - -2 => None, - -1 => Some(log::LevelFilter::Error), - 0 => Some(log::LevelFilter::Warn), - 1 => Some(log::LevelFilter::Info), - 2 => Some(log::LevelFilter::Debug), - 3 => Some(log::LevelFilter::Trace), - _ => unreachable!(), // values are clamped to [0, 3] - [0, 2] + let log_level = match min(num_verbose_flags, 3) as i8 - min(num_quiet_flags, 2) as i8 { + -2 => None, + -1 => Some(log::LevelFilter::Error), + 0 => Some(log::LevelFilter::Warn), + 1 => Some(log::LevelFilter::Info), + 2 => Some(log::LevelFilter::Debug), + 3 => Some(log::LevelFilter::Trace), + _ => unreachable!(), // values are clamped to [0, 3] - [0, 2] }; - if let Some(level) = log_level { - SimpleLogger::new().with_level(level).init()?; - } + if let Some(level) = log_level { + SimpleLogger::new().with_level(level).init()?; + } - Ok(()) + Ok(()) } /// Generates a hashmap of routes from the data structure created by the config /// file. This should improve runtime performance and is a better solution than /// just iterating over the config object for every hop resolution. fn cache_routes(groups: Vec) -> HashMap { - let mut mapping = HashMap::new(); - for group in groups { - for (kw, dest) in group.routes { - // This function isn't called often enough to not be a performance issue. - match mapping.insert(kw.clone(), dest.clone()) { - None => trace!("Inserting {kw} into mapping."), - Some(old_value) => { - trace!("Overriding {kw} route from {old_value} to {dest}."); + let mut mapping = HashMap::new(); + for group in groups { + for (kw, dest) in group.routes { + // This function isn't called often enough to not be a performance issue. + match mapping.insert(kw.clone(), dest.clone()) { + None => trace!("Inserting {kw} into mapping."), + Some(old_value) => { + trace!("Overriding {kw} route from {old_value} to {dest}."); + } + } } - } } - } - mapping + mapping } /// Returns an instance with all pre-generated templates included into the /// binary. This allows for users to have a portable binary without needed the /// templates at runtime. fn compile_templates() -> Result> { - let mut handlebars = Handlebars::new(); - handlebars.set_strict_mode(true); - handlebars.register_partial("bunbun_version", env!("CARGO_PKG_VERSION"))?; - handlebars.register_partial("bunbun_src", env!("CARGO_PKG_REPOSITORY"))?; - macro_rules! register_template { + let mut handlebars = Handlebars::new(); + handlebars.set_strict_mode(true); + handlebars.register_partial("bunbun_version", env!("CARGO_PKG_VERSION"))?; + handlebars.register_partial("bunbun_src", env!("CARGO_PKG_REPOSITORY"))?; + macro_rules! register_template { [ $( $template:expr ),* ] => { $( handlebars @@ -143,8 +140,8 @@ fn compile_templates() -> Result> { )* }; } - register_template!["index", "list", "opensearch"]; - Ok(handlebars) + register_template!["index", "list", "opensearch"]; + Ok(handlebars) } /// Starts the watch on a file, if possible. This will only return an Error if @@ -158,179 +155,170 @@ fn compile_templates() -> Result> { /// watches. #[cfg(not(tarpaulin_include))] fn start_watch( - state: Arc>, - config_data: FileData, - large_config: bool, + state: Arc>, + config_data: FileData, + large_config: bool, ) -> Result { - let mut watch = Hotwatch::new_with_custom_delay(Duration::from_millis(500))?; - let FileData { path, mut file } = config_data; - let watch_result = watch.watch(&path, move |e: Event| { - if let Event::Create(ref path) = e { - file = load_custom_file(path).expect("file to exist at path").file; - trace!("Getting new file handler as file was recreated."); - } - - match e { - Event::Write(_) | Event::Create(_) => { - trace!("Grabbing writer lock on state..."); - trace!("Obtained writer lock on state!"); - match load_file( - file.try_clone().expect("Failed to clone file handle"), - large_config, - ) { - Ok(conf) => { - state.store(Arc::new(State { - public_address: conf.public_address, - default_route: conf.default_route, - routes: cache_routes(conf.groups.clone()), - groups: conf.groups, - })); - info!("Successfully updated active state"); - } - Err(e) => warn!("Failed to update config file: {e}"), + let mut watch = Hotwatch::new_with_custom_delay(Duration::from_millis(500))?; + let FileData { path, mut file } = config_data; + let watch_result = watch.watch(&path, move |e: Event| { + if let Event::Create(ref path) = e { + file = load_custom_file(path).expect("file to exist at path").file; + trace!("Getting new file handler as file was recreated."); } - } - _ => debug!("Saw event {e:#?} but ignored it"), - } - }); - match watch_result { - Ok(_) => info!("Watcher is now watching {path:?}"), - Err(e) => { - warn!( - "Couldn't watch {path:?}: {e}. Changes to this file won't be seen!" - ); - } - } + match e { + Event::Write(_) | Event::Create(_) => { + trace!("Grabbing writer lock on state..."); + trace!("Obtained writer lock on state!"); + match load_file( + file.try_clone().expect("Failed to clone file handle"), + large_config, + ) { + Ok(conf) => { + state.store(Arc::new(State { + public_address: conf.public_address, + default_route: conf.default_route, + routes: cache_routes(conf.groups.clone()), + groups: conf.groups, + })); + info!("Successfully updated active state"); + } + Err(e) => warn!("Failed to update config file: {e}"), + } + } + _ => debug!("Saw event {e:#?} but ignored it"), + } + }); - Ok(watch) + match watch_result { + Ok(_) => info!("Watcher is now watching {path:?}"), + Err(e) => { + warn!("Couldn't watch {path:?}: {e}. Changes to this file won't be seen!"); + } + } + + Ok(watch) } #[cfg(test)] mod init_logger { - use super::*; - use anyhow::Result; + use super::*; + use anyhow::Result; - #[test] - fn defaults_to_warn() -> Result<()> { - init_logger(0, 0)?; - assert_eq!(log::max_level(), log::Level::Warn); - Ok(()) - } + #[test] + fn defaults_to_warn() -> Result<()> { + init_logger(0, 0)?; + assert_eq!(log::max_level(), log::Level::Warn); + Ok(()) + } - // The following tests work but because the log crate is global, initializing - // the logger more than once (read: testing it more than once) leads to a - // panic. These ignored tests must be manually tested. + // The following tests work but because the log crate is global, initializing + // the logger more than once (read: testing it more than once) leads to a + // panic. These ignored tests must be manually tested. - #[test] - #[ignore] - fn caps_to_2_when_log_level_is_lt_2() -> Result<()> { - init_logger(0, 3)?; - assert_eq!(log::max_level(), log::LevelFilter::Off); - Ok(()) - } + #[test] + #[ignore] + fn caps_to_2_when_log_level_is_lt_2() -> Result<()> { + init_logger(0, 3)?; + assert_eq!(log::max_level(), log::LevelFilter::Off); + Ok(()) + } - #[test] - #[ignore] - fn caps_to_3_when_log_level_is_gt_3() -> Result<()> { - init_logger(4, 0)?; - assert_eq!(log::max_level(), log::Level::Trace); - Ok(()) - } + #[test] + #[ignore] + fn caps_to_3_when_log_level_is_gt_3() -> Result<()> { + init_logger(4, 0)?; + assert_eq!(log::max_level(), log::Level::Trace); + Ok(()) + } } #[cfg(test)] mod cache_routes { - use super::*; - use std::iter::FromIterator; + use super::*; + use std::iter::FromIterator; - fn generate_external_routes( - routes: &[(&'static str, &'static str)], - ) -> HashMap { - HashMap::from_iter( - routes - .into_iter() - .map(|(key, value)| ((*key).to_owned(), Route::from(*value))), - ) - } + fn generate_external_routes(routes: &[(&'static str, &'static str)]) -> HashMap { + HashMap::from_iter( + routes + .into_iter() + .map(|(key, value)| ((*key).to_owned(), Route::from(*value))), + ) + } - #[test] - fn empty_groups_yield_empty_routes() { - assert_eq!(cache_routes(Vec::new()), HashMap::new()); - } + #[test] + fn empty_groups_yield_empty_routes() { + assert_eq!(cache_routes(Vec::new()), HashMap::new()); + } - #[test] - fn disjoint_groups_yield_summed_routes() { - let group1 = RouteGroup { - name: String::from("x"), - description: Some(String::from("y")), - routes: generate_external_routes(&[("a", "b"), ("c", "d")]), - hidden: false, - }; + #[test] + fn disjoint_groups_yield_summed_routes() { + let group1 = RouteGroup { + name: String::from("x"), + description: Some(String::from("y")), + routes: generate_external_routes(&[("a", "b"), ("c", "d")]), + hidden: false, + }; - let group2 = RouteGroup { - name: String::from("5"), - description: Some(String::from("6")), - routes: generate_external_routes(&[("1", "2"), ("3", "4")]), - hidden: false, - }; + let group2 = RouteGroup { + name: String::from("5"), + description: Some(String::from("6")), + routes: generate_external_routes(&[("1", "2"), ("3", "4")]), + hidden: false, + }; - assert_eq!( - cache_routes(vec![group1, group2]), - generate_external_routes(&[ - ("a", "b"), - ("c", "d"), - ("1", "2"), - ("3", "4") - ]) - ); - } + assert_eq!( + cache_routes(vec![group1, group2]), + generate_external_routes(&[("a", "b"), ("c", "d"), ("1", "2"), ("3", "4")]) + ); + } - #[test] - fn overlapping_groups_use_latter_routes() { - let group1 = RouteGroup { - name: String::from("x"), - description: Some(String::from("y")), - routes: generate_external_routes(&[("a", "b"), ("c", "d")]), - hidden: false, - }; + #[test] + fn overlapping_groups_use_latter_routes() { + let group1 = RouteGroup { + name: String::from("x"), + description: Some(String::from("y")), + routes: generate_external_routes(&[("a", "b"), ("c", "d")]), + hidden: false, + }; - let group2 = RouteGroup { - name: String::from("5"), - description: Some(String::from("6")), - routes: generate_external_routes(&[("a", "1"), ("c", "2")]), - hidden: false, - }; + let group2 = RouteGroup { + name: String::from("5"), + description: Some(String::from("6")), + routes: generate_external_routes(&[("a", "1"), ("c", "2")]), + hidden: false, + }; - assert_eq!( - cache_routes(vec![group1.clone(), group2]), - generate_external_routes(&[("a", "1"), ("c", "2")]) - ); + assert_eq!( + cache_routes(vec![group1.clone(), group2]), + generate_external_routes(&[("a", "1"), ("c", "2")]) + ); - let group3 = RouteGroup { - name: String::from("5"), - description: Some(String::from("6")), - routes: generate_external_routes(&[("a", "1"), ("b", "2")]), - hidden: false, - }; + let group3 = RouteGroup { + name: String::from("5"), + description: Some(String::from("6")), + routes: generate_external_routes(&[("a", "1"), ("b", "2")]), + hidden: false, + }; - assert_eq!( - cache_routes(vec![group1, group3]), - generate_external_routes(&[("a", "1"), ("b", "2"), ("c", "d")]) - ); - } + assert_eq!( + cache_routes(vec![group1, group3]), + generate_external_routes(&[("a", "1"), ("b", "2"), ("c", "d")]) + ); + } } #[cfg(test)] mod compile_templates { - use super::compile_templates; + use super::compile_templates; - /// Successful compilation of the binary guarantees that the templates will be - /// present to be registered to. Thus, we only really need to see that - /// compilation of the templates don't panic, which is just making sure that - /// the function can be successfully called. - #[test] - fn templates_compile() { - let _ = compile_templates(); - } + /// Successful compilation of the binary guarantees that the templates will be + /// present to be registered to. Thus, we only really need to see that + /// compilation of the templates don't panic, which is just making sure that + /// the function can be successfully called. + #[test] + fn templates_compile() { + let _ = compile_templates(); + } } diff --git a/src/routes.rs b/src/routes.rs index 3bb4110..b4e3d2d 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -17,130 +17,127 @@ use std::sync::Arc; // https://url.spec.whatwg.org/#fragment-percent-encode-set const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS - .add(b' ') - .add(b'"') - .add(b'<') - .add(b'>') - .add(b'`') - .add(b'+') - .add(b'&') // Interpreted as a GET query - .add(b'#') // Interpreted as a hyperlink section target - .add(b'\''); + .add(b' ') + .add(b'"') + .add(b'<') + .add(b'>') + .add(b'`') + .add(b'+') + .add(b'&') // Interpreted as a GET query + .add(b'#') // Interpreted as a hyperlink section target + .add(b'\''); #[allow(clippy::unused_async)] pub async fn index( - Extension(data): Extension>>, - Extension(handlebars): Extension>, + Extension(data): Extension>>, + Extension(handlebars): Extension>, ) -> impl IntoResponse { - handlebars - .render( - "index", - &template_args::hostname(&data.load().public_address), - ) - .map(Html) - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) + handlebars + .render( + "index", + &template_args::hostname(&data.load().public_address), + ) + .map(Html) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } #[allow(clippy::unused_async)] pub async fn opensearch( - Extension(data): Extension>>, - Extension(handlebars): Extension>, + Extension(data): Extension>>, + Extension(handlebars): Extension>, ) -> impl IntoResponse { - handlebars - .render( - "opensearch", - &template_args::hostname(&data.load().public_address), - ) - .map(|body| { - ( - StatusCode::OK, - [( - header::CONTENT_TYPE, - "application/opensearchdescription+xml", - )], - body, - ) - }) - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) + handlebars + .render( + "opensearch", + &template_args::hostname(&data.load().public_address), + ) + .map(|body| { + ( + StatusCode::OK, + [( + header::CONTENT_TYPE, + "application/opensearchdescription+xml", + )], + body, + ) + }) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } #[allow(clippy::unused_async)] pub async fn list( - Extension(data): Extension>>, - Extension(handlebars): Extension>, + Extension(data): Extension>>, + Extension(handlebars): Extension>, ) -> impl IntoResponse { - handlebars - .render("list", &data.load().groups) - .map(Html) - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) + handlebars + .render("list", &data.load().groups) + .map(Html) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } #[derive(Deserialize, Debug)] pub struct SearchQuery { - to: String, + to: String, } #[allow(clippy::unused_async)] pub async fn hop( - Extension(data): Extension>>, - Extension(handlebars): Extension>, - Query(query): Query, + Extension(data): Extension>>, + Extension(handlebars): Extension>, + Query(query): Query, ) -> impl IntoResponse { - let data = data.load(); + let data = data.load(); - match resolve_hop(&query.to, &data.routes, &data.default_route) { - RouteResolution::Resolved { route: path, args } => { - let resolved_template = match path { - ConfigRoute { - route_type: RouteType::Internal, - path, - .. - } => resolve_path(Path::new(path), &args), - ConfigRoute { - route_type: RouteType::External, - path, - .. - } => Ok(HopAction::Redirect(path.clone())), - }; + match resolve_hop(&query.to, &data.routes, &data.default_route) { + RouteResolution::Resolved { route: path, args } => { + let resolved_template = match path { + ConfigRoute { + route_type: RouteType::Internal, + path, + .. + } => resolve_path(Path::new(path), &args), + ConfigRoute { + route_type: RouteType::External, + path, + .. + } => Ok(HopAction::Redirect(path.clone())), + }; - match resolved_template { - Ok(HopAction::Redirect(path)) => { - let rendered = handlebars - .render_template( - &path, - &template_args::query(utf8_percent_encode( - &args, - FRAGMENT_ENCODE_SET, - )), - ) - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - Response::builder() - .status(StatusCode::FOUND) - .header(header::LOCATION, &path) - .body(boxed(Full::from(rendered))) + match resolved_template { + Ok(HopAction::Redirect(path)) => { + let rendered = handlebars + .render_template( + &path, + &template_args::query(utf8_percent_encode(&args, FRAGMENT_ENCODE_SET)), + ) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Response::builder() + .status(StatusCode::FOUND) + .header(header::LOCATION, &path) + .body(boxed(Full::from(rendered))) + } + Ok(HopAction::Body(body)) => Response::builder() + .status(StatusCode::OK) + .body(boxed(Full::new(Bytes::from(body)))), + Err(e) => { + error!("Failed to redirect user for {path}: {e}"); + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(boxed(Full::from("Something went wrong :(\n"))) + } + } } - Ok(HopAction::Body(body)) => Response::builder() - .status(StatusCode::OK) - .body(boxed(Full::new(Bytes::from(body)))), - Err(e) => { - error!("Failed to redirect user for {path}: {e}"); - Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(boxed(Full::from("Something went wrong :(\n"))) - } - } + RouteResolution::Unresolved => Response::builder() + .status(StatusCode::NOT_FOUND) + .body(boxed(Full::from("not found\n"))), } - RouteResolution::Unresolved => Response::builder() - .status(StatusCode::NOT_FOUND) - .body(boxed(Full::from("not found\n"))), - } - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } #[derive(Debug, PartialEq)] enum RouteResolution<'a> { - Resolved { route: &'a Route, args: String }, - Unresolved, + Resolved { route: &'a Route, args: String }, + Unresolved, } /// Attempts to resolve the provided string into its route and its arguments. @@ -150,71 +147,71 @@ enum RouteResolution<'a> { /// The first element in the tuple describes the route, while the second element /// returns the remaining arguments. If none remain, an empty string is given. fn resolve_hop<'a>( - query: &str, - routes: &'a HashMap, - default_route: &Option, + query: &str, + routes: &'a HashMap, + default_route: &Option, ) -> RouteResolution<'a> { - let mut split_args = query.split_ascii_whitespace().peekable(); - let maybe_route = { - match split_args.peek() { - Some(command) => routes.get(*command), - None => { - debug!("Found empty query, returning no route."); - return RouteResolution::Unresolved; - } + let mut split_args = query.split_ascii_whitespace().peekable(); + let maybe_route = { + match split_args.peek() { + Some(command) => routes.get(*command), + None => { + debug!("Found empty query, returning no route."); + return RouteResolution::Unresolved; + } + } + }; + + let args = split_args.collect::>(); + let arg_count = args.len(); + + // Try resolving with a matched command + if let Some(route) = maybe_route { + let args = if args.is_empty() { &[] } else { &args[1..] }.join(" "); + let arg_count = arg_count - 1; + if check_route(route, arg_count) { + debug!("Resolved {route} with args {args}"); + return RouteResolution::Resolved { route, args }; + } } - }; - let args = split_args.collect::>(); - let arg_count = args.len(); - - // Try resolving with a matched command - if let Some(route) = maybe_route { - let args = if args.is_empty() { &[] } else { &args[1..] }.join(" "); - let arg_count = arg_count - 1; - if check_route(route, arg_count) { - debug!("Resolved {route} with args {args}"); - return RouteResolution::Resolved { route, args }; + // Try resolving with the default route, if it exists + if let Some(route) = default_route { + if let Some(route) = routes.get(route) { + if check_route(route, arg_count) { + let args = args.join(" "); + debug!("Using default route {route} with args {args}"); + return RouteResolution::Resolved { route, args }; + } + } } - } - // Try resolving with the default route, if it exists - if let Some(route) = default_route { - if let Some(route) = routes.get(route) { - if check_route(route, arg_count) { - let args = args.join(" "); - debug!("Using default route {route} with args {args}"); - return RouteResolution::Resolved { route, args }; - } - } - } - - RouteResolution::Unresolved + RouteResolution::Unresolved } /// Checks if the user provided string has the correct properties required by /// the route to be successfully matched. const fn check_route(route: &Route, arg_count: usize) -> bool { - if let Some(min_args) = route.min_args { - if arg_count < min_args { - return false; + if let Some(min_args) = route.min_args { + if arg_count < min_args { + return false; + } } - } - if let Some(max_args) = route.max_args { - if arg_count > max_args { - return false; + if let Some(max_args) = route.max_args { + if arg_count > max_args { + return false; + } } - } - true + true } #[derive(Deserialize, Debug, PartialEq, Eq)] #[serde(rename_all = "snake_case")] enum HopAction { - Redirect(String), - Body(String), + Redirect(String), + Body(String), } /// Runs the executable with the user's input as a single argument. Returns Ok @@ -222,203 +219,200 @@ enum HopAction { /// file doesn't exist or bunbun did not have permission to read and execute the /// file. fn resolve_path(path: &Path, args: &str) -> Result { - let output = Command::new(path.canonicalize()?) - .args(args.split(' ')) - .output()?; + let output = Command::new(path.canonicalize()?) + .args(args.split(' ')) + .output()?; - if output.status.success() { - Ok(serde_json::from_slice(&output.stdout[..])?) - } else { - error!( - "Program exit code for {} was not 0! Dumping standard error!", - path.display(), - ); - let error = String::from_utf8_lossy(&output.stderr); - Err(BunBunError::CustomProgram(error.to_string())) - } + if output.status.success() { + Ok(serde_json::from_slice(&output.stdout[..])?) + } else { + error!( + "Program exit code for {} was not 0! Dumping standard error!", + path.display(), + ); + let error = String::from_utf8_lossy(&output.stderr); + Err(BunBunError::CustomProgram(error.to_string())) + } } #[cfg(test)] mod resolve_hop { - use super::*; - use anyhow::Result; + use super::*; + use anyhow::Result; - fn generate_route_result<'a>( - keyword: &'a Route, - args: &str, - ) -> RouteResolution<'a> { - RouteResolution::Resolved { - route: keyword, - args: String::from(args), + fn generate_route_result<'a>(keyword: &'a Route, args: &str) -> RouteResolution<'a> { + RouteResolution::Resolved { + route: keyword, + args: String::from(args), + } } - } - #[test] - fn empty_routes_no_default_yields_failed_hop() { - assert_eq!( - resolve_hop("hello world", &HashMap::new(), &None), - RouteResolution::Unresolved - ); - } + #[test] + fn empty_routes_no_default_yields_failed_hop() { + assert_eq!( + resolve_hop("hello world", &HashMap::new(), &None), + RouteResolution::Unresolved + ); + } - #[test] - fn empty_routes_some_default_yields_failed_hop() { - assert_eq!( - resolve_hop( - "hello world", - &HashMap::new(), - &Some(String::from("google")) - ), - RouteResolution::Unresolved - ); - } + #[test] + fn empty_routes_some_default_yields_failed_hop() { + assert_eq!( + resolve_hop( + "hello world", + &HashMap::new(), + &Some(String::from("google")) + ), + RouteResolution::Unresolved + ); + } - #[test] - fn only_default_routes_some_default_yields_default_hop() -> Result<()> { - let mut map: HashMap = HashMap::new(); - map.insert("google".into(), Route::from("https://example.com")); - assert_eq!( - resolve_hop("hello world", &map, &Some(String::from("google"))), - generate_route_result(&Route::from("https://example.com"), "hello world"), - ); - Ok(()) - } + #[test] + fn only_default_routes_some_default_yields_default_hop() -> Result<()> { + let mut map: HashMap = HashMap::new(); + map.insert("google".into(), Route::from("https://example.com")); + assert_eq!( + resolve_hop("hello world", &map, &Some(String::from("google"))), + generate_route_result(&Route::from("https://example.com"), "hello world"), + ); + Ok(()) + } - #[test] - fn non_default_routes_some_default_yields_non_default_hop() -> Result<()> { - let mut map: HashMap = HashMap::new(); - map.insert("google".into(), Route::from("https://example.com")); - assert_eq!( - resolve_hop("google hello world", &map, &Some(String::from("a"))), - generate_route_result(&Route::from("https://example.com"), "hello world"), - ); - Ok(()) - } + #[test] + fn non_default_routes_some_default_yields_non_default_hop() -> Result<()> { + let mut map: HashMap = HashMap::new(); + map.insert("google".into(), Route::from("https://example.com")); + assert_eq!( + resolve_hop("google hello world", &map, &Some(String::from("a"))), + generate_route_result(&Route::from("https://example.com"), "hello world"), + ); + Ok(()) + } - #[test] - fn non_default_routes_no_default_yields_non_default_hop() -> Result<()> { - let mut map: HashMap = HashMap::new(); - map.insert("google".into(), Route::from("https://example.com")); - assert_eq!( - resolve_hop("google hello world", &map, &None), - generate_route_result(&Route::from("https://example.com"), "hello world"), - ); - Ok(()) - } + #[test] + fn non_default_routes_no_default_yields_non_default_hop() -> Result<()> { + let mut map: HashMap = HashMap::new(); + map.insert("google".into(), Route::from("https://example.com")); + assert_eq!( + resolve_hop("google hello world", &map, &None), + generate_route_result(&Route::from("https://example.com"), "hello world"), + ); + Ok(()) + } } #[cfg(test)] mod check_route { - use super::*; + use super::*; - fn create_route( - min_args: impl Into>, - max_args: impl Into>, - ) -> Route { - Route { - description: None, - hidden: false, - max_args: max_args.into(), - min_args: min_args.into(), - path: String::new(), - route_type: RouteType::External, + fn create_route( + min_args: impl Into>, + max_args: impl Into>, + ) -> Route { + Route { + description: None, + hidden: false, + max_args: max_args.into(), + min_args: min_args.into(), + path: String::new(), + route_type: RouteType::External, + } } - } - #[test] - fn no_min_arg_no_max_arg_counts() { - assert!(check_route(&create_route(None, None), 0)); - assert!(check_route(&create_route(None, None), usize::MAX)); - } + #[test] + fn no_min_arg_no_max_arg_counts() { + assert!(check_route(&create_route(None, None), 0)); + assert!(check_route(&create_route(None, None), usize::MAX)); + } - #[test] - fn min_arg_no_max_arg_counts() { - assert!(!check_route(&create_route(3, None), 0)); - assert!(!check_route(&create_route(3, None), 2)); - assert!(check_route(&create_route(3, None), 3)); - assert!(check_route(&create_route(3, None), 4)); - assert!(check_route(&create_route(3, None), usize::MAX)); - } + #[test] + fn min_arg_no_max_arg_counts() { + assert!(!check_route(&create_route(3, None), 0)); + assert!(!check_route(&create_route(3, None), 2)); + assert!(check_route(&create_route(3, None), 3)); + assert!(check_route(&create_route(3, None), 4)); + assert!(check_route(&create_route(3, None), usize::MAX)); + } - #[test] - fn no_min_arg_max_arg_counts() { - assert!(check_route(&create_route(None, 3), 0)); - assert!(check_route(&create_route(None, 3), 2)); - assert!(check_route(&create_route(None, 3), 3)); - assert!(!check_route(&create_route(None, 3), 4)); - assert!(!check_route(&create_route(None, 3), usize::MAX)); - } + #[test] + fn no_min_arg_max_arg_counts() { + assert!(check_route(&create_route(None, 3), 0)); + assert!(check_route(&create_route(None, 3), 2)); + assert!(check_route(&create_route(None, 3), 3)); + assert!(!check_route(&create_route(None, 3), 4)); + assert!(!check_route(&create_route(None, 3), usize::MAX)); + } - #[test] - fn min_arg_max_arg_counts() { - assert!(!check_route(&create_route(2, 3), 1)); - assert!(check_route(&create_route(2, 3), 2)); - assert!(check_route(&create_route(2, 3), 3)); - assert!(!check_route(&create_route(2, 3), 4)); - } + #[test] + fn min_arg_max_arg_counts() { + assert!(!check_route(&create_route(2, 3), 1)); + assert!(check_route(&create_route(2, 3), 2)); + assert!(check_route(&create_route(2, 3), 3)); + assert!(!check_route(&create_route(2, 3), 4)); + } } #[cfg(test)] mod resolve_path { - use crate::error::BunBunError; + use crate::error::BunBunError; - use super::{resolve_path, HopAction}; - use anyhow::Result; - use std::env::current_dir; - use std::io::ErrorKind; - use std::path::{Path, PathBuf}; + use super::{resolve_path, HopAction}; + use anyhow::Result; + use std::env::current_dir; + use std::io::ErrorKind; + use std::path::{Path, PathBuf}; - #[test] - fn invalid_path_returns_err() { - assert!(resolve_path(&Path::new("/bin/aaaa"), "aaaa").is_err()); - } + #[test] + fn invalid_path_returns_err() { + assert!(resolve_path(&Path::new("/bin/aaaa"), "aaaa").is_err()); + } - #[test] - fn valid_path_returns_ok() { - assert!(resolve_path(&Path::new("/bin/echo"), r#"{"body": "a"}"#).is_ok()); - } + #[test] + fn valid_path_returns_ok() { + assert!(resolve_path(&Path::new("/bin/echo"), r#"{"body": "a"}"#).is_ok()); + } - #[test] - fn relative_path_returns_ok() -> Result<()> { - // How many ".." needed to get to / - let nest_level = current_dir()?.ancestors().count() - 1; - let mut rel_path = PathBuf::from("../".repeat(nest_level)); - rel_path.push("./bin/echo"); - assert!(resolve_path(&rel_path, r#"{"body": "a"}"#).is_ok()); - Ok(()) - } + #[test] + fn relative_path_returns_ok() -> Result<()> { + // How many ".." needed to get to / + let nest_level = current_dir()?.ancestors().count() - 1; + let mut rel_path = PathBuf::from("../".repeat(nest_level)); + rel_path.push("./bin/echo"); + assert!(resolve_path(&rel_path, r#"{"body": "a"}"#).is_ok()); + Ok(()) + } - #[test] - fn no_permissions_returns_err() { - let result = match resolve_path(&Path::new("/root/some_exec"), "") { - Err(BunBunError::Io(e)) => e.kind() == ErrorKind::PermissionDenied, - _ => false, - }; - assert!(result); - } + #[test] + fn no_permissions_returns_err() { + let result = match resolve_path(&Path::new("/root/some_exec"), "") { + Err(BunBunError::Io(e)) => e.kind() == ErrorKind::PermissionDenied, + _ => false, + }; + assert!(result); + } - #[test] - fn non_success_exit_code_yields_err() { - // cat-ing a folder always returns exit code 1 - assert!(resolve_path(&Path::new("/bin/cat"), "/").is_err()); - } + #[test] + fn non_success_exit_code_yields_err() { + // cat-ing a folder always returns exit code 1 + assert!(resolve_path(&Path::new("/bin/cat"), "/").is_err()); + } - #[test] - fn return_body() -> Result<()> { - assert_eq!( - resolve_path(&Path::new("/bin/echo"), r#"{"body": "a"}"#)?, - HopAction::Body("a".to_string()) - ); + #[test] + fn return_body() -> Result<()> { + assert_eq!( + resolve_path(&Path::new("/bin/echo"), r#"{"body": "a"}"#)?, + HopAction::Body("a".to_string()) + ); - Ok(()) - } + Ok(()) + } - #[test] - fn return_redirect() -> Result<()> { - assert_eq!( - resolve_path(&Path::new("/bin/echo"), r#"{"redirect": "a"}"#)?, - HopAction::Redirect("a".to_string()) - ); - Ok(()) - } + #[test] + fn return_redirect() -> Result<()> { + assert_eq!( + resolve_path(&Path::new("/bin/echo"), r#"{"redirect": "a"}"#)?, + HopAction::Redirect("a".to_string()) + ); + Ok(()) + } } diff --git a/src/template_args.rs b/src/template_args.rs index ae85c7e..da8321b 100644 --- a/src/template_args.rs +++ b/src/template_args.rs @@ -4,19 +4,19 @@ use percent_encoding::PercentEncode; use serde::Serialize; pub fn query(query: PercentEncode<'_>) -> impl Serialize + '_ { - #[derive(Serialize)] - struct TemplateArgs<'a> { - query: Cow<'a, str>, - } - TemplateArgs { - query: query.into(), - } + #[derive(Serialize)] + struct TemplateArgs<'a> { + query: Cow<'a, str>, + } + TemplateArgs { + query: query.into(), + } } pub fn hostname(hostname: &'_ str) -> impl Serialize + '_ { - #[derive(Serialize)] - pub struct TemplateArgs<'a> { - pub hostname: &'a str, - } - TemplateArgs { hostname } + #[derive(Serialize)] + pub struct TemplateArgs<'a> { + pub hostname: &'a str, + } + TemplateArgs { hostname } }