tests
This commit is contained in:
parent
282b12b62d
commit
d899fc1649
3 changed files with 219 additions and 4 deletions
18
Cargo.lock
generated
18
Cargo.lock
generated
|
@ -73,6 +73,12 @@ version = "1.0.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457"
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
|
@ -200,7 +206,19 @@ name = "ip_reflect"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"hyper",
|
||||
"itertools",
|
||||
"tokio",
|
||||
"tower",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
@ -8,3 +8,8 @@ edition = "2018"
|
|||
[dependencies]
|
||||
axum = "0.2"
|
||||
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tower = "0.4"
|
||||
hyper = { version = "0.14", features = ["client"] }
|
||||
itertools = "0.10"
|
198
src/main.rs
198
src/main.rs
|
@ -4,7 +4,7 @@ use axum::extract::ConnectInfo;
|
|||
use axum::handler::get;
|
||||
use axum::http::{HeaderMap, HeaderValue};
|
||||
use axum::{Router, Server};
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
|
@ -32,7 +32,9 @@ async fn root(header: HeaderMap, ConnectInfo(socket_info): ConnectInfo<SocketAdd
|
|||
|
||||
// skip first character, since we know it's an `"`
|
||||
chars.next();
|
||||
let res: String = chars.take_while(|c| *c != '"').collect();
|
||||
// skip next character, since we know it's an [
|
||||
chars.next();
|
||||
let res: String = chars.take_while(|c| *c != ']').collect();
|
||||
if res.parse::<Ipv6Addr>().is_ok() {
|
||||
return res;
|
||||
}
|
||||
|
@ -46,12 +48,202 @@ async fn root(header: HeaderMap, ConnectInfo(socket_info): ConnectInfo<SocketAdd
|
|||
}
|
||||
|
||||
if let Some(Ok(value)) = header.get("X-Forwarded-For").map(HeaderValue::to_str) {
|
||||
return match value.split_once(',').map(|v| v.0) {
|
||||
let value = match value.split_once(',').map(|v| v.0) {
|
||||
Some(v) => v,
|
||||
None => value,
|
||||
}
|
||||
.to_string();
|
||||
if value.parse::<IpAddr>().is_ok() {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
socket_info.ip().to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::root;
|
||||
use axum::body::Body;
|
||||
use axum::handler::get;
|
||||
use axum::http::{Request, StatusCode};
|
||||
use axum::Router;
|
||||
use itertools::Itertools;
|
||||
use std::net::{SocketAddr, TcpListener};
|
||||
|
||||
#[tokio::test]
|
||||
async fn benign_requests() {
|
||||
let app = Router::new().route("/", get(root));
|
||||
|
||||
let listener = TcpListener::bind("0.0.0.0:0".parse::<SocketAddr>().unwrap()).unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
axum::Server::from_tcp(listener)
|
||||
.unwrap()
|
||||
.serve(app.into_make_service_with_connect_info::<SocketAddr, _>())
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
let uri = format!("http://{}", addr);
|
||||
let client = hyper::Client::new();
|
||||
|
||||
// No headers
|
||||
let resp = client
|
||||
.request(
|
||||
Request::builder()
|
||||
.uri(format!("http://{}", addr))
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = hyper::body::to_bytes(resp.into_body()).await.unwrap();
|
||||
assert_eq!(body, "127.0.0.1");
|
||||
|
||||
// X-Forwarded-For
|
||||
let ip_combinations =
|
||||
IntoIterator::into_iter(["0.1.2.3", "255.255.255.0", "ff00::0", "ffab::1"])
|
||||
.powerset()
|
||||
.skip(1) // skip empty set
|
||||
.collect::<Vec<_>>();
|
||||
for ips in &ip_combinations {
|
||||
let first = ips[0];
|
||||
let resp = client
|
||||
.request(
|
||||
Request::builder()
|
||||
.uri(&uri)
|
||||
.header("X-Forwarded-For", ips.join(", "))
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = hyper::body::to_bytes(resp.into_body()).await.unwrap();
|
||||
assert_eq!(body, first);
|
||||
}
|
||||
|
||||
// Forwarded
|
||||
let ip_combinations =
|
||||
IntoIterator::into_iter(["0.1.2.3", "255.255.255.0", "\"[ff00::0]\"", "\"[ffab::1]\""])
|
||||
.powerset()
|
||||
.skip(1) // skip empty set
|
||||
.collect::<Vec<_>>();
|
||||
for ips in &ip_combinations {
|
||||
let first = ips[0];
|
||||
let resp = client
|
||||
.request(
|
||||
Request::builder()
|
||||
.uri(&uri)
|
||||
.header(
|
||||
"Forwarded",
|
||||
ips.into_iter().map(|v| format!("for={}", v)).join(", "),
|
||||
)
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = hyper::body::to_bytes(resp.into_body()).await.unwrap();
|
||||
if matches!(first.chars().next(), Some('"')) {
|
||||
assert_eq!(body, first[2..first.len() - 2]);
|
||||
} else {
|
||||
assert_eq!(body, first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! assert_passthrough {
|
||||
($resp:expr) => {
|
||||
let body = hyper::body::to_bytes($resp.into_body()).await.unwrap();
|
||||
assert_eq!(body, "127.0.0.1");
|
||||
};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn malicious_requests() {
|
||||
let app = Router::new().route("/", get(root));
|
||||
|
||||
let listener = TcpListener::bind("0.0.0.0:0".parse::<SocketAddr>().unwrap()).unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
axum::Server::from_tcp(listener)
|
||||
.unwrap()
|
||||
.serve(app.into_make_service_with_connect_info::<SocketAddr, _>())
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
let uri = format!("http://{}", addr);
|
||||
let client = hyper::Client::new();
|
||||
|
||||
// X-Forwarded-For incomplete IPv4
|
||||
assert_passthrough!(client
|
||||
.request(
|
||||
Request::builder()
|
||||
.uri(&uri)
|
||||
.header("X-Forwarded-For", "1.2.3")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap());
|
||||
|
||||
// X-Forwarded-For incomplete IPv6
|
||||
assert_passthrough!(client
|
||||
.request(
|
||||
Request::builder()
|
||||
.uri(&uri)
|
||||
.header("X-Forwarded-For", "1:4")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap());
|
||||
|
||||
// X-Forwarded-For bad ipv4
|
||||
assert_passthrough!(client
|
||||
.request(
|
||||
Request::builder()
|
||||
.uri(&uri)
|
||||
.header("X-Forwarded-For", "256.0.0.1")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap());
|
||||
|
||||
// X-Forwarded-For bad ipv6
|
||||
assert_passthrough!(client
|
||||
.request(
|
||||
Request::builder()
|
||||
.uri(&uri)
|
||||
.header("X-Forwarded-For", "ff::ff::ff")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap());
|
||||
|
||||
// nonsensical X-Forwarded-For
|
||||
assert_passthrough!(client
|
||||
.request(
|
||||
Request::builder()
|
||||
.uri(&uri)
|
||||
.header("X-Forwarded-For", "hello world")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue