bunbun/src/config.rs

419 lines
11 KiB
Rust
Raw Normal View History

2020-01-01 04:07:01 +00:00
use crate::BunBunError;
use dirs::{config_dir, home_dir};
use log::{debug, info, trace};
2020-01-01 04:07:01 +00:00
use serde::{
de::{self, Deserializer, MapAccess, Unexpected, Visitor},
Deserialize, Serialize,
2020-01-01 04:07:01 +00:00
};
2019-12-29 05:48:02 +00:00
use std::collections::HashMap;
2020-01-01 04:07:01 +00:00
use std::fmt;
use std::fs::{File, OpenOptions};
use std::io::{Read, Write};
use std::path::PathBuf;
use std::str::FromStr;
2019-12-29 05:48:02 +00:00
const CONFIG_FILENAME: &str = "bunbun.yaml";
2019-12-29 05:48:02 +00:00
const DEFAULT_CONFIG: &[u8] = include_bytes!("../bunbun.default.yaml");
2020-11-23 22:52:22 +00:00
#[cfg(not(test))]
const LARGE_FILE_SIZE_THRESHOLD: u64 = 100_000_000;
#[cfg(test)]
const LARGE_FILE_SIZE_THRESHOLD: u64 = 1_000_000;
2019-12-29 05:48:02 +00:00
#[derive(Deserialize, Debug, PartialEq)]
pub struct Config {
pub bind_address: String,
pub public_address: String,
pub default_route: Option<String>,
pub groups: Vec<RouteGroup>,
}
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
pub struct RouteGroup {
pub name: String,
pub description: Option<String>,
#[serde(default)]
pub hidden: bool,
2019-12-31 22:36:21 +00:00
pub routes: HashMap<String, Route>,
2019-12-29 05:48:02 +00:00
}
#[derive(Debug, PartialEq, Clone, Serialize)]
pub struct Route {
pub route_type: RouteType,
pub path: String,
pub hidden: bool,
pub description: Option<String>,
2020-09-27 21:02:43 +00:00
pub min_args: Option<usize>,
pub max_args: Option<usize>,
2020-01-01 04:07:01 +00:00
}
impl FromStr for Route {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self {
route_type: get_route_type(s),
path: s.to_string(),
hidden: false,
description: None,
2020-09-27 21:02:43 +00:00
min_args: None,
max_args: None,
})
2020-01-01 04:07:01 +00:00
}
}
/// Deserialization of the route string into the enum requires us to figure out
/// whether or not the string is valid to run as an executable or not. To
/// determine this, we simply check if it exists on disk or assume that it's a
/// web path. This incurs a disk check operation, but since users shouldn't be
/// updating the config that frequently, it should be fine.
impl<'de> Deserialize<'de> for Route {
fn deserialize<D>(deserializer: D) -> Result<Route, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
2020-09-28 01:33:32 +00:00
#[serde(field_identifier, rename_all = "snake_case")]
enum Field {
Path,
Hidden,
Description,
2020-09-27 21:02:43 +00:00
MinArgs,
MaxArgs,
}
2020-01-01 04:07:01 +00:00
struct RouteVisitor;
2020-01-01 04:07:01 +00:00
impl<'de> Visitor<'de> for RouteVisitor {
type Value = Route;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("string")
}
fn visit_str<E>(self, path: &str) -> Result<Self::Value, E>
2020-01-01 04:07:01 +00:00
where
E: serde::de::Error,
{
// This is infallible
Ok(Self::Value::from_str(path).unwrap())
}
fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut path = None;
let mut hidden = None;
let mut description = None;
2020-09-27 21:02:43 +00:00
let mut min_args = None;
let mut max_args = None;
while let Some(key) = map.next_key()? {
match key {
Field::Path => {
if path.is_some() {
return Err(de::Error::duplicate_field("path"));
}
path = Some(map.next_value::<String>()?);
}
Field::Hidden => {
if hidden.is_some() {
return Err(de::Error::duplicate_field("hidden"));
}
hidden = map.next_value()?;
}
Field::Description => {
if description.is_some() {
return Err(de::Error::duplicate_field("description"));
}
description = Some(map.next_value()?);
}
2020-09-27 21:02:43 +00:00
Field::MinArgs => {
if min_args.is_some() {
return Err(de::Error::duplicate_field("min_args"));
}
min_args = Some(map.next_value()?);
}
Field::MaxArgs => {
if max_args.is_some() {
return Err(de::Error::duplicate_field("max_args"));
}
max_args = Some(map.next_value()?);
}
}
2020-01-01 04:07:01 +00:00
}
if let (Some(min_args), Some(max_args)) = (min_args, max_args) {
if min_args > max_args {
{
return Err(de::Error::invalid_value(
Unexpected::Other(&format!(
"argument count range {} to {}",
min_args, max_args
)),
&"a valid argument count range",
));
}
}
}
let path = path.ok_or_else(|| de::Error::missing_field("path"))?;
Ok(Route {
route_type: get_route_type(&path),
path,
hidden: hidden.unwrap_or_default(),
description,
2020-09-27 21:02:43 +00:00
min_args,
max_args,
})
2020-01-01 04:07:01 +00:00
}
}
deserializer.deserialize_any(RouteVisitor)
}
}
2020-01-01 04:07:01 +00:00
impl std::fmt::Display for Route {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self {
route_type: RouteType::External,
path,
..
} => write!(f, "raw ({})", path),
Self {
route_type: RouteType::Internal,
path,
..
} => write!(f, "file ({})", path),
2020-01-01 04:07:01 +00:00
}
}
}
2020-07-05 03:32:08 +00:00
/// Classifies the path depending on if the there exists a local file.
fn get_route_type(path: &str) -> RouteType {
if std::path::Path::new(path).exists() {
debug!("Parsed {} as a valid local path.", path);
RouteType::Internal
} else {
debug!("{} does not exist on disk, assuming web path.", path);
RouteType::External
}
}
/// There exists two route types: an external path (e.g. a URL) or an internal
/// path (to a file).
#[derive(Debug, PartialEq, Clone, Serialize)]
pub enum RouteType {
External,
Internal,
}
pub struct ConfigData {
pub path: PathBuf,
pub file: File,
}
2019-12-29 05:48:02 +00:00
/// If a provided config path isn't found, this function checks known good
/// locations for a place to write a config file to. In order, it checks the
/// system-wide config location (`/etc/`, in Linux), followed by the config
/// folder, followed by the user's home folder.
pub fn get_config_data() -> Result<ConfigData, BunBunError> {
// Locations to check, with highest priority first
let locations: Vec<_> = {
let mut folders = vec![PathBuf::from("/etc/")];
2019-12-29 05:48:02 +00:00
// Config folder
2020-07-05 00:41:55 +00:00
if let Some(folder) = config_dir() {
folders.push(folder)
}
2019-12-29 05:48:02 +00:00
// Home folder
2020-07-05 00:41:55 +00:00
if let Some(folder) = home_dir() {
folders.push(folder)
}
folders
.iter_mut()
.for_each(|folder| folder.push(CONFIG_FILENAME));
folders
2019-12-29 05:48:02 +00:00
};
debug!("Checking locations for config file: {:?}", &locations);
for location in &locations {
let file = OpenOptions::new().read(true).open(location.clone());
match file {
Ok(file) => {
debug!("Found file at {:?}.", location);
return Ok(ConfigData {
path: location.clone(),
file,
2020-07-05 00:41:55 +00:00
});
}
Err(e) => debug!(
"Tried to read '{:?}' but failed due to error: {}",
2020-07-05 00:41:55 +00:00
location, e
),
}
}
debug!("Failed to find any config. Now trying to find first writable path");
// If we got here, we failed to read any file paths, meaning no config exists
// yet. In that case, try to return the first location that we can write to,
// after writing the default config
for location in locations {
let file = OpenOptions::new()
.write(true)
.create_new(true)
.open(location.clone());
match file {
Ok(mut file) => {
info!("Creating new config file at {:?}.", location);
file.write_all(DEFAULT_CONFIG)?;
let file = OpenOptions::new().read(true).open(location.clone())?;
return Ok(ConfigData {
path: location,
file,
});
}
Err(e) => debug!(
"Tried to open a new file at '{:?}' but failed due to error: {}",
2020-07-05 00:41:55 +00:00
location, e
),
}
}
Err(BunBunError::NoValidConfigPath)
}
/// Assumes that the user knows what they're talking about and will only try
/// to load the config at the given path.
pub fn load_custom_path_config(
path: impl Into<PathBuf>,
) -> Result<ConfigData, BunBunError> {
let path = path.into();
2020-07-05 00:41:55 +00:00
let file = OpenOptions::new()
.read(true)
.open(&path)
.map_err(|e| BunBunError::InvalidConfigPath(path.clone(), e))?;
Ok(ConfigData { file, path })
}
2020-07-05 16:17:29 +00:00
pub fn read_config(
mut config_file: File,
large_config: bool,
) -> Result<Config, BunBunError> {
2020-07-05 03:32:08 +00:00
trace!("Loading config file.");
2020-07-05 16:17:29 +00:00
let file_size = config_file.metadata()?.len();
// 100 MB
2020-11-23 22:52:22 +00:00
if file_size > LARGE_FILE_SIZE_THRESHOLD && !large_config {
2020-07-05 16:17:29 +00:00
return Err(BunBunError::ConfigTooLarge(file_size));
}
if file_size == 0 {
return Err(BunBunError::ZeroByteConfig);
}
let mut config_data = String::new();
config_file.read_to_string(&mut config_data)?;
2019-12-29 05:48:02 +00:00
// Reading from memory is faster than reading directly from a reader for some
// reason; see https://github.com/serde-rs/json/issues/160
Ok(serde_yaml::from_str(&config_data)?)
2019-12-29 05:48:02 +00:00
}
2020-01-01 04:07:01 +00:00
#[cfg(test)]
mod route {
use super::*;
use serde_yaml::{from_str, to_string};
use tempfile::NamedTempFile;
#[test]
fn deserialize_relative_path() {
let tmpfile = NamedTempFile::new_in(".").unwrap();
let path = format!("{}", tmpfile.path().display());
let path = path.get(path.rfind(".").unwrap()..).unwrap();
let path = std::path::Path::new(path);
assert!(path.is_relative());
let path = path.to_str().unwrap();
2020-07-05 03:43:06 +00:00
assert_eq!(
from_str::<Route>(path).unwrap(),
Route::from_str(path).unwrap()
);
2020-01-01 04:07:01 +00:00
}
#[test]
fn deserialize_absolute_path() {
let tmpfile = NamedTempFile::new().unwrap();
let path = format!("{}", tmpfile.path().display());
assert!(tmpfile.path().is_absolute());
2020-07-05 03:43:06 +00:00
assert_eq!(
from_str::<Route>(&path).unwrap(),
Route::from_str(&path).unwrap()
);
2020-01-01 04:07:01 +00:00
}
#[test]
fn deserialize_http_path() {
assert_eq!(
from_str::<Route>("http://google.com").unwrap(),
2020-07-05 03:43:06 +00:00
Route::from_str("http://google.com").unwrap()
2020-01-01 04:07:01 +00:00
);
}
#[test]
fn deserialize_https_path() {
assert_eq!(
from_str::<Route>("https://google.com").unwrap(),
2020-07-05 03:43:06 +00:00
Route::from_str("https://google.com").unwrap()
2020-01-01 04:07:01 +00:00
);
}
#[test]
fn serialize() {
assert_eq!(
2020-07-05 03:43:06 +00:00
&to_string(&Route::from_str("hello world").unwrap()).unwrap(),
2020-09-27 21:02:43 +00:00
"---\nroute_type: External\npath: hello world\nhidden: false\ndescription: ~\nmin_args: ~\nmax_args: ~"
2020-01-01 04:07:01 +00:00
);
}
2020-07-05 00:41:55 +00:00
}
2020-11-23 22:52:22 +00:00
#[cfg(test)]
mod read_config {
use super::*;
#[test]
fn empty_file() {
let config_file = tempfile::tempfile().unwrap();
assert!(matches!(
read_config(config_file, false),
Err(BunBunError::ZeroByteConfig)
));
}
#[test]
fn config_too_large() {
let mut config_file = tempfile::tempfile().unwrap();
let size_to_write = (LARGE_FILE_SIZE_THRESHOLD + 1) as usize;
config_file.write(&[0].repeat(size_to_write)).unwrap();
match read_config(config_file, false) {
Err(BunBunError::ConfigTooLarge(size))
if size as usize == size_to_write => {}
Err(BunBunError::ConfigTooLarge(size)) => {
panic!("Mismatched size: {} != {}", size, size_to_write)
}
res => panic!("Wrong result, got {:#?}", res),
}
}
#[test]
fn valid_config() {
let config_file = File::open("bunbun.default.yaml").unwrap();
assert!(read_config(config_file, false).is_ok());
}
}