diff --git a/src/config.rs b/src/config.rs index ee959f5..87b8aca 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,4 @@ -use crate::BunBunError; +use crate::{routes::Route, BunBunError}; use log::{debug, error, info, trace}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -19,7 +19,7 @@ pub struct Config { pub struct RouteGroup { pub name: String, pub description: Option, - pub routes: HashMap, + pub routes: HashMap, } // TODO implement rlua: @@ -34,13 +34,6 @@ pub struct RouteGroup { // # Ok(()) // # } -#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] -struct Route { - dest: Option, - source: Option, - script: Option, -} - /// Attempts to read the config file. If it doesn't exist, generate one a /// default config file before attempting to parse it. pub fn read_config(config_file_path: &str) -> Result { diff --git a/src/main.rs b/src/main.rs index 047f3bc..459fc30 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,7 +24,7 @@ pub struct State { default_route: Option, groups: Vec, /// Cached, flattened mapping of all routes and their destinations. - routes: HashMap, + routes: HashMap, } #[actix_rt::main] @@ -97,7 +97,7 @@ fn init_logger( /// Generates a hashmap of routes from the data structure created by the config /// file. This should improve runtime performance and is a better solution than /// just iterating over the config object for every hop resolution. -fn cache_routes(groups: &[RouteGroup]) -> HashMap { +fn cache_routes(groups: &[RouteGroup]) -> HashMap { let mut mapping = HashMap::new(); for group in groups { for (kw, dest) in &group.routes { @@ -216,11 +216,13 @@ mod cache_routes { use super::*; use std::iter::FromIterator; - fn generate_routes(routes: &[(&str, &str)]) -> HashMap { + fn generate_external_routes( + routes: &[(&str, &str)], + ) -> HashMap { HashMap::from_iter( routes .iter() - .map(|(k, v)| (String::from(*k), String::from(*v))), + .map(|(k, v)| ((*k).into(), routes::Route::External((*v).into()))), ) } @@ -234,18 +236,23 @@ mod cache_routes { let group1 = RouteGroup { name: String::from("x"), description: Some(String::from("y")), - routes: generate_routes(&[("a", "b"), ("c", "d")]), + routes: generate_external_routes(&[("a", "b"), ("c", "d")]), }; let group2 = RouteGroup { name: String::from("5"), description: Some(String::from("6")), - routes: generate_routes(&[("1", "2"), ("3", "4")]), + routes: generate_external_routes(&[("1", "2"), ("3", "4")]), }; assert_eq!( cache_routes(&[group1, group2]), - generate_routes(&[("a", "b"), ("c", "d"), ("1", "2"), ("3", "4")]) + generate_external_routes(&[ + ("a", "b"), + ("c", "d"), + ("1", "2"), + ("3", "4") + ]) ); } @@ -254,29 +261,29 @@ mod cache_routes { let group1 = RouteGroup { name: String::from("x"), description: Some(String::from("y")), - routes: generate_routes(&[("a", "b"), ("c", "d")]), + routes: generate_external_routes(&[("a", "b"), ("c", "d")]), }; let group2 = RouteGroup { name: String::from("5"), description: Some(String::from("6")), - routes: generate_routes(&[("a", "1"), ("c", "2")]), + routes: generate_external_routes(&[("a", "1"), ("c", "2")]), }; assert_eq!( cache_routes(&[group1.clone(), group2]), - generate_routes(&[("a", "1"), ("c", "2")]) + generate_external_routes(&[("a", "1"), ("c", "2")]) ); let group3 = RouteGroup { name: String::from("5"), description: Some(String::from("6")), - routes: generate_routes(&[("a", "1"), ("b", "2")]), + routes: generate_external_routes(&[("a", "1"), ("b", "2")]), }; assert_eq!( cache_routes(&[group1, group3]), - generate_routes(&[("a", "1"), ("b", "2"), ("c", "d")]) + generate_external_routes(&[("a", "1"), ("b", "2"), ("c", "d")]) ); } } diff --git a/src/routes.rs b/src/routes.rs index 06b20ab..7158b18 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -8,12 +8,11 @@ use handlebars::Handlebars; use itertools::Itertools; use log::debug; use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; -use serde::Deserialize; +use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; use std::collections::HashMap; +use std::fmt; use std::sync::{Arc, RwLock}; -type StateData = Data>>; - /// https://url.spec.whatwg.org/#fragment-percent-encode-set const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS .add(b' ') @@ -23,6 +22,67 @@ const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS .add(b'`') .add(b'+'); +type StateData = Data>>; + +#[derive(Debug, PartialEq, Clone)] +pub enum Route { + External(String), + Path(String), +} + +impl Serialize for Route { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + Self::External(s) => serializer.serialize_str(s), + Self::Path(s) => serializer.serialize_str(s), + } + } +} + +impl<'de> Deserialize<'de> for Route { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct RouteVisitor; + impl<'de> Visitor<'de> for RouteVisitor { + type Value = Route; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("string") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // Return early if it's a path, don't go through URL parsing + if std::path::Path::new(value).exists() { + debug!("Parsed {} as a valid local path.", value); + Ok(Route::Path(value.into())) + } else { + debug!("{} does not exist on disk, assuming web path.", value); + Ok(Route::External(value.into())) + } + } + } + + deserializer.deserialize_str(RouteVisitor) + } +} + +impl std::fmt::Display for Route { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::External(s) => write!(f, "raw path ({})", s), + Self::Path(s) => write!(f, "file path ({})", s), + } + } +} + #[get("/ls")] pub async fn list( data: Data>>, @@ -59,7 +119,10 @@ pub async fn hop( .app_data::() .unwrap() .render_template( - &path, + match path { + Route::Path(s) => s, // TODO: try resolve path + Route::External(s) => s, + }, &template_args::query( utf8_percent_encode(&args, FRAGMENT_ENCODE_SET).to_string(), ), @@ -77,11 +140,11 @@ pub async fn hop( /// /// The first element in the tuple describes the route, while the second element /// returns the remaining arguments. If none remain, an empty string is given. -fn resolve_hop( +fn resolve_hop<'a>( query: &str, - routes: &HashMap, + routes: &'a HashMap, default_route: &Option, -) -> (Option, String) { +) -> (Option<&'a Route>, String) { let mut split_args = query.split_ascii_whitespace().peekable(); let command = match split_args.peek() { Some(command) => command, @@ -94,7 +157,7 @@ fn resolve_hop( match (routes.get(*command), default_route) { // Found a route (Some(resolved), _) => ( - Some(resolved.clone()), + Some(resolved), match split_args.next() { // Discard the first result, we found the route using the first arg Some(_) => { @@ -113,7 +176,7 @@ fn resolve_hop( let args = split_args.join(" "); debug!("Using default route {} with args {}", route, args); match routes.get(route) { - Some(v) => (Some(v.to_owned()), args), + Some(v) => (Some(v), args), None => (None, String::new()), } } @@ -160,15 +223,69 @@ pub async fn opensearch(data: StateData, req: HttpRequest) -> impl Responder { ) } +#[cfg(test)] +mod route { + use super::*; + use serde_yaml::{from_str, to_string}; + use tempfile::NamedTempFile; + + #[test] + fn deserialize_relative_path() { + let tmpfile = NamedTempFile::new_in(".").unwrap(); + let path = format!("{}", tmpfile.path().display()); + let path = path.get(path.rfind(".").unwrap()..).unwrap(); + let path = std::path::Path::new(path); + assert!(path.is_relative()); + let path = path.to_str().unwrap(); + assert_eq!(from_str::(path).unwrap(), Route::Path(path.into())); + } + + #[test] + fn deserialize_absolute_path() { + let tmpfile = NamedTempFile::new().unwrap(); + let path = format!("{}", tmpfile.path().display()); + assert!(tmpfile.path().is_absolute()); + assert_eq!(from_str::(&path).unwrap(), Route::Path(path)); + } + + #[test] + fn deserialize_http_path() { + assert_eq!( + from_str::("http://google.com").unwrap(), + Route::External("http://google.com".into()) + ); + } + + #[test] + fn deserialize_https_path() { + assert_eq!( + from_str::("https://google.com").unwrap(), + Route::External("https://google.com".into()) + ); + } + + #[test] + fn serialize() { + assert_eq!( + &to_string(&Route::External("hello world".into())).unwrap(), + "---\nhello world" + ); + assert_eq!( + &to_string(&Route::Path("hello world".into())).unwrap(), + "---\nhello world" + ); + } +} + #[cfg(test)] mod resolve_hop { use super::*; - fn generate_route_result( - keyword: &str, + fn generate_route_result<'a>( + keyword: &'a Route, args: &str, - ) -> (Option, String) { - (Some(String::from(keyword)), String::from(args)) + ) -> (Option<&'a Route>, String) { + (Some(keyword), String::from(args)) } #[test] @@ -193,31 +310,49 @@ mod resolve_hop { #[test] fn only_default_routes_some_default_yields_default_hop() { - let mut map = HashMap::new(); - map.insert(String::from("google"), String::from("https://example.com")); + let mut map: HashMap = HashMap::new(); + map.insert( + "google".into(), + Route::External("https://example.com".into()), + ); assert_eq!( resolve_hop("hello world", &map, &Some(String::from("google"))), - generate_route_result("https://example.com", "hello world"), + generate_route_result( + &Route::External("https://example.com".into()), + "hello world" + ), ); } #[test] fn non_default_routes_some_default_yields_non_default_hop() { - let mut map = HashMap::new(); - map.insert(String::from("google"), String::from("https://example.com")); + let mut map: HashMap = HashMap::new(); + map.insert( + "google".into(), + Route::External("https://example.com".into()), + ); assert_eq!( resolve_hop("google hello world", &map, &Some(String::from("a"))), - generate_route_result("https://example.com", "hello world"), + generate_route_result( + &Route::External("https://example.com".into()), + "hello world" + ), ); } #[test] fn non_default_routes_no_default_yields_non_default_hop() { - let mut map = HashMap::new(); - map.insert(String::from("google"), String::from("https://example.com")); + let mut map: HashMap = HashMap::new(); + map.insert( + "google".into(), + Route::External("https://example.com".into()), + ); assert_eq!( resolve_hop("google hello world", &map, &None), - generate_route_result("https://example.com", "hello world"), + generate_route_result( + &Route::External("https://example.com".into()), + "hello world" + ), ); } }