use crate::api::methods::sync::SyncResponse; use reqwest::{ header::{HeaderMap, HeaderValue, CONTENT_TYPE, USER_AGENT}, Client as reqwest_client, Response, StatusCode, }; use serde::Deserialize; use std::{collections::HashMap, error::Error, fmt, time}; use url::{ParseError, Url}; #[cfg(test)] use mockito; const V2_API_PATH: &str = "/_matrix/client/r0"; const SUPPORTED_VERSION: &str = "r0.5.0"; #[derive(Debug)] pub enum MatrixParseError { ParseError(ParseError), EmptyScheme, } pub enum PresenceState { Offline, Online, Unavailable, } impl From for MatrixParseError { fn from(parse_error: ParseError) -> Self { MatrixParseError::ParseError(parse_error) } } #[derive(Debug)] pub struct ResponseError { code: StatusCode, content: String, } impl fmt::Display for ResponseError { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "{}: {}", self.code, self.content) } } impl Error for ResponseError {} pub enum MatrixHTTPMethod { Get, Put, Delete, Post, } #[derive(Debug)] pub struct Client { homeserver_url: String, access_token: Option, mxid: Option, default_492_wait_ms: u64, reqwest_client: reqwest_client, } #[derive(Deserialize)] /// Response struct for [Section 2.1 **GET** /_matrix/client/versions](https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-versions). pub struct SupportedSpecs { pub versions: Vec, pub unstable_features: Option>, } impl Client { pub fn new( homeserver_url: &str, access_token: Option, mxid: Option, default_492_wait_ms: Option, ) -> Result> { let url = Url::parse(homeserver_url)?; if url.scheme().is_empty() { panic!("todo: implement handling"); } let client = Client { homeserver_url: homeserver_url.to_string(), access_token, mxid, default_492_wait_ms: default_492_wait_ms.unwrap_or_else(|| 5000), reqwest_client: reqwest_client::new(), }; if !client .supported_versions()? .versions .contains(&SUPPORTED_VERSION.to_string()) { // TODO: Implement proper response panic!("server version doesn't support client"); } Ok(client) } /// Implementation of [Section 2.1 **GET** /_matrix/client/versions](https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-versions). /// /// Returns a list of matrix specifications a server supports, as well as /// a map of unstable features the server has advertised. pub fn supported_versions(&self) -> Result> { Ok(self .send( MatrixHTTPMethod::Get, Some("/_matrix/client/versions"), None, None, None, )? .json()?) } /// Sends an API request to the homeserver using the specified method and /// path, returning either an error or the response text. /// /// The header will automatically be populated with a user agent and have /// the content type set to `application/json`. If a token was provided, it /// will be used for the Authorization header. /// /// This is a blocking, synchronous send. If the response from the /// homeserver indicates that too many requests were sent, it will attempt /// to wait the specified duration (or a provided default) before retrying. /// TODO: Make async fn send( &self, method: MatrixHTTPMethod, path: Option<&str>, content: Option, query_params: Option>, headers: Option, ) -> Result> { let mut query_params = query_params.unwrap_or_default(); let mut headers = headers.unwrap_or_default(); #[cfg(test)] let url = &mockito::server_url(); #[cfg(not(test))] let url = &self.homeserver_url; let endpoint = &format!("{}{}", url, path.unwrap_or_else(|| V2_API_PATH)); let mut request = match method { MatrixHTTPMethod::Get => self.reqwest_client.get(endpoint), MatrixHTTPMethod::Put => self.reqwest_client.put(endpoint), MatrixHTTPMethod::Delete => self.reqwest_client.delete(endpoint), MatrixHTTPMethod::Post => self.reqwest_client.post(endpoint), }; if !headers.contains_key(&USER_AGENT) { let user_agent = &format!("libmatrix-client/{}", env!("CARGO_PKG_VERSION")); headers.insert(USER_AGENT, HeaderValue::from_str(user_agent)?); } if !headers.contains_key(&CONTENT_TYPE) { headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json")?); } if let Some(token) = &self.access_token { request = request.bearer_auth(token); } if let Some(id) = &self.mxid { query_params.insert("user_id".to_string(), id.to_string()); } request = request.headers(headers).query(&query_params); if let Some(content) = content { request = request.body(content); } loop { let mut res = request .try_clone() .expect("Unable to clone request") .send()?; if res.status().is_success() { return Ok(res); } else if res.status() == StatusCode::TOO_MANY_REQUESTS { let mut body: HashMap = res.json()?; if let Some(value) = body.get("error") { body = serde_json::from_str(value)?; } if let Some(value) = body.get("retry_after_ms") { std::thread::sleep(time::Duration::from_millis(value.parse::()?)); } else { std::thread::sleep(time::Duration::from_millis(self.default_492_wait_ms)); } } else { return Err(Box::from(ResponseError { code: res.status(), content: res.text()?, })); } } } /// Helper method for sending query-based requests. fn send_query( &self, method: MatrixHTTPMethod, path: &str, query_params: HashMap, ) -> Result> { self.send(method, Some(path), None, Some(query_params), None) } pub fn sync( &self, bookmark_token: Option<&str>, timeout_ms: Option, filter: Option<&str>, get_full_state: Option, set_presence: Option, ) -> Result> { let mut params: HashMap = HashMap::with_capacity(5); params.insert( "timeout".to_string(), timeout_ms.unwrap_or_else(|| 30000).to_string(), ); if let Some(token) = bookmark_token { params.insert("since".to_string(), token.to_string()); } if let Some(filter) = filter { params.insert("filter".to_string(), filter.to_string()); } if let Some(true) = get_full_state { params.insert("full_state".to_string(), "true".to_string()); } params.insert( "full_state".to_string(), match set_presence { Some(PresenceState::Online) => "online", Some(PresenceState::Unavailable) => "unavailable", None | Some(PresenceState::Offline) => "offline", } .to_string(), ); Ok(self .send_query(MatrixHTTPMethod::Get, "/sync", params)? .json()?) } } #[cfg(test)] mod tests { use super::*; use mockito::mock; #[test] fn client_init_properly() {} #[test] fn supported_versions_complete_resp() { let _m = mock("GET", "/_matrix/client/versions") .with_body( r#"{ "versions": ["r0.4.0", "r0.5.0"], "unstable_features": { "m.lazy_load_members": true } }"#, ) .create(); // "valid" location must be supplied as reqwest attempts to parse it. let resp = Client::new("http://dummy.website", None, None, None) .unwrap() .supported_versions() .unwrap(); assert_eq!(resp.versions, vec!["r0.4.0", "r0.5.0"]); assert!(resp .unstable_features .unwrap_or_default() .get("m.lazy_load_members") .unwrap_or_else(|| &false)); } #[test] fn supported_versions_just_versions() { let _m = mock("GET", "/_matrix/client/versions") .with_body( r#"{ "versions": ["r0.4.0", "r0.5.0"] }"#, ) .create(); // "valid" location must be supplied as reqwest attempts to parse it. let resp = Client::new("http://dummy.website", None, None, None) .unwrap() .supported_versions() .unwrap(); assert_eq!(resp.versions, vec!["r0.4.0", "r0.5.0"]); assert!(resp.unstable_features.is_none()); } } #[derive(Default)] pub struct ApiError {} impl fmt::Display for ApiError { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "an error occurred!") } }