use crate::api::methods::{login::*, sync::SyncResponse}; use reqwest::{ header::{HeaderMap, HeaderValue, CONTENT_TYPE, USER_AGENT}, Client as reqwest_client, Response, StatusCode, }; use serde::Deserialize; use serde_json::{json, Value}; use std::{ collections::HashMap, error::Error, fmt, sync::{ atomic::{AtomicBool, Ordering}, Arc, }, thread, time, }; use url::{ParseError, Url}; #[cfg(test)] use mockito; const SUPPORTED_VERSION: &str = "r0.5.0"; #[derive(Debug)] pub enum MatrixParseError { ParseError(ParseError), EmptyScheme, } pub enum PresenceState { Offline, Online, Unavailable, } impl From for MatrixParseError { fn from(parse_error: ParseError) -> Self { MatrixParseError::ParseError(parse_error) } } #[derive(Debug)] pub struct ResponseError { code: StatusCode, content: String, } impl fmt::Display for ResponseError { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "{}: {}", self.code, self.content) } } impl Error for ResponseError {} pub enum MatrixHTTPMethod { Get, Put, Delete, Post, } #[derive(Debug)] pub struct Client { homeserver_url: String, access_token: Option, mxid: Option, default_492_wait_ms: u64, reqwest_client: reqwest_client, } #[derive(Deserialize)] /// Response struct for [Section 2.1 **GET** /_matrix/client/versions](https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-versions). pub struct SupportedSpecs { pub versions: Vec, pub unstable_features: Option>, } #[derive(Deserialize)] pub struct DiscoveryInformation { #[serde(rename = "m.homeserver")] pub homeserver: HomeserverInformation, pub identity_server: Option, } #[derive(Deserialize)] pub struct HomeserverInformation { pub base_url: String, } #[derive(Deserialize)] pub struct IdentityServerInformation { pub base_url: String, } impl Client { pub fn new( homeserver_url: &str, access_token: Option, mxid: Option, default_492_wait_ms: Option, ) -> Result> { let url = Url::parse(homeserver_url)?; if url.scheme().is_empty() { panic!("todo: implement handling"); } let client = Client { homeserver_url: homeserver_url.to_string(), access_token, mxid, default_492_wait_ms: default_492_wait_ms.unwrap_or_else(|| 5000), reqwest_client: reqwest_client::new(), }; if !client .supported_versions()? .versions .contains(&SUPPORTED_VERSION.to_string()) { // TODO: Implement proper response panic!("server version doesn't support client"); } Ok(client) } /// Implementation of [Section 4.1.1 **GET** `/.well-known/matrix/client`](https://matrix.org/docs/spec/client_server/r0.5.0#get-well-known-matrix-client). /// /// This does not implement the expected behavior a client *should* perform /// as described in [Section 4.1 Well-Known URI](https://matrix.org/docs/spec/client_server/r0.5.0#well-known-uri). /// Higher level clients must implement this behavior. pub fn discovery_information(&self) -> Result> { Ok(self .send( MatrixHTTPMethod::Get, "/.well-known/matrix/client", None, None, None, )? .json()?) } /// Implementation of [Section 2.1 **GET** `/_matrix/client/versions`](https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-versions). /// /// Returns a list of matrix specifications a server supports, as well as /// a map of unstable features the server has advertised. This is not rate- /// limited nor requires authorization. pub fn supported_versions(&self) -> Result> { Ok(self .send( MatrixHTTPMethod::Get, "/_matrix/client/versions", None, None, None, )? .json()?) } /// Sends an API request to the homeserver using the specified method and /// path, returning either an error or the response text. /// /// The header will automatically be populated with a user agent and have /// the content type set to `application/json`. If a token was provided, it /// will be used for the Authorization header. /// /// This is a blocking, synchronous send. If the response from the /// homeserver indicates that too many requests were sent, it will attempt /// to wait the specified duration (or a provided default) before retrying. /// TODO: Make async fn send( &self, method: MatrixHTTPMethod, path: &str, content: Option, query_params: Option>, headers: Option, ) -> Result> { let mut query_params = query_params.unwrap_or_default(); let mut headers = headers.unwrap_or_default(); #[cfg(test)] let url = &mockito::server_url(); #[cfg(not(test))] let url = &self.homeserver_url; let endpoint = &format!("{}{}", url, path); let mut request = match method { MatrixHTTPMethod::Get => self.reqwest_client.get(endpoint), MatrixHTTPMethod::Put => self.reqwest_client.put(endpoint), MatrixHTTPMethod::Delete => self.reqwest_client.delete(endpoint), MatrixHTTPMethod::Post => self.reqwest_client.post(endpoint), }; if !headers.contains_key(&USER_AGENT) { let user_agent = &format!("libmatrix-client/{}", env!("CARGO_PKG_VERSION")); headers.insert(USER_AGENT, HeaderValue::from_str(user_agent)?); } if !headers.contains_key(&CONTENT_TYPE) { headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json")?); } if let Some(token) = &self.access_token { request = request.bearer_auth(token); } if let Some(id) = &self.mxid { query_params.insert("user_id".to_string(), id.to_string()); } request = request.headers(headers).query(&query_params); if let Some(content) = content { request = request.body(content.to_string()); } loop { let mut res = request .try_clone() .expect("Unable to clone request") .send()?; if res.status().is_success() { return Ok(res); } else if res.status() == StatusCode::TOO_MANY_REQUESTS { let mut body: &Value = &res.json()?; if let Some(value) = body.get("error") { body = value; } std::thread::sleep(time::Duration::from_millis( body.get("retry_after_ms").map_or_else( || self.default_492_wait_ms, |v| v.as_u64().unwrap_or_else(|| self.default_492_wait_ms), ), )); } else { return Err(Box::from(ResponseError { code: res.status(), content: res.text()?, })); } } } /// Helper method for sending query-based requests. fn send_query( &self, method: MatrixHTTPMethod, path: &str, query_params: HashMap, ) -> Result> { self.send(method, path, None, Some(query_params), None) } pub fn sync( &self, bookmark_token: Option<&str>, timeout_ms: Option, filter: Option<&str>, get_full_state: Option, set_presence: Option, ) -> Result> { let mut params: HashMap = HashMap::with_capacity(5); params.insert( "timeout".to_string(), timeout_ms.unwrap_or_else(|| 30000).to_string(), ); if let Some(token) = bookmark_token { params.insert("since".to_string(), token.to_string()); } if let Some(filter) = filter { params.insert("filter".to_string(), filter.to_string()); } if let Some(true) = get_full_state { params.insert("full_state".to_string(), "true".to_string()); } params.insert( "full_state".to_string(), match set_presence { Some(PresenceState::Online) => "online", Some(PresenceState::Unavailable) => "unavailable", None | Some(PresenceState::Offline) => "offline", } .to_string(), ); Ok(self .send_query(MatrixHTTPMethod::Get, "/sync", params)? .json()?) } } // Holds implementation for Section 5.4, Login impl Client { /// Implementation of [Section 5.4.1 **GET** `/_matrix/client/r0/login`](https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-login). /// /// This request is rate-limited but does not require authentication. pub fn login_flows(&self) -> Result> { Ok(self .send( MatrixHTTPMethod::Get, "/_matrix/client/r0/login", None, None, None, )? .json()?) } /// Implementation of password-based login for /// [5.4.2 **POST** `/_matrix/client/r0/login`](https://matrix.org/docs/spec/client_server/r0.5.0#post-matrix-client-r0-login) /// /// If a specific device ID is specified, then the server will attempt to /// use it. If an existing device ID is used, then any previously assigned /// access tokens to that device ID may be invalidated. If the device ID /// does not exist, then it will be created. If none is provided, then the /// server will generate one. /// /// You may specify a device display name if the provided device ID is not /// known. It is ignored if the device ID already exists. /// /// This request is rate-limited but does not require authentication. pub fn login_password( &self, identifier: IdentifierType, password: &str, device_id: Option<&str>, initial_device_display_name: Option<&str>, ) -> Result> { self.login( json!({ "type": "m.login.password", "password": password, }), identifier, device_id, initial_device_display_name, ) } /// Implementation of token-based login for /// [5.4.2 **POST** `/_matrix/client/r0/login`](https://matrix.org/docs/spec/client_server/r0.5.0#post-matrix-client-r0-login) /// /// If a specific device ID is specified, then the server will attempt to /// use it. If an existing device ID is used, then any previously assigned /// access tokens to that device ID may be invalidated. If the device ID /// does not exist, then it will be created. If none is provided, then the /// server will generate one. /// /// You may specify a device display name if the provided device ID is not /// known. It is ignored if the device ID already exists. /// /// This request is rate-limited but does not require authentication. pub fn login_token( &self, identifier: IdentifierType, password: &str, device_id: Option<&str>, initial_device_display_name: Option<&str>, ) -> Result> { self.login( json!({ "type": "m.login.token", "token": password, }), identifier, device_id, initial_device_display_name, ) } /// Actual method that performs the login request. Accepts a body that /// should already be populated with the login-specific values needed for /// that particular login type. fn login( &self, body: Value, identifier: IdentifierType, device_id: Option<&str>, initial_device_display_name: Option<&str>, ) -> Result> { let mut body = body; body["identifier"] = json!(identifier); if let Some(id) = device_id { body["device_id"] = json!(id); if let Some(display_name) = initial_device_display_name { body["initial_device_display_name"] = json!(display_name); } } Ok(self .send( MatrixHTTPMethod::Post, "/_matrix/client/r0/login", Some(body), None, None, )? .json()?) } } #[cfg(test)] mod tests { use super::*; use mockito::mock; #[test] fn client_init_properly() {} #[test] fn supported_versions_complete_resp() { let _m = mock("GET", "/_matrix/client/versions") .with_body( r#"{ "versions": ["r0.4.0", "r0.5.0"], "unstable_features": { "m.lazy_load_members": true } }"#, ) .create(); // "valid" location must be supplied as reqwest attempts to parse it. let resp = Client::new("http://dummy.website", None, None, None) .unwrap() .supported_versions() .unwrap(); assert_eq!(resp.versions, vec!["r0.4.0", "r0.5.0"]); assert!(resp .unstable_features .unwrap_or_default() .get("m.lazy_load_members") .unwrap_or_else(|| &false)); } #[test] fn supported_versions_just_versions() { let _m = mock("GET", "/_matrix/client/versions") .with_body( r#"{ "versions": ["r0.4.0", "r0.5.0"] }"#, ) .create(); // "valid" location must be supplied as reqwest attempts to parse it. let resp = Client::new("http://dummy.website", None, None, None) .unwrap() .supported_versions() .unwrap(); assert_eq!(resp.versions, vec!["r0.4.0", "r0.5.0"]); assert!(resp.unstable_features.is_none()); } #[test] fn send_handles_too_many_requests() { let _mock_version_check = mock("GET", "/_matrix/client/versions") .with_body(r#"{"versions": ["r0.5.0"]}"#) .create(); let _m2 = mock("GET", "/hello") .with_body( r#"{ "errcode": "M_LIMIT_EXCEEDED", "error": "Too many requests", "retry_after_ms": 2000 }"#, ) .with_status(429) .create(); // Run request in separate thread. Once async lands use async/await thread::spawn(|| { let resp: Value = Client::new("http://dummy.website", None, None, None) .expect("failed to get valid client") .send(MatrixHTTPMethod::Get, "/hello", None, None, None) .expect("failed to unwrap Response") .json() .unwrap(); assert_eq!(dbg!(resp)["hello"], "world"); }); // Override 429 response with valid response after initial request // returns a 429. This should really be done with a synchronization // primitive (e.g. condvar) but I don't know how to do that. thread::sleep(time::Duration::from_secs(1)); let _m = mock("GET", "/hello") .with_body(r#"{"hello": "world"}"#) .create(); } } #[derive(Default)] pub struct ApiError {} impl fmt::Display for ApiError { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "an error occurred!") } }