#![feature(try_blocks, byte_slice_trim_ascii)]
use std::{
convert::{identity as id, Infallible},
error::Error,
fmt::{Display, Debug},
net::SocketAddr,
path::PathBuf,
process::ExitCode,
sync::{Arc, Mutex, atomic::AtomicBool},
time::Duration,
};
use std::sync::atomic::Ordering::SeqCst;
use bytes::{Bytes, Buf};
use futures::stream::FuturesUnordered;
use h3::error::{ErrorLevel, Code};
use http::{Response, Request, HeaderName, header, HeaderMap, HeaderValue, Method, StatusCode, uri::Scheme, Uri};
use hyper::{
service::{make_service_fn, service_fn},
Body, Server, body::HttpBody,
server::conn::AddrStream,
};
use rand::{seq::IteratorRandom, thread_rng};
use rustls::server::AllowAnyAuthenticatedClient;
use sd_notify::{notify, NotifyState};
use structopt::StructOpt;
use tokio::{net::lookup_host, select, time::sleep, try_join};
use tokio_util::sync::CancellationToken;
use tracing::{error, info};
use h3_quinn::quinn;
mod config;
mod utils;
use crate::config::PeerMode;
use crate::utils::{cancellable, drain_stream, with_background};
static ALPN: &[u8] = b"h3";
#[derive(StructOpt, Debug)]
#[structopt(about)]
struct Opt {
#[structopt(
long,
short,
default_value = "/etc/ptproxy/config.toml",
help = "Path to configuration file"
)]
pub config: PathBuf,
}
#[derive(Debug)]
struct AppError {
task: String,
inner: Box<dyn Error>,
}
impl Display for AppError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "failed at {}: {}", self.task, self.inner)
}
}
impl Error for AppError {}
fn wrap_fn<'a, F, T>(task: F) -> impl FnOnce(T) -> AppError + 'a
where
F: FnOnce() -> String + 'a,
T: Into<Box<dyn Error>>,
{
|err| AppError {
task: task(),
inner: err.into(),
}
}
fn wrap<'a, T>(task: &'a str) -> impl FnOnce(T) -> AppError + 'a
where
T: Into<Box<dyn Error>>,
{
wrap_fn(move || task.to_owned())
}
fn wrap_arg<'a, T, A>(task: &'a str, arg: &'a A) -> impl FnOnce(T) -> AppError + 'a
where
T: Into<Box<dyn Error>>,
A: Debug,
{
wrap_fn(move || format!("{} ({:?})", task, arg))
}
#[tokio::main]
async fn main() -> ExitCode {
let result = real_main().await;
if let Err(err) = result {
let msg = format!("failed at {}: {:?}", err.task, err.inner);
let _ = notify(false, &[NotifyState::Status(&msg)]);
error!("{}", err);
return ExitCode::FAILURE
}
ExitCode::SUCCESS
}
async fn real_main() -> Result<(), AppError> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_span_events(tracing_subscriber::fmt::format::FmtSpan::FULL)
.with_writer(std::io::stderr)
.with_max_level(tracing::Level::INFO)
.init();
let opt = Opt::from_args();
let config = tokio::fs::read_to_string(opt.config.clone())
.await
.map_err(wrap_arg("reading configuration file", &opt.config))?;
let config: config::Config = toml::from_str(&config)
.map_err(wrap("parsing configuration"))?;
let config_base = opt.config.parent().unwrap();
let general = &config.general;
let connect_interval = Duration::from_millis(
config
.system
.connect_interval
.unwrap_or(config::default_connect_interval()),
);
let roots = crate::utils::load_root_certs(&config, &config_base)
.map_err(wrap("loading root CA certificates"))?;
let cert = config_base.join(config.tls.cert);
let cert = crate::utils::load_certificates_from_pem(&cert)
.map_err(wrap_arg("loading certificate", &cert))?;
let key = config_base.join(config.tls.key);
let key = crate::utils::load_private_key_from_file(&key)
.map_err(wrap_arg("loading certificate key", &key))?;
let endpoint_config = quinn::EndpointConfig::default();
let transport_config = crate::utils::build_transport_config(
general.mode,
&config.transport,
)
.map_err(wrap("building tranport config"))?;
let transport_config = Arc::new(transport_config);
let client_config = {
let mut tls_config = rustls::ClientConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&[&rustls::version::TLS13])
.expect("invalid TLS parameters?")
.with_root_certificates(roots.clone())
.with_client_auth_cert(cert.clone(), key.clone())
.map_err(wrap("building client TLS config"))?;
tls_config.enable_early_data = true;
tls_config.alpn_protocols = vec![ALPN.into()];
tls_config.key_log = Arc::new(rustls::KeyLogFile::new());
let mut config = quinn::ClientConfig::new(Arc::new(tls_config));
config.transport_config(transport_config.clone());
config
};
let server_config = {
let cert_verifier = Arc::new(crate::utils::StrictClientCertVerifier {
inner: AllowAnyAuthenticatedClient::new(roots.clone()),
server_name: general.peer_hostname.as_str().try_into()
.map_err(wrap_arg("parsing peer hostname", &general.peer_hostname))?,
});
let mut tls_config = rustls::ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&[&rustls::version::TLS13])
.expect("invalid TLS parameters?")
.with_client_cert_verifier(cert_verifier)
.with_single_cert(cert.clone(), key.clone())
.map_err(wrap("building server TLS config"))?;
tls_config.max_early_data_size = u32::MAX;
tls_config.alpn_protocols = vec![ALPN.into()];
tls_config.key_log = Arc::new(rustls::KeyLogFile::new());
let mut config = quinn::ServerConfig::with_crypto(Arc::new(tls_config));
config.transport_config(transport_config.clone());
config.migration(true);
config.use_retry(false);
config
};
let client_addr = (
general
.connect_address
.clone()
.unwrap_or(general.peer_hostname.clone()),
general.quic_port,
);
let endpoint_addr = SocketAddr::new(
general.bind_address,
match general.mode {
PeerMode::Client => 0,
PeerMode::Server => general.quic_port,
},
);
'err: {
if general.mode != PeerMode::Client && general.http_bind_address.is_some() {
break 'err Err("http_bind_address can only be present in client mode")
}
if general.mode != PeerMode::Server && general.http_connect_address.is_some() {
break 'err Err("http_connect_address can only be present in server mode")
}
if general.mode == PeerMode::Server && general.http_connect_address.is_none() {
break 'err Err("http_connect_address must be present in server mode")
}
Ok(())
}
.map_err(wrap("validating configuration"))?;
let listener_addr = general
.http_bind_address
.clone()
.unwrap_or(config::default_http_bind_address());
let upstream_url = match general.http_connect_address.as_ref() {
None => None,
Some(addr) => Some(Uri::builder()
.scheme(Scheme::HTTP)
.authority(addr.as_str())
.path_and_query("/") .build()
.map_err(wrap_arg("parsing connect address", addr))?),
};
let tcp_nodelay = config
.system
.tcp_nodelay
.unwrap_or(config::default_tcp_nodelay());
let mut http_connector = hyper::client::connect::HttpConnector::new();
http_connector.set_keepalive(Some(Duration::from_millis(15000)));
http_connector.set_nodelay(tcp_nodelay);
let http_client = hyper::Client::builder()
.http1_title_case_headers(true)
.set_host(false)
.build(http_connector);
let http_server = match general.mode {
PeerMode::Server => None,
PeerMode::Client => Some(Server::try_bind(&listener_addr)
.map_err(wrap_arg("binding listener socket", &listener_addr))?
.tcp_nodelay(tcp_nodelay)
.http1_title_case_headers(true)
.http1_only(true))
};
let socket = std::net::UdpSocket::bind(endpoint_addr)
.map_err(wrap_arg("binding QUIC endpoint socket", &endpoint_addr))?;
crate::utils::configure_endpoint_socket(&socket, &config.transport)
.map_err(wrap("configuring QUIC endpoint socket"))?;
let endpoint = {
let mut endpoint = quinn::Endpoint::new(
endpoint_config,
(general.mode == PeerMode::Server).then_some(server_config),
socket,
Arc::new(quinn::TokioRuntime),
)
.map_err(wrap("creating QUIC endpoint"))?;
if general.mode == PeerMode::Client {
endpoint.set_default_client_config(client_config);
}
endpoint
};
let add_forwarded = config.system.add_forwarded.unwrap_or(config::default_add_forwarded());
let ready_sent = Arc::new(AtomicBool::new(false));
let send_status = move |message: &str, is_ready: bool| {
let states = [
NotifyState::Status(message),
NotifyState::Ready,
];
let add_ready = is_ready && !ready_sent.swap(true, SeqCst);
let states = &states[0..(1 + add_ready as usize)];
let _ = notify(false, &states);
};
let stop_token = CancellationToken::new();
{
let stop_token = stop_token.clone();
let send_status = send_status.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.unwrap();
info!("stopping server...");
send_status("stopping server", false);
stop_token.cancel();
let _ = notify(false, &[NotifyState::Stopping]);
tokio::signal::ctrl_c().await.unwrap();
std::process::exit(130);
});
}
let watchdog_interval = {
let watchdog_factor = config.system.watchdog_factor.unwrap_or(config::default_watchdog_factor());
let watchdog_usec = {
let mut result = 0;
sd_notify::watchdog_enabled(true, &mut result).then_some(result)
};
watchdog_usec.map(|x| Duration::from_secs_f32((x as f32) / (watchdog_factor * 1e6)))
};
let watchdog_loop = async {
if let Some(watchdog_interval) = watchdog_interval {
loop {
let _ = notify(false, &[NotifyState::Watchdog]);
sleep(watchdog_interval).await;
}
}
};
let wait_for_first_attempt = config
.system
.wait_for_first_attempt
.unwrap_or(config::default_wait_for_first_attempt());
struct EstablishedConnection {
send_request: Option<h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>>,
}
struct ConnectionGuard<'a>(&'a Mutex<Option<EstablishedConnection>>);
impl<'a> Drop for ConnectionGuard<'a> {
fn drop(&mut self) {
*self.0.lock().unwrap() = None;
}
}
let current_connection = Arc::new(Mutex::new(None::<EstablishedConnection>));
let establish_client_connection = || async {
let addr = lookup_host(&client_addr)
.await?
.choose(&mut thread_rng())
.ok_or("resolution found no addresses")?;
let quinn_connection = endpoint.connect(addr, &general.peer_hostname)?.await?;
info!("connection {} established", quinn_connection.stable_id());
let h3_connection = h3_quinn::Connection::new(quinn_connection);
let connection = h3::client::new(h3_connection).await?;
Ok::<_, Box<dyn Error>>(connection)
};
let client_iteration = || async {
let (mut driver, send_request) = select! {
() = stop_token.cancelled() => return Ok(()),
value = establish_client_connection() => value,
}?;
let state_guard = ConnectionGuard(¤t_connection);
*state_guard.0.lock().unwrap() = Some(EstablishedConnection {
send_request: Some(send_request),
});
info!("tunnel established");
send_status("tunnel established", true);
let mut have_closed = false;
with_background(driver.wait_idle(), async {
stop_token.cancelled().await;
if !have_closed {
have_closed = true;
state_guard.0.lock().unwrap().as_mut().unwrap().send_request = None;
}
})
.await?;
if !have_closed {
Err("server closed the connection")?
}
Ok::<(), Box<dyn Error>>(())
};
let client_loop = || async {
send_status("attempting first connection", !wait_for_first_attempt);
while let Err(error) = client_iteration().await {
error!("client connection failed: {}", error);
send_status(&format!("client connection failed: {}", error), true);
sleep(connect_interval).await;
}
Ok::<(), Box<dyn Error>>(())
};
let handle_request_client = {
let current_connection = current_connection.clone();
move |fwd: Option<HeaderValue>, mut request: Request<Body>| {
let current_connection = current_connection.clone();
async move {
if request.method() == Method::CONNECT {
return Ok(Response::builder()
.header(header::SERVER, "ptproxy client")
.header(header::CONTENT_TYPE, "text/plain;charset=UTF-8")
.status(StatusCode::METHOD_NOT_ALLOWED)
.body("CONNECT requests not implemented yet\n".into())
.unwrap());
}
let chunked = match is_chunked_message(request.headers()) {
Some(x) => x,
None => return Ok(Response::builder()
.header(header::SERVER, "ptproxy client")
.header(header::CONTENT_TYPE, "text/plain;charset=UTF-8")
.status(StatusCode::BAD_REQUEST)
.body("invalid Transfer-Encoding value: only chunked transfer coding supported\n".into())
.unwrap()),
};
if chunked {
request.headers_mut().remove(header::CONTENT_LENGTH);
}
remove_hop_by_hop_headers(request.headers_mut());
let send_request = {
let current_connection = current_connection.lock().unwrap();
current_connection.as_ref().and_then(|s| s.send_request.clone())
};
let mut send_request = match send_request {
Some(value) => value,
None => return Ok(Response::builder()
.header(header::SERVER, "ptproxy client")
.header(header::CONTENT_TYPE, "text/plain;charset=UTF-8")
.status(StatusCode::SERVICE_UNAVAILABLE)
.body("tunnel not established\n".into())
.unwrap()),
};
if let Some(fwd) = fwd {
request.headers_mut().append(header::FORWARDED, fwd);
}
*request.version_mut() = http::Version::HTTP_3;
let (mut body, request) = {
let (parts, body) = request.into_parts();
(body, Request::from_parts(parts, ()))
};
let mut stream = match send_request.send_request(request).await {
Ok(value) => value,
Err(err) => return Ok(Response::builder()
.header(header::SERVER, "ptproxy client")
.header(header::CONTENT_TYPE, "text/plain;charset=UTF-8")
.status(StatusCode::BAD_GATEWAY)
.body(format!("error sending request:\n{}\n", err).into())
.unwrap()),
};
let proxy_request_body = async {
while let Some(buf) = body.data().await {
let buf = buf.map_err(|err| format!("when receiving data: {}", err))?;
stream.send_data(buf).await.map_err(|err| format!("when sending data: {}", err))?;
}
stream.finish().await.map_err(|err| format!("when finishing stream: {}", err))?;
Ok::<_, Box<dyn Error>>(())
};
if let Err(err) = proxy_request_body.await {
return Ok(Response::builder()
.header(header::SERVER, "ptproxy client")
.header(header::CONTENT_TYPE, "text/plain;charset=UTF-8")
.status(StatusCode::BAD_GATEWAY)
.body(format!("error when streaming request body:\n{}\n", err).into())
.unwrap())
}
let mut response = match stream.recv_response().await {
Ok(value) => value,
Err(err) => return Ok(Response::builder()
.header(header::SERVER, "ptproxy client")
.header(header::CONTENT_TYPE, "text/plain;charset=UTF-8")
.status(StatusCode::BAD_GATEWAY)
.body(format!("error when receiving response:\n{}\n", err).into())
.unwrap())
};
*response.version_mut() = http::Version::HTTP_11;
let (mut sender, response) = {
let (sender, body) = Body::channel();
(sender, Response::from_parts(response.into_parts().0, body))
};
tokio::spawn(async move {
let result: Result<(), Box<dyn Error + Send + Sync>> = try {
while let Some(mut buf) = {
stream.recv_data().await.map_err(|err| format!("when receiving data: {}", err))?
} {
sender.send_data(buf.copy_to_bytes(buf.remaining())).await.map_err(|err| format!("when sending data: {}", err))?;
}
};
if let Err(err) = result {
sender.abort();
error!("error when proxying response: {}", err);
}
});
Ok::<_, Infallible>(response)
}
}
};
let listener_loop = || async {
let make_svc = make_service_fn(move |conn: &AddrStream| {
let fwd: HeaderValue = format!(
"for={:?};by={:?};proto=http",
conn.remote_addr().to_string(),
conn.local_addr().to_string()
).try_into().unwrap();
let fwd = add_forwarded.then_some(fwd);
let handle_request_client = handle_request_client.clone();
let handle_request_client = move |req| handle_request_client(fwd.clone(), req);
async move {
Ok::<_, Infallible>(service_fn(handle_request_client))
}
});
http_server
.unwrap()
.serve(make_svc)
.with_graceful_shutdown(stop_token.cancelled())
.await?;
Ok::<(), Box<dyn Error>>(())
};
let handle_request_server = |fwd: Option<HeaderValue>, (mut request, mut stream): (Request<()>, h3::server::RequestStream<_, _>)| {
let upstream_url = upstream_url.clone().unwrap();
let http_client = http_client.clone();
async move {
if request.uri().scheme() != Some(&Scheme::HTTPS) {
}
if !request.headers().contains_key(header::HOST) {
let authority = match request.uri().authority() {
Some(value) => value,
None => unreachable!(), };
let value = authority.as_str().try_into().unwrap();
request.headers_mut().append(header::HOST, value);
}
*request.uri_mut() = {
let mut parts = upstream_url.into_parts();
parts.path_and_query = Some(request.uri().path_and_query().unwrap().clone());
Uri::from_parts(parts).unwrap()
};
if let Some(fwd) = fwd {
request.headers_mut().append(header::FORWARDED, fwd);
}
*request.version_mut() = http::Version::HTTP_11;
let (mut sender, request) = {
let (sender, body) = Body::channel();
(sender, Request::from_parts(request.into_parts().0, body))
};
let response: Result<_, Box<dyn Error + Sync + Send>> = try_join!(
async {
Ok(http_client.request(request).await.map_err(|err| format!("when making request: {}", err))?)
},
async {
let result = try {
while let Some(mut buf) = {
stream.recv_data().await.map_err(|err| format!("when receiving data: {}", err))?
} {
sender.send_data(buf.copy_to_bytes(buf.remaining())).await.map_err(|err| format!("when sending data: {}", err))?;
}
};
if let Err(_) = result {
sender.abort();
}
result
}
);
let mut response = match response {
Ok((value, ())) => value,
Err(err) => {
let body: Bytes = format!("could not proxy request:\n{}\n", err).into();
stream.send_response(Response::builder()
.header(header::SERVER, "ptproxy server")
.header(header::CONTENT_TYPE, "text/plain;charset=UTF-8")
.header(header::CONTENT_LENGTH, body.len())
.status(StatusCode::BAD_GATEWAY)
.body(())
.unwrap()).await?;
stream.send_data(body).await?;
stream.stop_sending(Code::H3_NO_ERROR);
stream.finish().await?;
return Ok(());
},
};
let chunked = match is_chunked_message(response.headers()) {
Some(x) => x,
None => unreachable!(), };
if chunked {
response.headers_mut().remove(header::CONTENT_LENGTH);
}
remove_hop_by_hop_headers(response.headers_mut());
*response.version_mut() = http::Version::HTTP_3;
let (mut body, response) = {
let (parts, body) = response.into_parts();
(body, Response::from_parts(parts, ()))
};
if let Err(err) = stream.send_response(response).await {
error!("error sending response: {}", err);
stream.stop_stream(Code::H3_INTERNAL_ERROR);
return Ok(());
}
let proxy_response_body = async {
while let Some(buf) = body.data().await {
let buf = buf.map_err(|err| format!("when receiving data: {}", err))?;
stream.send_data(buf).await.map_err(|err| format!("when sending data: {}", err))?;
}
stream.finish().await.map_err(|err| format!("when finishing stream: {}", err))?;
Ok::<_, Box<dyn Error>>(())
};
if let Err(err) = proxy_response_body.await {
error!("error sending response body: {}", err);
stream.stop_stream(Code::H3_INTERNAL_ERROR);
return Ok(());
}
Ok::<_, Box<dyn Error + Sync + Send>>(())
}
};
let handle_established_connection_server = |quinn_connection: quinn::Connection| {
let stop_token = &stop_token;
let endpoint = &endpoint;
async move {
let h3_connection = h3_quinn::Connection::new(quinn_connection.clone());
let mut connection = select! {
() = stop_token.cancelled() => return Ok(()),
value = h3::server::Connection::<_, Bytes>::new(h3_connection) => value,
}?;
let fwd: HeaderValue = format!(
"for={:?};by={:?};proto=http3",
quinn_connection.remote_address().to_string(),
endpoint.local_addr()?.to_string()
).try_into().unwrap();
let fwd = add_forwarded.then_some(fwd);
let result = loop {
let request = select! {
() = stop_token.cancelled() => break Ok(()),
value = connection.accept() => value,
};
let request = match request {
Ok(Some(value)) => value,
Ok(None) => break Err("client closed the connection".into()),
Err(err) => {
if err.get_error_level() == ErrorLevel::StreamError {
error!(
"connection {} failed accepting: {}",
quinn_connection.stable_id(),
err
);
continue;
}
break Err(err.into());
}
};
tokio::spawn(handle_request_server(fwd.clone(), request));
};
let result = result.and(connection.shutdown(100).await.map_err(|err| err.into())); id::<Result<_, Box<dyn Error>>>(result)
}
};
let handle_connection_server = |connection| async {
let connection = select! {
() = stop_token.cancelled() => return,
value = connection => value,
};
let connection: quinn::Connection = match connection {
Ok(value) => value,
Err(_) => return,
};
info!("connection {} established ({})", connection.stable_id(), connection.remote_address());
let result = handle_established_connection_server(connection.clone()).await;
if let Err(error) = result {
error!("connection {} failed: {}", connection.stable_id(), error);
}
};
let server_loop = || async {
let mut connections = FuturesUnordered::new();
send_status("accepting connections", true);
loop {
let accept_future = with_background(
cancellable(endpoint.accept(), &stop_token),
drain_stream(&mut connections),
);
let connecting = match accept_future.await {
None => break, Some(new_conn) => new_conn.unwrap(),
};
connections.push(handle_connection_server(connecting));
}
send_status("waiting for outstanding connections to close", false);
drain_stream(&mut connections).await;
};
let main_loop = async {
info!("started endpoint at {}", endpoint.local_addr()?);
match general.mode {
PeerMode::Client => {
try_join!(listener_loop(), client_loop())?;
}
PeerMode::Server => {
server_loop().await;
}
}
endpoint.close(Code::H3_NO_ERROR.value().try_into().unwrap(), &[]);
info!("waiting for endpoint to finish...");
send_status("waiting for endpoint to finish", false);
endpoint.wait_idle().await;
Ok::<_, Box<dyn Error>>(())
};
with_background(main_loop, watchdog_loop)
.await
.map_err(wrap("main loop"))
}
static HEADER_KEEP_ALIVE: HeaderName = HeaderName::from_static("keep-alive");
static HEADER_PROXY_CONNECTION: HeaderName = HeaderName::from_static("proxy-connection");
fn split_comma(value: &HeaderValue) -> impl Iterator<Item = &[u8]> {
value.as_bytes().split(|b| *b == b',').map(|t| t.trim_ascii())
}
fn split_simple_header(headers: &HeaderMap<HeaderValue>, header: HeaderName) -> impl Iterator<Item = &[u8]> {
headers.get_all(header).iter().map(split_comma).flatten()
}
fn remove_hop_by_hop_headers(headers: &mut HeaderMap<HeaderValue>) {
static KNOWN_HOP_BY_HOP_HEADERS: &[&HeaderName] = &[
&header::CONNECTION,
&header::TE,
&header::TRANSFER_ENCODING,
&header::TRAILER,
&HEADER_KEEP_ALIVE,
&header::UPGRADE,
&HEADER_PROXY_CONNECTION,
&header::PROXY_AUTHENTICATE,
&header::PROXY_AUTHORIZATION,
];
let connection_headers: Vec<HeaderName> = split_simple_header(headers, header::CONNECTION)
.filter_map(|value| HeaderName::from_bytes(value).ok())
.filter(|value| value != "close")
.collect();
for name in connection_headers {
headers.remove(name);
}
for name in KNOWN_HOP_BY_HOP_HEADERS {
headers.remove(*name);
}
}
fn is_chunked_message(headers: &HeaderMap<HeaderValue>) -> Option<bool> {
let value: Vec<_> = headers.get_all(header::TRANSFER_ENCODING).iter().collect();
if value.is_empty() {
return Some(false)
}
if value.len() != 1 {
return None
}
let value = value.get(0).unwrap().as_bytes().to_ascii_lowercase();
if value != b"chunked" {
return None
}
Some(true)
}