250 lines
7.7 KiB
Rust
250 lines
7.7 KiB
Rust
#![warn(clippy::pedantic, clippy::nursery)]
|
|
|
|
use axum::extract::ConnectInfo;
|
|
use axum::handler::get;
|
|
use axum::http::{HeaderMap, HeaderValue};
|
|
use axum::{Router, Server};
|
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
let app = Router::new().route("/", get(root));
|
|
|
|
let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
|
|
Server::bind(&addr)
|
|
.serve(app.into_make_service_with_connect_info::<SocketAddr, _>())
|
|
.await
|
|
.unwrap();
|
|
}
|
|
|
|
#[allow(clippy::unused_async)]
|
|
async fn root(header: HeaderMap, ConnectInfo(socket_info): ConnectInfo<SocketAddr>) -> String {
|
|
if let Some(Ok(value)) = header.get("Forwarded").map(HeaderValue::to_str) {
|
|
let match_str = "or="; // the `for` key is case-insensitive
|
|
let maybe_index = value.find(match_str);
|
|
if let Some(index) = maybe_index {
|
|
let ip_start = value.split_at(index + match_str.len()).1;
|
|
|
|
let mut chars = ip_start.chars().peekable();
|
|
|
|
if let Some('"') = chars.peek() {
|
|
// We're dealing with an ipv6 address
|
|
|
|
// skip first character, since we know it's an `"`
|
|
chars.next();
|
|
// skip next character, since we know it's an [
|
|
chars.next();
|
|
let res: String = chars.take_while(|c| *c != ']').collect();
|
|
if res.parse::<Ipv6Addr>().is_ok() {
|
|
return res;
|
|
}
|
|
} else {
|
|
let res: String = chars.take_while(|c| *c != ',').collect();
|
|
if res.parse::<Ipv4Addr>().is_ok() {
|
|
return res;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Some(Ok(value)) = header.get("X-Forwarded-For").map(HeaderValue::to_str) {
|
|
let value = match value.split_once(',').map(|v| v.0) {
|
|
Some(v) => v,
|
|
None => value,
|
|
}
|
|
.to_string();
|
|
if value.parse::<IpAddr>().is_ok() {
|
|
return value;
|
|
}
|
|
}
|
|
|
|
socket_info.ip().to_string()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use crate::root;
|
|
use axum::body::Body;
|
|
use axum::handler::get;
|
|
use axum::http::{Request, StatusCode};
|
|
use axum::Router;
|
|
use itertools::Itertools;
|
|
use std::net::{SocketAddr, TcpListener};
|
|
|
|
#[tokio::test]
|
|
async fn benign_requests() {
|
|
let app = Router::new().route("/", get(root));
|
|
|
|
let listener = TcpListener::bind("0.0.0.0:0".parse::<SocketAddr>().unwrap()).unwrap();
|
|
let addr = listener.local_addr().unwrap();
|
|
|
|
tokio::spawn(async move {
|
|
axum::Server::from_tcp(listener)
|
|
.unwrap()
|
|
.serve(app.into_make_service_with_connect_info::<SocketAddr, _>())
|
|
.await
|
|
.unwrap();
|
|
});
|
|
let uri = format!("http://{}", addr);
|
|
let client = hyper::Client::new();
|
|
|
|
// No headers
|
|
let resp = client
|
|
.request(
|
|
Request::builder()
|
|
.uri(format!("http://{}", addr))
|
|
.body(Body::empty())
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
assert_eq!(resp.status(), StatusCode::OK);
|
|
|
|
let body = hyper::body::to_bytes(resp.into_body()).await.unwrap();
|
|
assert_eq!(body, "127.0.0.1");
|
|
|
|
// X-Forwarded-For
|
|
let ip_combinations =
|
|
IntoIterator::into_iter(["0.1.2.3", "255.255.255.0", "ff00::0", "ffab::1"])
|
|
.powerset()
|
|
.skip(1) // skip empty set
|
|
.collect::<Vec<_>>();
|
|
for ips in &ip_combinations {
|
|
let first = ips[0];
|
|
let resp = client
|
|
.request(
|
|
Request::builder()
|
|
.uri(&uri)
|
|
.header("X-Forwarded-For", ips.join(", "))
|
|
.body(Body::empty())
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
assert_eq!(resp.status(), StatusCode::OK);
|
|
|
|
let body = hyper::body::to_bytes(resp.into_body()).await.unwrap();
|
|
assert_eq!(body, first);
|
|
}
|
|
|
|
// Forwarded
|
|
let ip_combinations =
|
|
IntoIterator::into_iter(["0.1.2.3", "255.255.255.0", "\"[ff00::0]\"", "\"[ffab::1]\""])
|
|
.powerset()
|
|
.skip(1) // skip empty set
|
|
.collect::<Vec<_>>();
|
|
for ips in &ip_combinations {
|
|
let first = ips[0];
|
|
let resp = client
|
|
.request(
|
|
Request::builder()
|
|
.uri(&uri)
|
|
.header(
|
|
"Forwarded",
|
|
ips.into_iter().map(|v| format!("for={}", v)).join(", "),
|
|
)
|
|
.body(Body::empty())
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
assert_eq!(resp.status(), StatusCode::OK);
|
|
|
|
let body = hyper::body::to_bytes(resp.into_body()).await.unwrap();
|
|
if matches!(first.chars().next(), Some('"')) {
|
|
assert_eq!(body, first[2..first.len() - 2]);
|
|
} else {
|
|
assert_eq!(body, first);
|
|
}
|
|
}
|
|
}
|
|
|
|
macro_rules! assert_passthrough {
|
|
($resp:expr) => {
|
|
let body = hyper::body::to_bytes($resp.into_body()).await.unwrap();
|
|
assert_eq!(body, "127.0.0.1");
|
|
};
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn malicious_requests() {
|
|
let app = Router::new().route("/", get(root));
|
|
|
|
let listener = TcpListener::bind("0.0.0.0:0".parse::<SocketAddr>().unwrap()).unwrap();
|
|
let addr = listener.local_addr().unwrap();
|
|
|
|
tokio::spawn(async move {
|
|
axum::Server::from_tcp(listener)
|
|
.unwrap()
|
|
.serve(app.into_make_service_with_connect_info::<SocketAddr, _>())
|
|
.await
|
|
.unwrap();
|
|
});
|
|
let uri = format!("http://{}", addr);
|
|
let client = hyper::Client::new();
|
|
|
|
// X-Forwarded-For incomplete IPv4
|
|
assert_passthrough!(client
|
|
.request(
|
|
Request::builder()
|
|
.uri(&uri)
|
|
.header("X-Forwarded-For", "1.2.3")
|
|
.body(Body::empty())
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap());
|
|
|
|
// X-Forwarded-For incomplete IPv6
|
|
assert_passthrough!(client
|
|
.request(
|
|
Request::builder()
|
|
.uri(&uri)
|
|
.header("X-Forwarded-For", "1:4")
|
|
.body(Body::empty())
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap());
|
|
|
|
// X-Forwarded-For bad ipv4
|
|
assert_passthrough!(client
|
|
.request(
|
|
Request::builder()
|
|
.uri(&uri)
|
|
.header("X-Forwarded-For", "256.0.0.1")
|
|
.body(Body::empty())
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap());
|
|
|
|
// X-Forwarded-For bad ipv6
|
|
assert_passthrough!(client
|
|
.request(
|
|
Request::builder()
|
|
.uri(&uri)
|
|
.header("X-Forwarded-For", "ff::ff::ff")
|
|
.body(Body::empty())
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap());
|
|
|
|
// nonsensical X-Forwarded-For
|
|
assert_passthrough!(client
|
|
.request(
|
|
Request::builder()
|
|
.uri(&uri)
|
|
.header("X-Forwarded-For", "hello world")
|
|
.body(Body::empty())
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap());
|
|
}
|
|
}
|