diff --git a/Cargo.lock b/Cargo.lock index bed9396..0a770a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -73,6 +73,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "either" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" + [[package]] name = "fnv" version = "1.0.7" @@ -200,7 +206,19 @@ name = "ip_reflect" version = "0.1.0" dependencies = [ "axum", + "hyper", + "itertools", "tokio", + "tower", +] + +[[package]] +name = "itertools" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf" +dependencies = [ + "either", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 0c54a83..10bfe31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,4 +7,9 @@ edition = "2018" [dependencies] axum = "0.2" -tokio = { version = "1", features = ["macros", "rt-multi-thread"] } \ No newline at end of file +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } + +[dev-dependencies] +tower = "0.4" +hyper = { version = "0.14", features = ["client"] } +itertools = "0.10" \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 894019f..1147918 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,7 +4,7 @@ use axum::extract::ConnectInfo; use axum::handler::get; use axum::http::{HeaderMap, HeaderValue}; use axum::{Router, Server}; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; #[tokio::main] async fn main() { @@ -32,7 +32,9 @@ async fn root(header: HeaderMap, ConnectInfo(socket_info): ConnectInfo().is_ok() { return res; } @@ -46,12 +48,202 @@ async fn root(header: HeaderMap, ConnectInfo(socket_info): ConnectInfo v, None => value, } .to_string(); + if value.parse::().is_ok() { + return value; + } } socket_info.ip().to_string() } + +#[cfg(test)] +mod tests { + use crate::root; + use axum::body::Body; + use axum::handler::get; + use axum::http::{Request, StatusCode}; + use axum::Router; + use itertools::Itertools; + use std::net::{SocketAddr, TcpListener}; + + #[tokio::test] + async fn benign_requests() { + let app = Router::new().route("/", get(root)); + + let listener = TcpListener::bind("0.0.0.0:0".parse::().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + axum::Server::from_tcp(listener) + .unwrap() + .serve(app.into_make_service_with_connect_info::()) + .await + .unwrap(); + }); + let uri = format!("http://{}", addr); + let client = hyper::Client::new(); + + // No headers + let resp = client + .request( + Request::builder() + .uri(format!("http://{}", addr)) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + + let body = hyper::body::to_bytes(resp.into_body()).await.unwrap(); + assert_eq!(body, "127.0.0.1"); + + // X-Forwarded-For + let ip_combinations = + IntoIterator::into_iter(["0.1.2.3", "255.255.255.0", "ff00::0", "ffab::1"]) + .powerset() + .skip(1) // skip empty set + .collect::>(); + for ips in &ip_combinations { + let first = ips[0]; + let resp = client + .request( + Request::builder() + .uri(&uri) + .header("X-Forwarded-For", ips.join(", ")) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + + let body = hyper::body::to_bytes(resp.into_body()).await.unwrap(); + assert_eq!(body, first); + } + + // Forwarded + let ip_combinations = + IntoIterator::into_iter(["0.1.2.3", "255.255.255.0", "\"[ff00::0]\"", "\"[ffab::1]\""]) + .powerset() + .skip(1) // skip empty set + .collect::>(); + for ips in &ip_combinations { + let first = ips[0]; + let resp = client + .request( + Request::builder() + .uri(&uri) + .header( + "Forwarded", + ips.into_iter().map(|v| format!("for={}", v)).join(", "), + ) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + + let body = hyper::body::to_bytes(resp.into_body()).await.unwrap(); + if matches!(first.chars().next(), Some('"')) { + assert_eq!(body, first[2..first.len() - 2]); + } else { + assert_eq!(body, first); + } + } + } + + macro_rules! assert_passthrough { + ($resp:expr) => { + let body = hyper::body::to_bytes($resp.into_body()).await.unwrap(); + assert_eq!(body, "127.0.0.1"); + }; + } + + #[tokio::test] + async fn malicious_requests() { + let app = Router::new().route("/", get(root)); + + let listener = TcpListener::bind("0.0.0.0:0".parse::().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + axum::Server::from_tcp(listener) + .unwrap() + .serve(app.into_make_service_with_connect_info::()) + .await + .unwrap(); + }); + let uri = format!("http://{}", addr); + let client = hyper::Client::new(); + + // X-Forwarded-For incomplete IPv4 + assert_passthrough!(client + .request( + Request::builder() + .uri(&uri) + .header("X-Forwarded-For", "1.2.3") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap()); + + // X-Forwarded-For incomplete IPv6 + assert_passthrough!(client + .request( + Request::builder() + .uri(&uri) + .header("X-Forwarded-For", "1:4") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap()); + + // X-Forwarded-For bad ipv4 + assert_passthrough!(client + .request( + Request::builder() + .uri(&uri) + .header("X-Forwarded-For", "256.0.0.1") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap()); + + // X-Forwarded-For bad ipv6 + assert_passthrough!(client + .request( + Request::builder() + .uri(&uri) + .header("X-Forwarded-For", "ff::ff::ff") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap()); + + // nonsensical X-Forwarded-For + assert_passthrough!(client + .request( + Request::builder() + .uri(&uri) + .header("X-Forwarded-For", "hello world") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap()); + } +}