bunbun/src/config.rs

259 lines
6.9 KiB
Rust
Raw Normal View History

2020-01-01 04:07:01 +00:00
use crate::BunBunError;
2019-12-29 05:48:02 +00:00
use log::{debug, error, info, trace};
2020-01-01 04:07:01 +00:00
use serde::{
de::{Deserializer, Visitor},
Deserialize, Serialize, Serializer,
};
2019-12-29 05:48:02 +00:00
use std::collections::HashMap;
2020-01-01 04:07:01 +00:00
use std::fmt;
2019-12-29 05:48:02 +00:00
use std::fs::{read_to_string, OpenOptions};
use std::io::Write;
const DEFAULT_CONFIG: &[u8] = include_bytes!("../bunbun.default.yaml");
#[derive(Deserialize, Debug, PartialEq)]
pub struct Config {
pub bind_address: String,
pub public_address: String,
pub default_route: Option<String>,
pub groups: Vec<RouteGroup>,
}
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct RouteGroup {
pub name: String,
pub description: Option<String>,
2019-12-31 22:36:21 +00:00
pub routes: HashMap<String, Route>,
2019-12-29 05:48:02 +00:00
}
2020-01-01 04:07:01 +00:00
#[derive(Debug, PartialEq, Clone)]
pub enum Route {
External(String),
Path(String),
}
/// Serialization of the Route enum needs to be transparent, but since the
/// `#[serde(transparent)]` macro isn't available on enums, so we need to
/// implement it manually.
impl Serialize for Route {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Self::External(s) => serializer.serialize_str(s),
Self::Path(s) => serializer.serialize_str(s),
}
}
}
/// Deserialization of the route string into the enum requires us to figure out
/// whether or not the string is valid to run as an executable or not. To
/// determine this, we simply check if it exists on disk or assume that it's a
/// web path. This incurs a disk check operation, but since users shouldn't be
/// updating the config that frequently, it should be fine.
impl<'de> Deserialize<'de> for Route {
fn deserialize<D>(deserializer: D) -> Result<Route, D::Error>
where
D: Deserializer<'de>,
{
struct RouteVisitor;
impl<'de> Visitor<'de> for RouteVisitor {
type Value = Route;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("string")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
// Return early if it's a path, don't go through URL parsing
if std::path::Path::new(value).exists() {
debug!("Parsed {} as a valid local path.", value);
Ok(Route::Path(value.into()))
} else {
debug!("{} does not exist on disk, assuming web path.", value);
Ok(Route::External(value.into()))
}
}
}
deserializer.deserialize_str(RouteVisitor)
}
}
impl std::fmt::Display for Route {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::External(s) => write!(f, "raw ({})", s),
Self::Path(s) => write!(f, "file ({})", s),
}
}
}
2019-12-29 05:48:02 +00:00
/// Attempts to read the config file. If it doesn't exist, generate one a
/// default config file before attempting to parse it.
pub fn read_config(config_file_path: &str) -> Result<Config, BunBunError> {
trace!("Loading config file...");
let config_str = match read_to_string(config_file_path) {
Ok(conf_str) => {
debug!("Successfully loaded config file into memory.");
conf_str
}
Err(_) => {
info!(
"Unable to find a {} file. Creating default!",
config_file_path
);
let fd = OpenOptions::new()
.write(true)
.create_new(true)
.open(config_file_path);
match fd {
Ok(mut fd) => fd.write_all(DEFAULT_CONFIG)?,
Err(e) => {
error!("Failed to write to {}: {}. Default config will be loaded but not saved.", config_file_path, e);
}
};
String::from_utf8_lossy(DEFAULT_CONFIG).into_owned()
}
};
// Reading from memory is faster than reading directly from a reader for some
// reason; see https://github.com/serde-rs/json/issues/160
Ok(serde_yaml::from_str(&config_str)?)
}
2020-01-01 04:07:01 +00:00
#[cfg(test)]
mod route {
use super::*;
use serde_yaml::{from_str, to_string};
use tempfile::NamedTempFile;
#[test]
fn deserialize_relative_path() {
let tmpfile = NamedTempFile::new_in(".").unwrap();
let path = format!("{}", tmpfile.path().display());
let path = path.get(path.rfind(".").unwrap()..).unwrap();
let path = std::path::Path::new(path);
assert!(path.is_relative());
let path = path.to_str().unwrap();
assert_eq!(from_str::<Route>(path).unwrap(), Route::Path(path.into()));
}
#[test]
fn deserialize_absolute_path() {
let tmpfile = NamedTempFile::new().unwrap();
let path = format!("{}", tmpfile.path().display());
assert!(tmpfile.path().is_absolute());
assert_eq!(from_str::<Route>(&path).unwrap(), Route::Path(path));
}
#[test]
fn deserialize_http_path() {
assert_eq!(
from_str::<Route>("http://google.com").unwrap(),
Route::External("http://google.com".into())
);
}
#[test]
fn deserialize_https_path() {
assert_eq!(
from_str::<Route>("https://google.com").unwrap(),
Route::External("https://google.com".into())
);
}
#[test]
fn serialize() {
assert_eq!(
&to_string(&Route::External("hello world".into())).unwrap(),
"---\nhello world"
);
assert_eq!(
&to_string(&Route::Path("hello world".into())).unwrap(),
"---\nhello world"
);
}
}
2019-12-29 05:48:02 +00:00
#[cfg(test)]
mod read_config {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn returns_default_config_if_path_does_not_exist() {
assert_eq!(
read_config("/this_is_a_non_existent_file").unwrap(),
serde_yaml::from_slice(DEFAULT_CONFIG).unwrap()
);
}
#[test]
fn returns_error_if_given_empty_config() {
assert_eq!(
read_config("/dev/null").unwrap_err().to_string(),
"EOF while parsing a value"
);
}
#[test]
fn returns_error_if_given_invalid_config() -> Result<(), std::io::Error> {
let mut tmp_file = NamedTempFile::new()?;
tmp_file.write_all(b"g")?;
assert_eq!(
read_config(tmp_file.path().to_str().unwrap())
.unwrap_err()
.to_string(),
r#"invalid type: string "g", expected struct Config at line 1 column 1"#
);
Ok(())
}
#[test]
fn returns_error_if_config_missing_field() -> Result<(), std::io::Error> {
let mut tmp_file = NamedTempFile::new()?;
tmp_file.write_all(
br#"
bind_address: "localhost"
public_address: "localhost"
"#,
)?;
assert_eq!(
read_config(tmp_file.path().to_str().unwrap())
.unwrap_err()
.to_string(),
"missing field `groups` at line 2 column 19"
);
Ok(())
}
#[test]
fn returns_ok_if_valid_config() -> Result<(), std::io::Error> {
let mut tmp_file = NamedTempFile::new()?;
tmp_file.write_all(
br#"
bind_address: "a"
public_address: "b"
groups: []"#,
)?;
assert_eq!(
read_config(tmp_file.path().to_str().unwrap()).unwrap(),
Config {
bind_address: String::from("a"),
public_address: String::from("b"),
groups: vec![],
default_route: None,
}
);
Ok(())
}
}