#![warn(clippy::pedantic, clippy::nursery)] use axum::extract::ConnectInfo; use axum::handler::get; use axum::http::{HeaderMap, HeaderValue}; use axum::{Router, Server}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; #[tokio::main] async fn main() { let app = Router::new().route("/", get(root)); let addr = SocketAddr::from(([0, 0, 0, 0], 3000)); Server::bind(&addr) .serve(app.into_make_service_with_connect_info::()) .await .unwrap(); } #[allow(clippy::unused_async)] async fn root(header: HeaderMap, ConnectInfo(socket_info): ConnectInfo) -> String { if let Some(Ok(value)) = header.get("Forwarded").map(HeaderValue::to_str) { let match_str = "or="; // the `for` key is case-insensitive let maybe_index = value.find(match_str); if let Some(index) = maybe_index { let ip_start = value.split_at(index + match_str.len()).1; let mut chars = ip_start.chars().peekable(); if let Some('"') = chars.peek() { // We're dealing with an ipv6 address // skip first character, since we know it's an `"` chars.next(); // skip next character, since we know it's an [ chars.next(); let res: String = chars.take_while(|c| *c != ']').collect(); if res.parse::().is_ok() { return res; } } else { let res: String = chars.take_while(|c| *c != ',').collect(); if res.parse::().is_ok() { return res; } } } } if let Some(Ok(value)) = header.get("X-Forwarded-For").map(HeaderValue::to_str) { let value = match value.split_once(',').map(|v| v.0) { Some(v) => v, None => value, } .to_string(); if value.parse::().is_ok() { return value; } } socket_info.ip().to_string() } #[cfg(test)] mod tests { use crate::root; use axum::body::Body; use axum::handler::get; use axum::http::{Request, StatusCode}; use axum::Router; use itertools::Itertools; use std::net::{SocketAddr, TcpListener}; #[tokio::test] async fn benign_requests() { let app = Router::new().route("/", get(root)); let listener = TcpListener::bind("0.0.0.0:0".parse::().unwrap()).unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { axum::Server::from_tcp(listener) .unwrap() .serve(app.into_make_service_with_connect_info::()) .await .unwrap(); }); let uri = format!("http://{}", addr); let client = hyper::Client::new(); // No headers let resp = client .request( Request::builder() .uri(format!("http://{}", addr)) .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(resp.status(), StatusCode::OK); let body = hyper::body::to_bytes(resp.into_body()).await.unwrap(); assert_eq!(body, "127.0.0.1"); // X-Forwarded-For let ip_combinations = IntoIterator::into_iter(["0.1.2.3", "255.255.255.0", "ff00::0", "ffab::1"]) .powerset() .skip(1) // skip empty set .collect::>(); for ips in &ip_combinations { let first = ips[0]; let resp = client .request( Request::builder() .uri(&uri) .header("X-Forwarded-For", ips.join(", ")) .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(resp.status(), StatusCode::OK); let body = hyper::body::to_bytes(resp.into_body()).await.unwrap(); assert_eq!(body, first); } // Forwarded let ip_combinations = IntoIterator::into_iter(["0.1.2.3", "255.255.255.0", "\"[ff00::0]\"", "\"[ffab::1]\""]) .powerset() .skip(1) // skip empty set .collect::>(); for ips in &ip_combinations { let first = ips[0]; let resp = client .request( Request::builder() .uri(&uri) .header( "Forwarded", ips.into_iter().map(|v| format!("for={}", v)).join(", "), ) .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(resp.status(), StatusCode::OK); let body = hyper::body::to_bytes(resp.into_body()).await.unwrap(); if matches!(first.chars().next(), Some('"')) { assert_eq!(body, first[2..first.len() - 2]); } else { assert_eq!(body, first); } } } macro_rules! assert_passthrough { ($resp:expr) => { let body = hyper::body::to_bytes($resp.into_body()).await.unwrap(); assert_eq!(body, "127.0.0.1"); }; } #[tokio::test] async fn malicious_requests() { let app = Router::new().route("/", get(root)); let listener = TcpListener::bind("0.0.0.0:0".parse::().unwrap()).unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { axum::Server::from_tcp(listener) .unwrap() .serve(app.into_make_service_with_connect_info::()) .await .unwrap(); }); let uri = format!("http://{}", addr); let client = hyper::Client::new(); // X-Forwarded-For incomplete IPv4 assert_passthrough!(client .request( Request::builder() .uri(&uri) .header("X-Forwarded-For", "1.2.3") .body(Body::empty()) .unwrap(), ) .await .unwrap()); // X-Forwarded-For incomplete IPv6 assert_passthrough!(client .request( Request::builder() .uri(&uri) .header("X-Forwarded-For", "1:4") .body(Body::empty()) .unwrap(), ) .await .unwrap()); // X-Forwarded-For bad ipv4 assert_passthrough!(client .request( Request::builder() .uri(&uri) .header("X-Forwarded-For", "256.0.0.1") .body(Body::empty()) .unwrap(), ) .await .unwrap()); // X-Forwarded-For bad ipv6 assert_passthrough!(client .request( Request::builder() .uri(&uri) .header("X-Forwarded-For", "ff::ff::ff") .body(Body::empty()) .unwrap(), ) .await .unwrap()); // nonsensical X-Forwarded-For assert_passthrough!(client .request( Request::builder() .uri(&uri) .header("X-Forwarded-For", "hello world") .body(Body::empty()) .unwrap(), ) .await .unwrap()); } }