use crate::BunBunError; use dirs::{config_dir, home_dir}; use log::{debug, info, trace}; use serde::{ de::{self, Deserializer, MapAccess, Unexpected, Visitor}, Deserialize, Serialize, }; use std::collections::HashMap; use std::fmt; use std::fs::{File, OpenOptions}; use std::io::{Read, Write}; use std::path::PathBuf; use std::str::FromStr; const CONFIG_FILENAME: &str = "bunbun.yaml"; const DEFAULT_CONFIG: &[u8] = include_bytes!("../bunbun.default.yaml"); #[cfg(not(test))] const LARGE_FILE_SIZE_THRESHOLD: u64 = 100_000_000; #[cfg(test)] const LARGE_FILE_SIZE_THRESHOLD: u64 = 1_000_000; #[derive(Deserialize, Debug, PartialEq)] pub struct Config { pub bind_address: String, pub public_address: String, pub default_route: Option, pub groups: Vec, } #[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] pub struct RouteGroup { pub name: String, pub description: Option, #[serde(default)] pub hidden: bool, pub routes: HashMap, } #[derive(Debug, PartialEq, Clone, Serialize)] pub struct Route { pub route_type: RouteType, pub path: String, pub hidden: bool, pub description: Option, pub min_args: Option, pub max_args: Option, } impl FromStr for Route { type Err = std::convert::Infallible; fn from_str(s: &str) -> Result { Ok(Self { route_type: get_route_type(s), path: s.to_string(), hidden: false, description: None, min_args: None, max_args: None, }) } } /// Deserialization of the route string into the enum requires us to figure out /// whether or not the string is valid to run as an executable or not. To /// determine this, we simply check if it exists on disk or assume that it's a /// web path. This incurs a disk check operation, but since users shouldn't be /// updating the config that frequently, it should be fine. impl<'de> Deserialize<'de> for Route { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { #[derive(Deserialize)] #[serde(field_identifier, rename_all = "snake_case")] enum Field { Path, Hidden, Description, MinArgs, MaxArgs, } struct RouteVisitor; impl<'de> Visitor<'de> for RouteVisitor { type Value = Route; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("string") } fn visit_str(self, path: &str) -> Result where E: serde::de::Error, { // This is infallible Ok(Self::Value::from_str(path).unwrap()) } fn visit_map(self, mut map: M) -> Result where M: MapAccess<'de>, { let mut path = None; let mut hidden = None; let mut description = None; let mut min_args = None; let mut max_args = None; while let Some(key) = map.next_key()? { match key { Field::Path => { if path.is_some() { return Err(de::Error::duplicate_field("path")); } path = Some(map.next_value::()?); } Field::Hidden => { if hidden.is_some() { return Err(de::Error::duplicate_field("hidden")); } hidden = map.next_value()?; } Field::Description => { if description.is_some() { return Err(de::Error::duplicate_field("description")); } description = Some(map.next_value()?); } Field::MinArgs => { if min_args.is_some() { return Err(de::Error::duplicate_field("min_args")); } min_args = Some(map.next_value()?); } Field::MaxArgs => { if max_args.is_some() { return Err(de::Error::duplicate_field("max_args")); } max_args = Some(map.next_value()?); } } } if let (Some(min_args), Some(max_args)) = (min_args, max_args) { if min_args > max_args { { return Err(de::Error::invalid_value( Unexpected::Other(&format!( "argument count range {} to {}", min_args, max_args )), &"a valid argument count range", )); } } } let path = path.ok_or_else(|| de::Error::missing_field("path"))?; Ok(Route { route_type: get_route_type(&path), path, hidden: hidden.unwrap_or_default(), description, min_args, max_args, }) } } deserializer.deserialize_any(RouteVisitor) } } impl std::fmt::Display for Route { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self { route_type: RouteType::External, path, .. } => write!(f, "raw ({})", path), Self { route_type: RouteType::Internal, path, .. } => write!(f, "file ({})", path), } } } /// Classifies the path depending on if the there exists a local file. fn get_route_type(path: &str) -> RouteType { if std::path::Path::new(path).exists() { debug!("Parsed {} as a valid local path.", path); RouteType::Internal } else { debug!("{} does not exist on disk, assuming web path.", path); RouteType::External } } /// There exists two route types: an external path (e.g. a URL) or an internal /// path (to a file). #[derive(Debug, PartialEq, Clone, Serialize)] pub enum RouteType { External, Internal, } pub struct ConfigData { pub path: PathBuf, pub file: File, } /// If a provided config path isn't found, this function checks known good /// locations for a place to write a config file to. In order, it checks the /// system-wide config location (`/etc/`, in Linux), followed by the config /// folder, followed by the user's home folder. pub fn get_config_data() -> Result { // Locations to check, with highest priority first let locations: Vec<_> = { let mut folders = vec![PathBuf::from("/etc/")]; // Config folder if let Some(folder) = config_dir() { folders.push(folder) } // Home folder if let Some(folder) = home_dir() { folders.push(folder) } folders .iter_mut() .for_each(|folder| folder.push(CONFIG_FILENAME)); folders }; debug!("Checking locations for config file: {:?}", &locations); for location in &locations { let file = OpenOptions::new().read(true).open(location.clone()); match file { Ok(file) => { debug!("Found file at {:?}.", location); return Ok(ConfigData { path: location.clone(), file, }); } Err(e) => debug!( "Tried to read '{:?}' but failed due to error: {}", location, e ), } } debug!("Failed to find any config. Now trying to find first writable path"); // If we got here, we failed to read any file paths, meaning no config exists // yet. In that case, try to return the first location that we can write to, // after writing the default config for location in locations { let file = OpenOptions::new() .write(true) .create_new(true) .open(location.clone()); match file { Ok(mut file) => { info!("Creating new config file at {:?}.", location); file.write_all(DEFAULT_CONFIG)?; let file = OpenOptions::new().read(true).open(location.clone())?; return Ok(ConfigData { path: location, file, }); } Err(e) => debug!( "Tried to open a new file at '{:?}' but failed due to error: {}", location, e ), } } Err(BunBunError::NoValidConfigPath) } /// Assumes that the user knows what they're talking about and will only try /// to load the config at the given path. pub fn load_custom_path_config( path: impl Into, ) -> Result { let path = path.into(); let file = OpenOptions::new() .read(true) .open(&path) .map_err(|e| BunBunError::InvalidConfigPath(path.clone(), e))?; Ok(ConfigData { file, path }) } pub fn read_config( mut config_file: File, large_config: bool, ) -> Result { trace!("Loading config file."); let file_size = config_file.metadata()?.len(); // 100 MB if file_size > LARGE_FILE_SIZE_THRESHOLD && !large_config { return Err(BunBunError::ConfigTooLarge(file_size)); } if file_size == 0 { return Err(BunBunError::ZeroByteConfig); } let mut config_data = String::new(); config_file.read_to_string(&mut config_data)?; // Reading from memory is faster than reading directly from a reader for some // reason; see https://github.com/serde-rs/json/issues/160 Ok(serde_yaml::from_str(&config_data)?) } #[cfg(test)] mod route { use super::*; use serde_yaml::{from_str, to_string}; use tempfile::NamedTempFile; #[test] fn deserialize_relative_path() { let tmpfile = NamedTempFile::new_in(".").unwrap(); let path = format!("{}", tmpfile.path().display()); let path = path.get(path.rfind(".").unwrap()..).unwrap(); let path = std::path::Path::new(path); assert!(path.is_relative()); let path = path.to_str().unwrap(); assert_eq!( from_str::(path).unwrap(), Route::from_str(path).unwrap() ); } #[test] fn deserialize_absolute_path() { let tmpfile = NamedTempFile::new().unwrap(); let path = format!("{}", tmpfile.path().display()); assert!(tmpfile.path().is_absolute()); assert_eq!( from_str::(&path).unwrap(), Route::from_str(&path).unwrap() ); } #[test] fn deserialize_http_path() { assert_eq!( from_str::("http://google.com").unwrap(), Route::from_str("http://google.com").unwrap() ); } #[test] fn deserialize_https_path() { assert_eq!( from_str::("https://google.com").unwrap(), Route::from_str("https://google.com").unwrap() ); } #[test] fn serialize() { assert_eq!( &to_string(&Route::from_str("hello world").unwrap()).unwrap(), "---\nroute_type: External\npath: hello world\nhidden: false\ndescription: ~\nmin_args: ~\nmax_args: ~" ); } } #[cfg(test)] mod read_config { use super::*; #[test] fn empty_file() { let config_file = tempfile::tempfile().unwrap(); assert!(matches!( read_config(config_file, false), Err(BunBunError::ZeroByteConfig) )); } #[test] fn config_too_large() { let mut config_file = tempfile::tempfile().unwrap(); let size_to_write = (LARGE_FILE_SIZE_THRESHOLD + 1) as usize; config_file.write(&[0].repeat(size_to_write)).unwrap(); match read_config(config_file, false) { Err(BunBunError::ConfigTooLarge(size)) if size as usize == size_to_write => {} Err(BunBunError::ConfigTooLarge(size)) => { panic!("Mismatched size: {} != {}", size, size_to_write) } res => panic!("Wrong result, got {:#?}", res), } } #[test] fn valid_config() { let config_file = File::open("bunbun.default.yaml").unwrap(); assert!(read_config(config_file, false).is_ok()); } }