implement intelligent config location selection.

This commit is contained in:
Edward Shen 2020-07-04 20:17:12 -04:00
parent 633a152f89
commit 7faf15889a
Signed by: edward
GPG key ID: 19182661E818369F
5 changed files with 771 additions and 578 deletions

1101
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -12,6 +12,7 @@ exclude = ["/aux/"]
[dependencies] [dependencies]
actix-web = "2.0" actix-web = "2.0"
actix-rt = "1.0" actix-rt = "1.0"
dirs = "3.0"
serde = "1.0" serde = "1.0"
serde_yaml = "0.8" serde_yaml = "0.8"
handlebars = "2.0" handlebars = "2.0"

View file

@ -1,14 +1,17 @@
use crate::BunBunError; use crate::BunBunError;
use log::{debug, error, info, trace}; use dirs::{config_dir, home_dir};
use log::{debug, info, trace};
use serde::{ use serde::{
de::{Deserializer, Visitor}, de::{Deserializer, Visitor},
Deserialize, Serialize, Serializer, Deserialize, Serialize, Serializer,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::fmt;
use std::fs::{read_to_string, OpenOptions}; use std::fs::{File, OpenOptions};
use std::io::Write; use std::io::{Read, Write};
use std::path::PathBuf;
const CONFIG_FILENAME: &str = "bunbun.yaml";
const DEFAULT_CONFIG: &[u8] = include_bytes!("../bunbun.default.yaml"); const DEFAULT_CONFIG: &[u8] = include_bytes!("../bunbun.default.yaml");
#[derive(Deserialize, Debug, PartialEq)] #[derive(Deserialize, Debug, PartialEq)]
@ -93,40 +96,104 @@ impl std::fmt::Display for Route {
} }
} }
/// Attempts to read the config file. If it doesn't exist, generate one a pub struct ConfigData {
/// default config file before attempting to parse it. pub path: PathBuf,
pub fn read_config(config_file_path: &str) -> Result<Config, BunBunError> { pub file: File,
trace!("Loading config file...");
let config_str = match read_to_string(config_file_path) {
Ok(conf_str) => {
debug!("Successfully loaded config file into memory.");
conf_str
} }
Err(_) => {
info!(
"Unable to find a {} file. Creating default!",
config_file_path
);
let fd = OpenOptions::new() /// If a provided config path isn't found, this function checks known good
/// locations for a place to write a config file to. In order, it checks the
/// system-wide config location (`/etc/`, in Linux), followed by the config
/// folder, followed by the user's home folder.
pub fn get_config_data() -> Result<ConfigData, BunBunError> {
// Locations to check, with highest priority first
let locations: Vec<_> = {
let mut folders = vec![PathBuf::from("/etc/")];
// Config folder
if let Some(folder) = config_dir() { folders.push(folder) }
// Home folder
if let Some(folder) = home_dir() { folders.push(folder) }
folders
.iter_mut()
.for_each(|folder| folder.push(CONFIG_FILENAME));
folders
};
debug!("Checking locations for config file: {:?}", &locations);
for location in &locations {
let file = OpenOptions::new().read(true).open(location.clone());
match file {
Ok(file) => {
debug!("Found file at {:?}.", location);
return Ok(ConfigData {
path: location.clone(),
file,
})
}
Err(e) => debug!(
"Tried to read '{:?}' but failed due to error: {}",
location,
e
),
}
}
debug!("Failed to find any config. Now trying to find first writable path");
// If we got here, we failed to read any file paths, meaning no config exists
// yet. In that case, try to return the first location that we can write to,
// after writing the default config
for location in locations {
let file = OpenOptions::new()
.write(true) .write(true)
.create_new(true) .create_new(true)
.open(config_file_path); .open(location.clone());
match file {
Ok(mut file) => {
info!("Creating new config file at {:?}.", location);
file.write_all(DEFAULT_CONFIG)?;
match fd { let file = OpenOptions::new().read(true).open(location.clone())?;
Ok(mut fd) => fd.write_all(DEFAULT_CONFIG)?, return Ok(ConfigData {
Err(e) => { path: location,
error!("Failed to write to {}: {}. Default config will be loaded but not saved.", config_file_path, e); file,
});
}
Err(e) => debug!(
"Tried to open a new file at '{:?}' but failed due to error: {}",
location,
e
),
} }
};
String::from_utf8_lossy(DEFAULT_CONFIG).into_owned()
} }
};
Err(BunBunError::NoValidConfigPath)
}
/// Assumes that the user knows what they're talking about and will only try
/// to load the config at the given path.
pub fn load_custom_path_config(
path: impl Into<PathBuf>,
) -> Result<ConfigData, BunBunError> {
let path = path.into();
Ok(ConfigData {
file: OpenOptions::new().read(true).open(&path)?,
path,
})
}
pub fn read_config(mut config_file: File) -> Result<Config, BunBunError> {
trace!("Loading config file...");
let mut config_data = String::new();
config_file.read_to_string(&mut config_data)?;
// Reading from memory is faster than reading directly from a reader for some // Reading from memory is faster than reading directly from a reader for some
// reason; see https://github.com/serde-rs/json/issues/160 // reason; see https://github.com/serde-rs/json/issues/160
Ok(serde_yaml::from_str(&config_str)?) Ok(serde_yaml::from_str(&config_data)?)
} }
#[cfg(test)] #[cfg(test)]
@ -182,77 +249,3 @@ mod route {
); );
} }
} }
#[cfg(test)]
mod read_config {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn returns_default_config_if_path_does_not_exist() {
assert_eq!(
read_config("/this_is_a_non_existent_file").unwrap(),
serde_yaml::from_slice(DEFAULT_CONFIG).unwrap()
);
}
#[test]
fn returns_error_if_given_empty_config() {
assert_eq!(
read_config("/dev/null").unwrap_err().to_string(),
"EOF while parsing a value"
);
}
#[test]
fn returns_error_if_given_invalid_config() -> Result<(), std::io::Error> {
let mut tmp_file = NamedTempFile::new()?;
tmp_file.write_all(b"g")?;
assert_eq!(
read_config(tmp_file.path().to_str().unwrap())
.unwrap_err()
.to_string(),
r#"invalid type: string "g", expected struct Config at line 1 column 1"#
);
Ok(())
}
#[test]
fn returns_error_if_config_missing_field() -> Result<(), std::io::Error> {
let mut tmp_file = NamedTempFile::new()?;
tmp_file.write_all(
br#"
bind_address: "localhost"
public_address: "localhost"
"#,
)?;
assert_eq!(
read_config(tmp_file.path().to_str().unwrap())
.unwrap_err()
.to_string(),
"missing field `groups` at line 2 column 19"
);
Ok(())
}
#[test]
fn returns_ok_if_valid_config() -> Result<(), std::io::Error> {
let mut tmp_file = NamedTempFile::new()?;
tmp_file.write_all(
br#"
bind_address: "a"
public_address: "b"
groups: []"#,
)?;
assert_eq!(
read_config(tmp_file.path().to_str().unwrap()).unwrap(),
Config {
bind_address: String::from("a"),
public_address: String::from("b"),
groups: vec![],
default_route: None,
}
);
Ok(())
}
}

View file

@ -9,6 +9,7 @@ pub enum BunBunError {
WatchError(hotwatch::Error), WatchError(hotwatch::Error),
LoggerInitError(log::SetLoggerError), LoggerInitError(log::SetLoggerError),
CustomProgramError(String), CustomProgramError(String),
NoValidConfigPath
} }
impl Error for BunBunError {} impl Error for BunBunError {}
@ -21,6 +22,7 @@ impl fmt::Display for BunBunError {
Self::WatchError(e) => e.fmt(f), Self::WatchError(e) => e.fmt(f),
Self::LoggerInitError(e) => e.fmt(f), Self::LoggerInitError(e) => e.fmt(f),
Self::CustomProgramError(msg) => write!(f, "{}", msg), Self::CustomProgramError(msg) => write!(f, "{}", msg),
Self::NoValidConfigPath => write!(f, "No valid config path was found!"),
} }
} }
} }

View file

@ -1,6 +1,9 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
use crate::config::{read_config, Route, RouteGroup}; use crate::config::{
get_config_data, load_custom_path_config, read_config, ConfigData, Route,
RouteGroup,
};
use actix_web::{middleware::Logger, App, HttpServer}; use actix_web::{middleware::Logger, App, HttpServer};
use clap::{crate_authors, crate_version, load_yaml, App as ClapApp}; use clap::{crate_authors, crate_version, load_yaml, App as ClapApp};
use error::BunBunError; use error::BunBunError;
@ -41,8 +44,12 @@ async fn main() -> Result<(), BunBunError> {
)?; )?;
// config has default location provided, unwrapping is fine. // config has default location provided, unwrapping is fine.
let conf_file_location = String::from(matches.value_of("config").unwrap()); let conf_data = match matches.value_of("config") {
let conf = read_config(&conf_file_location)?; Some(file_name) => load_custom_path_config(file_name),
None => get_config_data(),
}?;
let conf = read_config(conf_data.file.try_clone()?)?;
let state = Arc::from(RwLock::new(State { let state = Arc::from(RwLock::new(State {
public_address: conf.public_address, public_address: conf.public_address,
default_route: conf.default_route, default_route: conf.default_route,
@ -50,7 +57,7 @@ async fn main() -> Result<(), BunBunError> {
groups: conf.groups, groups: conf.groups,
})); }));
let _watch = start_watch(state.clone(), conf_file_location)?; let _watch = start_watch(state.clone(), conf_data)?;
HttpServer::new(move || { HttpServer::new(move || {
App::new() App::new()
@ -104,7 +111,7 @@ fn cache_routes(groups: &[RouteGroup]) -> HashMap<String, Route> {
match mapping.insert(kw.clone(), dest.clone()) { match mapping.insert(kw.clone(), dest.clone()) {
None => trace!("Inserting {} into mapping.", kw), None => trace!("Inserting {} into mapping.", kw),
Some(old_value) => { Some(old_value) => {
debug!("Overriding {} route from {} to {}.", kw, old_value, dest) trace!("Overriding {} route from {} to {}.", kw, old_value, dest)
} }
} }
} }
@ -153,18 +160,25 @@ fn compile_templates() -> Handlebars {
/// watches. /// watches.
fn start_watch( fn start_watch(
state: Arc<RwLock<State>>, state: Arc<RwLock<State>>,
config_file_path: String, config_data: ConfigData,
) -> Result<Hotwatch, BunBunError> { ) -> Result<Hotwatch, BunBunError> {
let mut watch = Hotwatch::new_with_custom_delay(Duration::from_millis(500))?; let mut watch = Hotwatch::new_with_custom_delay(Duration::from_millis(500))?;
// TODO: keep retry watching in separate thread
// Closures need their own copy of variables for proper life cycle management // Closures need their own copy of variables for proper life cycle management
let config_file_path_clone = config_file_path.clone(); let config_data = Arc::new(config_data);
let watch_result = watch.watch(&config_file_path, move |e: Event| { let config_data_ref = Arc::clone(&config_data);
let watch_result = watch.watch(&config_data.path, move |e: Event| {
if let Event::Write(_) = e { if let Event::Write(_) = e {
trace!("Grabbing writer lock on state..."); trace!("Grabbing writer lock on state...");
let mut state = state.write().unwrap(); let mut state = state.write().expect("Failed to get write lock on state");
trace!("Obtained writer lock on state!"); trace!("Obtained writer lock on state!");
match read_config(&config_file_path_clone) { match read_config(
config_data_ref
.file
.try_clone()
.expect("Failed to clone file handle"),
) {
Ok(conf) => { Ok(conf) => {
state.public_address = conf.public_address; state.public_address = conf.public_address;
state.default_route = conf.default_route; state.default_route = conf.default_route;
@ -180,10 +194,10 @@ fn start_watch(
}); });
match watch_result { match watch_result {
Ok(_) => info!("Watcher is now watching {}", &config_file_path), Ok(_) => info!("Watcher is now watching {:?}", &config_data.path),
Err(e) => warn!( Err(e) => warn!(
"Couldn't watch {}: {}. Changes to this file won't be seen!", "Couldn't watch {:?}: {}. Changes to this file won't be seen!",
&config_file_path, e &config_data.path, e
), ),
} }