bunbun/src/main.rs

411 lines
11 KiB
Rust
Raw Normal View History

2019-12-29 05:09:18 +00:00
#![forbid(unsafe_code)]
2019-12-27 05:02:12 +00:00
use actix_web::{middleware::Logger, App, HttpServer};
2019-12-24 00:33:09 +00:00
use clap::{crate_authors, crate_version, load_yaml, App as ClapApp};
2019-12-26 22:01:45 +00:00
use error::BunBunError;
2019-12-15 16:07:36 +00:00
use handlebars::Handlebars;
2019-12-21 19:04:13 +00:00
use hotwatch::{Event, Hotwatch};
2019-12-23 19:09:49 +00:00
use log::{debug, error, info, trace, warn};
2019-12-24 05:05:56 +00:00
use serde::{Deserialize, Serialize};
2019-12-24 03:13:38 +00:00
use std::cmp::min;
use std::collections::HashMap;
use std::fs::{read_to_string, OpenOptions};
2019-12-21 20:43:04 +00:00
use std::io::Write;
2019-12-15 17:49:16 +00:00
use std::sync::{Arc, RwLock};
2019-12-21 19:04:13 +00:00
use std::time::Duration;
2019-12-15 16:07:36 +00:00
2019-12-26 22:01:45 +00:00
mod error;
2019-12-23 15:02:21 +00:00
mod routes;
mod template_args;
2019-12-24 03:59:12 +00:00
const DEFAULT_CONFIG: &[u8] = include_bytes!("../bunbun.default.yaml");
2019-12-15 16:07:36 +00:00
2019-12-23 01:05:01 +00:00
/// Dynamic variables that either need to be present at runtime, or can be
/// changed during runtime.
2019-12-23 15:02:21 +00:00
pub struct State {
2019-12-21 19:21:13 +00:00
public_address: String,
default_route: Option<String>,
2019-12-24 05:05:56 +00:00
groups: Vec<RouteGroup>,
/// Cached, flattened mapping of all routes and their destinations.
routes: HashMap<String, String>,
2019-12-15 17:49:16 +00:00
}
2019-12-26 20:06:00 +00:00
#[actix_rt::main]
async fn main() -> Result<(), BunBunError> {
2019-12-24 00:33:09 +00:00
let yaml = load_yaml!("cli.yaml");
let matches = ClapApp::from(yaml)
.version(crate_version!())
.author(crate_authors!())
.get_matches();
2019-12-23 19:09:49 +00:00
2019-12-24 03:59:12 +00:00
init_logger(
matches.occurrences_of("verbose"),
matches.occurrences_of("quiet"),
)?;
2019-12-24 00:33:09 +00:00
2019-12-24 03:59:12 +00:00
// config has default location provided, unwrapping is fine.
2019-12-24 00:33:09 +00:00
let conf_file_location = String::from(matches.value_of("config").unwrap());
let conf = read_config(&conf_file_location)?;
2019-12-21 19:21:13 +00:00
let state = Arc::from(RwLock::new(State {
public_address: conf.public_address,
default_route: conf.default_route,
2019-12-24 05:05:56 +00:00
routes: cache_routes(&conf.groups),
groups: conf.groups,
2019-12-21 19:21:13 +00:00
}));
2019-12-24 00:33:09 +00:00
2019-12-26 20:49:42 +00:00
let _watch = start_watch(state.clone(), conf_file_location)?;
2019-12-23 19:09:49 +00:00
2019-12-21 19:21:13 +00:00
HttpServer::new(move || {
App::new()
2019-12-26 20:49:42 +00:00
.data(state.clone())
.app_data(compile_templates())
2019-12-23 19:09:49 +00:00
.wrap(Logger::default())
2019-12-23 15:02:21 +00:00
.service(routes::hop)
.service(routes::list)
.service(routes::index)
.service(routes::opensearch)
2019-12-21 19:21:13 +00:00
})
.bind(&conf.bind_address)?
2019-12-26 20:06:00 +00:00
.run()
.await?;
2019-12-21 20:43:04 +00:00
Ok(())
2019-12-15 17:49:16 +00:00
}
2019-12-15 16:07:36 +00:00
2019-12-24 05:05:56 +00:00
/// Initializes the logger based on the number of quiet and verbose flags passed
/// in. Usually, these values are mutually exclusive, that is, if the number of
/// verbose flags is non-zero then the quiet flag is zero, and vice versa.
2019-12-24 03:59:12 +00:00
fn init_logger(
num_verbose_flags: u64,
num_quiet_flags: u64,
) -> Result<(), BunBunError> {
let log_level =
match min(num_verbose_flags, 3) as i8 - min(num_quiet_flags, 2) as i8 {
-2 => None,
-1 => Some(log::Level::Error),
0 => Some(log::Level::Warn),
1 => Some(log::Level::Info),
2 => Some(log::Level::Debug),
3 => Some(log::Level::Trace),
_ => unreachable!(), // values are clamped to [0, 3] - [0, 2]
};
if let Some(level) = log_level {
simple_logger::init_with_level(level)?;
}
Ok(())
}
2019-12-29 05:08:13 +00:00
#[derive(Deserialize, Debug, PartialEq)]
2019-12-23 01:05:01 +00:00
struct Config {
bind_address: String,
public_address: String,
default_route: Option<String>,
2019-12-24 05:05:56 +00:00
groups: Vec<RouteGroup>,
}
2019-12-29 05:08:13 +00:00
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
2019-12-24 05:05:56 +00:00
struct RouteGroup {
name: String,
description: Option<String>,
routes: HashMap<String, String>,
2019-12-23 01:05:01 +00:00
}
/// Attempts to read the config file. If it doesn't exist, generate one a
/// default config file before attempting to parse it.
2019-12-21 20:43:04 +00:00
fn read_config(config_file_path: &str) -> Result<Config, BunBunError> {
2019-12-23 19:09:49 +00:00
trace!("Loading config file...");
let config_str = match read_to_string(config_file_path) {
2019-12-23 19:09:49 +00:00
Ok(conf_str) => {
debug!("Successfully loaded config file into memory.");
conf_str
}
2019-12-21 19:21:13 +00:00
Err(_) => {
2019-12-23 19:09:49 +00:00
info!(
2019-12-21 19:21:13 +00:00
"Unable to find a {} file. Creating default!",
config_file_path
);
2019-12-23 19:09:49 +00:00
let fd = OpenOptions::new()
2019-12-21 19:21:13 +00:00
.write(true)
.create_new(true)
2019-12-23 19:09:49 +00:00
.open(config_file_path);
match fd {
Ok(mut fd) => fd.write_all(DEFAULT_CONFIG)?,
Err(e) => {
error!("Failed to write to {}: {}. Default config will be loaded but not saved.", config_file_path, e);
}
};
String::from_utf8_lossy(DEFAULT_CONFIG).into_owned()
2019-12-21 19:21:13 +00:00
}
};
// Reading from memory is faster than reading directly from a reader for some
// reason; see https://github.com/serde-rs/json/issues/160
Ok(serde_yaml::from_str(&config_str)?)
2019-12-21 19:16:47 +00:00
}
/// Generates a hashmap of routes from the data structure created by the config
/// file. This should improve runtime performance and is a better solution than
/// just iterating over the config object for every hop resolution.
2019-12-24 05:05:56 +00:00
fn cache_routes(groups: &[RouteGroup]) -> HashMap<String, String> {
let mut mapping = HashMap::new();
for group in groups {
for (kw, dest) in &group.routes {
match mapping.insert(kw.clone(), dest.clone()) {
None => trace!("Inserting {} into mapping.", kw),
Some(old_value) => {
debug!("Overriding {} route from {} to {}.", kw, old_value, dest)
}
}
}
}
mapping
}
2019-12-23 01:05:01 +00:00
/// Returns an instance with all pre-generated templates included into the
/// binary. This allows for users to have a portable binary without needed the
/// templates at runtime.
2019-12-15 17:49:16 +00:00
fn compile_templates() -> Handlebars {
2019-12-21 19:21:13 +00:00
let mut handlebars = Handlebars::new();
2019-12-22 04:34:03 +00:00
macro_rules! register_template {
2019-12-23 01:05:01 +00:00
[ $( $template:expr ),* ] => {
2019-12-22 04:34:03 +00:00
$(
handlebars
.register_template_string(
$template,
String::from_utf8_lossy(
include_bytes!(concat!("templates/", $template, ".hbs")))
)
.unwrap();
2019-12-23 19:09:49 +00:00
debug!("Loaded {} template.", $template);
2019-12-22 04:34:03 +00:00
)*
};
}
2019-12-23 01:05:01 +00:00
register_template!["index", "list", "opensearch"];
2019-12-21 19:21:13 +00:00
handlebars
2019-12-15 16:07:36 +00:00
}
2019-12-26 20:49:42 +00:00
/// Starts the watch on a file, if possible. This will only return an Error if
/// the notify library (used by Hotwatch) fails to initialize, which is
/// considered to be a more serve error as it may be indicative of a low-level
/// problem. If a watch was unsuccessfully obtained (the most common is due to
/// the file not existing), then this will simply warn before returning a watch
/// object.
///
/// This watch object should be kept in scope as dropping it releases all
/// watches.
2019-12-26 20:49:42 +00:00
fn start_watch(
state: Arc<RwLock<State>>,
config_file_path: String,
) -> Result<Hotwatch, BunBunError> {
let mut watch = Hotwatch::new_with_custom_delay(Duration::from_millis(500))?;
// TODO: keep retry watching in separate thread
// Closures need their own copy of variables for proper lifecycle management
let config_file_path_clone = config_file_path.clone();
let watch_result = watch.watch(&config_file_path, move |e: Event| {
if let Event::Write(_) = e {
trace!("Grabbing writer lock on state...");
let mut state = state.write().unwrap();
trace!("Obtained writer lock on state!");
match read_config(&config_file_path_clone) {
Ok(conf) => {
state.public_address = conf.public_address;
state.default_route = conf.default_route;
state.routes = cache_routes(&conf.groups);
state.groups = conf.groups;
info!("Successfully updated active state");
}
Err(e) => warn!("Failed to update config file: {}", e),
}
} else {
debug!("Saw event {:#?} but ignored it", e);
}
});
match watch_result {
Ok(_) => info!("Watcher is now watching {}", &config_file_path),
Err(e) => warn!(
"Couldn't watch {}: {}. Changes to this file won't be seen!",
&config_file_path, e
),
}
2019-12-26 20:49:42 +00:00
Ok(watch)
}
2019-12-29 05:08:13 +00:00
#[cfg(test)]
mod init_logger {
use super::*;
#[test]
fn defaults_to_warn() -> Result<(), BunBunError> {
init_logger(0, 0)?;
assert_eq!(log::max_level(), log::Level::Warn);
Ok(())
}
#[test]
#[ignore]
fn caps_to_2_when_log_level_is_lt_2() -> Result<(), BunBunError> {
init_logger(0, 3)?;
assert_eq!(log::max_level(), log::LevelFilter::Off);
Ok(())
}
#[test]
#[ignore]
fn caps_to_3_when_log_level_is_gt_3() -> Result<(), BunBunError> {
init_logger(4, 0)?;
assert_eq!(log::max_level(), log::Level::Trace);
Ok(())
}
}
#[cfg(test)]
mod read_config {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn returns_default_config_if_path_does_not_exist() {
assert_eq!(
read_config("/this_is_a_non_existent_file").unwrap(),
serde_yaml::from_slice(DEFAULT_CONFIG).unwrap()
);
}
#[test]
fn returns_error_if_given_empty_config() {
assert_eq!(
read_config("/dev/null").unwrap_err().to_string(),
"EOF while parsing a value"
);
}
#[test]
fn returns_error_if_given_invalid_config() -> Result<(), std::io::Error> {
let mut tmp_file = NamedTempFile::new()?;
tmp_file.write_all(b"g")?;
assert_eq!(
read_config(tmp_file.path().to_str().unwrap())
.unwrap_err()
.to_string(),
r#"invalid type: string "g", expected struct Config at line 1 column 1"#
);
Ok(())
}
#[test]
fn returns_error_if_config_missing_field() -> Result<(), std::io::Error> {
let mut tmp_file = NamedTempFile::new()?;
tmp_file.write_all(
br#"
bind_address: "localhost"
public_address: "localhost"
"#,
)?;
assert_eq!(
read_config(tmp_file.path().to_str().unwrap())
.unwrap_err()
.to_string(),
"missing field `groups` at line 2 column 19"
);
Ok(())
}
#[test]
fn returns_ok_if_valid_config() -> Result<(), std::io::Error> {
let mut tmp_file = NamedTempFile::new()?;
tmp_file.write_all(
br#"
bind_address: "a"
public_address: "b"
groups: []"#,
)?;
assert_eq!(
read_config(tmp_file.path().to_str().unwrap()).unwrap(),
Config {
bind_address: String::from("a"),
public_address: String::from("b"),
groups: vec![],
default_route: None,
}
);
Ok(())
}
}
#[cfg(test)]
mod cache_routes {
use super::*;
use std::iter::FromIterator;
fn generate_routes(routes: &[(&str, &str)]) -> HashMap<String, String> {
HashMap::from_iter(
routes
.iter()
.map(|(k, v)| (String::from(*k), String::from(*v))),
)
}
#[test]
fn empty_groups_yield_empty_routes() {
assert_eq!(cache_routes(&[]), HashMap::new());
}
#[test]
fn disjoint_groups_yield_summed_routes() {
let group1 = RouteGroup {
name: String::from("x"),
description: Some(String::from("y")),
routes: generate_routes(&[("a", "b"), ("c", "d")]),
};
let group2 = RouteGroup {
name: String::from("5"),
description: Some(String::from("6")),
routes: generate_routes(&[("1", "2"), ("3", "4")]),
};
assert_eq!(
cache_routes(&[group1, group2]),
generate_routes(&[("a", "b"), ("c", "d"), ("1", "2"), ("3", "4")])
);
}
#[test]
fn overlapping_groups_use_latter_routes() {
let group1 = RouteGroup {
name: String::from("x"),
description: Some(String::from("y")),
routes: generate_routes(&[("a", "b"), ("c", "d")]),
};
let group2 = RouteGroup {
name: String::from("5"),
description: Some(String::from("6")),
routes: generate_routes(&[("a", "1"), ("c", "2")]),
};
assert_eq!(
cache_routes(&[group1.clone(), group2]),
generate_routes(&[("a", "1"), ("c", "2")])
);
let group3 = RouteGroup {
name: String::from("5"),
description: Some(String::from("6")),
routes: generate_routes(&[("a", "1"), ("b", "2")]),
};
assert_eq!(
cache_routes(&[group1, group3]),
generate_routes(&[("a", "1"), ("b", "2"), ("c", "d")])
);
}
}