use enum for routes

This commit is contained in:
Edward Shen 2019-12-31 17:36:21 -05:00
parent a4543c48ec
commit 1385045013
Signed by: edward
GPG key ID: F350507060ED6C90
3 changed files with 178 additions and 43 deletions

View file

@ -1,4 +1,4 @@
use crate::BunBunError; use crate::{routes::Route, BunBunError};
use log::{debug, error, info, trace}; use log::{debug, error, info, trace};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
@ -19,7 +19,7 @@ pub struct Config {
pub struct RouteGroup { pub struct RouteGroup {
pub name: String, pub name: String,
pub description: Option<String>, pub description: Option<String>,
pub routes: HashMap<String, String>, pub routes: HashMap<String, Route>,
} }
// TODO implement rlua: // TODO implement rlua:
@ -34,13 +34,6 @@ pub struct RouteGroup {
// # Ok(()) // # Ok(())
// # } // # }
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
struct Route {
dest: Option<String>,
source: Option<String>,
script: Option<String>,
}
/// Attempts to read the config file. If it doesn't exist, generate one a /// Attempts to read the config file. If it doesn't exist, generate one a
/// default config file before attempting to parse it. /// default config file before attempting to parse it.
pub fn read_config(config_file_path: &str) -> Result<Config, BunBunError> { pub fn read_config(config_file_path: &str) -> Result<Config, BunBunError> {

View file

@ -24,7 +24,7 @@ pub struct State {
default_route: Option<String>, default_route: Option<String>,
groups: Vec<RouteGroup>, groups: Vec<RouteGroup>,
/// Cached, flattened mapping of all routes and their destinations. /// Cached, flattened mapping of all routes and their destinations.
routes: HashMap<String, String>, routes: HashMap<String, routes::Route>,
} }
#[actix_rt::main] #[actix_rt::main]
@ -97,7 +97,7 @@ fn init_logger(
/// Generates a hashmap of routes from the data structure created by the config /// Generates a hashmap of routes from the data structure created by the config
/// file. This should improve runtime performance and is a better solution than /// file. This should improve runtime performance and is a better solution than
/// just iterating over the config object for every hop resolution. /// just iterating over the config object for every hop resolution.
fn cache_routes(groups: &[RouteGroup]) -> HashMap<String, String> { fn cache_routes(groups: &[RouteGroup]) -> HashMap<String, routes::Route> {
let mut mapping = HashMap::new(); let mut mapping = HashMap::new();
for group in groups { for group in groups {
for (kw, dest) in &group.routes { for (kw, dest) in &group.routes {
@ -216,11 +216,13 @@ mod cache_routes {
use super::*; use super::*;
use std::iter::FromIterator; use std::iter::FromIterator;
fn generate_routes(routes: &[(&str, &str)]) -> HashMap<String, String> { fn generate_external_routes(
routes: &[(&str, &str)],
) -> HashMap<String, routes::Route> {
HashMap::from_iter( HashMap::from_iter(
routes routes
.iter() .iter()
.map(|(k, v)| (String::from(*k), String::from(*v))), .map(|(k, v)| ((*k).into(), routes::Route::External((*v).into()))),
) )
} }
@ -234,18 +236,23 @@ mod cache_routes {
let group1 = RouteGroup { let group1 = RouteGroup {
name: String::from("x"), name: String::from("x"),
description: Some(String::from("y")), description: Some(String::from("y")),
routes: generate_routes(&[("a", "b"), ("c", "d")]), routes: generate_external_routes(&[("a", "b"), ("c", "d")]),
}; };
let group2 = RouteGroup { let group2 = RouteGroup {
name: String::from("5"), name: String::from("5"),
description: Some(String::from("6")), description: Some(String::from("6")),
routes: generate_routes(&[("1", "2"), ("3", "4")]), routes: generate_external_routes(&[("1", "2"), ("3", "4")]),
}; };
assert_eq!( assert_eq!(
cache_routes(&[group1, group2]), cache_routes(&[group1, group2]),
generate_routes(&[("a", "b"), ("c", "d"), ("1", "2"), ("3", "4")]) generate_external_routes(&[
("a", "b"),
("c", "d"),
("1", "2"),
("3", "4")
])
); );
} }
@ -254,29 +261,29 @@ mod cache_routes {
let group1 = RouteGroup { let group1 = RouteGroup {
name: String::from("x"), name: String::from("x"),
description: Some(String::from("y")), description: Some(String::from("y")),
routes: generate_routes(&[("a", "b"), ("c", "d")]), routes: generate_external_routes(&[("a", "b"), ("c", "d")]),
}; };
let group2 = RouteGroup { let group2 = RouteGroup {
name: String::from("5"), name: String::from("5"),
description: Some(String::from("6")), description: Some(String::from("6")),
routes: generate_routes(&[("a", "1"), ("c", "2")]), routes: generate_external_routes(&[("a", "1"), ("c", "2")]),
}; };
assert_eq!( assert_eq!(
cache_routes(&[group1.clone(), group2]), cache_routes(&[group1.clone(), group2]),
generate_routes(&[("a", "1"), ("c", "2")]) generate_external_routes(&[("a", "1"), ("c", "2")])
); );
let group3 = RouteGroup { let group3 = RouteGroup {
name: String::from("5"), name: String::from("5"),
description: Some(String::from("6")), description: Some(String::from("6")),
routes: generate_routes(&[("a", "1"), ("b", "2")]), routes: generate_external_routes(&[("a", "1"), ("b", "2")]),
}; };
assert_eq!( assert_eq!(
cache_routes(&[group1, group3]), cache_routes(&[group1, group3]),
generate_routes(&[("a", "1"), ("b", "2"), ("c", "d")]) generate_external_routes(&[("a", "1"), ("b", "2"), ("c", "d")])
); );
} }
} }

View file

@ -8,12 +8,11 @@ use handlebars::Handlebars;
use itertools::Itertools; use itertools::Itertools;
use log::debug; use log::debug;
use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
use serde::Deserialize; use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
type StateData = Data<Arc<RwLock<State>>>;
/// https://url.spec.whatwg.org/#fragment-percent-encode-set /// https://url.spec.whatwg.org/#fragment-percent-encode-set
const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS
.add(b' ') .add(b' ')
@ -23,6 +22,67 @@ const FRAGMENT_ENCODE_SET: &AsciiSet = &CONTROLS
.add(b'`') .add(b'`')
.add(b'+'); .add(b'+');
type StateData = Data<Arc<RwLock<State>>>;
#[derive(Debug, PartialEq, Clone)]
pub enum Route {
External(String),
Path(String),
}
impl Serialize for Route {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Self::External(s) => serializer.serialize_str(s),
Self::Path(s) => serializer.serialize_str(s),
}
}
}
impl<'de> Deserialize<'de> for Route {
fn deserialize<D>(deserializer: D) -> Result<Route, D::Error>
where
D: Deserializer<'de>,
{
struct RouteVisitor;
impl<'de> Visitor<'de> for RouteVisitor {
type Value = Route;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("string")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
// Return early if it's a path, don't go through URL parsing
if std::path::Path::new(value).exists() {
debug!("Parsed {} as a valid local path.", value);
Ok(Route::Path(value.into()))
} else {
debug!("{} does not exist on disk, assuming web path.", value);
Ok(Route::External(value.into()))
}
}
}
deserializer.deserialize_str(RouteVisitor)
}
}
impl std::fmt::Display for Route {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::External(s) => write!(f, "raw path ({})", s),
Self::Path(s) => write!(f, "file path ({})", s),
}
}
}
#[get("/ls")] #[get("/ls")]
pub async fn list( pub async fn list(
data: Data<Arc<RwLock<State>>>, data: Data<Arc<RwLock<State>>>,
@ -59,7 +119,10 @@ pub async fn hop(
.app_data::<Handlebars>() .app_data::<Handlebars>()
.unwrap() .unwrap()
.render_template( .render_template(
&path, match path {
Route::Path(s) => s, // TODO: try resolve path
Route::External(s) => s,
},
&template_args::query( &template_args::query(
utf8_percent_encode(&args, FRAGMENT_ENCODE_SET).to_string(), utf8_percent_encode(&args, FRAGMENT_ENCODE_SET).to_string(),
), ),
@ -77,11 +140,11 @@ pub async fn hop(
/// ///
/// The first element in the tuple describes the route, while the second element /// The first element in the tuple describes the route, while the second element
/// returns the remaining arguments. If none remain, an empty string is given. /// returns the remaining arguments. If none remain, an empty string is given.
fn resolve_hop( fn resolve_hop<'a>(
query: &str, query: &str,
routes: &HashMap<String, String>, routes: &'a HashMap<String, Route>,
default_route: &Option<String>, default_route: &Option<String>,
) -> (Option<String>, String) { ) -> (Option<&'a Route>, String) {
let mut split_args = query.split_ascii_whitespace().peekable(); let mut split_args = query.split_ascii_whitespace().peekable();
let command = match split_args.peek() { let command = match split_args.peek() {
Some(command) => command, Some(command) => command,
@ -94,7 +157,7 @@ fn resolve_hop(
match (routes.get(*command), default_route) { match (routes.get(*command), default_route) {
// Found a route // Found a route
(Some(resolved), _) => ( (Some(resolved), _) => (
Some(resolved.clone()), Some(resolved),
match split_args.next() { match split_args.next() {
// Discard the first result, we found the route using the first arg // Discard the first result, we found the route using the first arg
Some(_) => { Some(_) => {
@ -113,7 +176,7 @@ fn resolve_hop(
let args = split_args.join(" "); let args = split_args.join(" ");
debug!("Using default route {} with args {}", route, args); debug!("Using default route {} with args {}", route, args);
match routes.get(route) { match routes.get(route) {
Some(v) => (Some(v.to_owned()), args), Some(v) => (Some(v), args),
None => (None, String::new()), None => (None, String::new()),
} }
} }
@ -160,15 +223,69 @@ pub async fn opensearch(data: StateData, req: HttpRequest) -> impl Responder {
) )
} }
#[cfg(test)]
mod route {
use super::*;
use serde_yaml::{from_str, to_string};
use tempfile::NamedTempFile;
#[test]
fn deserialize_relative_path() {
let tmpfile = NamedTempFile::new_in(".").unwrap();
let path = format!("{}", tmpfile.path().display());
let path = path.get(path.rfind(".").unwrap()..).unwrap();
let path = std::path::Path::new(path);
assert!(path.is_relative());
let path = path.to_str().unwrap();
assert_eq!(from_str::<Route>(path).unwrap(), Route::Path(path.into()));
}
#[test]
fn deserialize_absolute_path() {
let tmpfile = NamedTempFile::new().unwrap();
let path = format!("{}", tmpfile.path().display());
assert!(tmpfile.path().is_absolute());
assert_eq!(from_str::<Route>(&path).unwrap(), Route::Path(path));
}
#[test]
fn deserialize_http_path() {
assert_eq!(
from_str::<Route>("http://google.com").unwrap(),
Route::External("http://google.com".into())
);
}
#[test]
fn deserialize_https_path() {
assert_eq!(
from_str::<Route>("https://google.com").unwrap(),
Route::External("https://google.com".into())
);
}
#[test]
fn serialize() {
assert_eq!(
&to_string(&Route::External("hello world".into())).unwrap(),
"---\nhello world"
);
assert_eq!(
&to_string(&Route::Path("hello world".into())).unwrap(),
"---\nhello world"
);
}
}
#[cfg(test)] #[cfg(test)]
mod resolve_hop { mod resolve_hop {
use super::*; use super::*;
fn generate_route_result( fn generate_route_result<'a>(
keyword: &str, keyword: &'a Route,
args: &str, args: &str,
) -> (Option<String>, String) { ) -> (Option<&'a Route>, String) {
(Some(String::from(keyword)), String::from(args)) (Some(keyword), String::from(args))
} }
#[test] #[test]
@ -193,31 +310,49 @@ mod resolve_hop {
#[test] #[test]
fn only_default_routes_some_default_yields_default_hop() { fn only_default_routes_some_default_yields_default_hop() {
let mut map = HashMap::new(); let mut map: HashMap<String, Route> = HashMap::new();
map.insert(String::from("google"), String::from("https://example.com")); map.insert(
"google".into(),
Route::External("https://example.com".into()),
);
assert_eq!( assert_eq!(
resolve_hop("hello world", &map, &Some(String::from("google"))), resolve_hop("hello world", &map, &Some(String::from("google"))),
generate_route_result("https://example.com", "hello world"), generate_route_result(
&Route::External("https://example.com".into()),
"hello world"
),
); );
} }
#[test] #[test]
fn non_default_routes_some_default_yields_non_default_hop() { fn non_default_routes_some_default_yields_non_default_hop() {
let mut map = HashMap::new(); let mut map: HashMap<String, Route> = HashMap::new();
map.insert(String::from("google"), String::from("https://example.com")); map.insert(
"google".into(),
Route::External("https://example.com".into()),
);
assert_eq!( assert_eq!(
resolve_hop("google hello world", &map, &Some(String::from("a"))), resolve_hop("google hello world", &map, &Some(String::from("a"))),
generate_route_result("https://example.com", "hello world"), generate_route_result(
&Route::External("https://example.com".into()),
"hello world"
),
); );
} }
#[test] #[test]
fn non_default_routes_no_default_yields_non_default_hop() { fn non_default_routes_no_default_yields_non_default_hop() {
let mut map = HashMap::new(); let mut map: HashMap<String, Route> = HashMap::new();
map.insert(String::from("google"), String::from("https://example.com")); map.insert(
"google".into(),
Route::External("https://example.com".into()),
);
assert_eq!( assert_eq!(
resolve_hop("google hello world", &map, &None), resolve_hop("google hello world", &map, &None),
generate_route_result("https://example.com", "hello world"), generate_route_result(
&Route::External("https://example.com".into()),
"hello world"
),
); );
} }
} }