diff --git a/src/main.rs b/src/main.rs index ad9cdbc..f3e937c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,5 @@ -use std::{net::SocketAddr, convert::Infallible}; - -use hyper::{Client, Server, service::{make_service_fn, service_fn}, body::HttpBody, Body, Response, Request, Method, http, upgrade::{self, Upgraded}}; -use tokio::net::TcpStream; - -// https://github.com/hyperium/hyper/blob/master/examples/gateway.rs -// https://en.wikipedia.org/wiki/Gateway_(telecommunications) +use std::net::SocketAddr; +use hyper::{service::{make_service_fn, service_fn}, Client, Error, Server, Response, StatusCode, Body, Request, Uri}; #[tokio::main(flavor = "current_thread")] async fn main() { @@ -13,12 +8,12 @@ async fn main() { let config = { // Prefer ./config over /etc/proxima - let file = unsafe { + let file = { ["./config", "/etc/proxima"] .iter() .map(load) .reduce(Result::or) - .unwrap_unchecked() + .expect("Impossible") }; file.map(parse) @@ -27,31 +22,135 @@ async fn main() { }; let addr = SocketAddr::from(([127, 0, 0, 1], 8100)); + + let client = Client::new(); + + let make_service = make_service_fn(move |_| { + let client = client.clone(); + let config = config.clone(); + + async move { + // This is the `Service` that will handle the connection. + // `service_fn` is a helper to convert a function that + // returns a Response into a `Service`. + Ok::<_, Error>(service_fn(move |mut req| { + let config = config.clone(); + let client = client.clone(); + + async move { + for Rule (pattern, effect) in config.rules() { + println!("{} {}", req.method(), req.uri()); + if pattern.matches(&req) { + return match effect { + Effect::Proxy { port, .. } => { + let host = "0.0.0.0"; // Support for custom hosts added later + let path = req.uri().path_and_query().map(|x| x.as_str()).unwrap_or(""); + let target = format!("http://{host}:{port}{path}"); + + let uri = target.parse().unwrap(); + *req.uri_mut() = uri; + + println!("Proxying to {target}"); + + client.request(req).await + }, + Effect::Redirect (uri) => Ok ({ + println!("Redirecting to {uri}"); + Response::builder() + .status(StatusCode::PERMANENT_REDIRECT) + .header("Location", uri) + .body(Body::empty()) + .unwrap() + }), + } + } + } + + Ok (Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::empty()) + .unwrap()) + } + })) + } + }); + + let server = Server::bind(&addr).serve(make_service); + + if let Err(e) = server.await { + eprintln!("server error: {}", e); + } + } +#[derive(Clone, Debug)] pub struct Rule (Pattern, Effect); impl Rule { /// Get the domain of the pattern. - pub fn domain (&self) -> &str { - &self.0.domain - } + pub fn domain (&self) -> &str { + &self.0.domain + } - /// Get the portspec - pub fn ports (&self) -> &Ports { - &self.0.ports - } + /// Get the portspec + pub fn ports (&self) -> &Ports { + &self.0.ports + } - pub fn effect (&self) -> &Effect { - &self.1 - } + pub fn effect (&self) -> &Effect { + &self.1 + } } +#[derive(Clone, Debug)] pub struct Pattern { domain: String, ports: Ports, } +impl Pattern { + pub fn matches (&self, req: &Request) -> bool { + let uri = req.uri(); + let (host, port) = { + let host = req + .headers() + .get("host") + .and_then(|x| x.to_str().ok()) + .and_then(|x| x.parse::().ok()); + + let h = uri + .host() + .map(|x| x.to_string()) + .or_else(|| { + host.clone().and_then(|x| { + x.host().map(|x| x.to_string()) + }) + }); + + let p = uri + .port_u16() + .or_else(|| { + host.and_then(|x| x.port_u16()) + }); + + (h, p) + }; + + match host { + Some (h) if &h == &self.domain => match &self.ports { + Ports::Any => true, + spec => match port { + Some (p) => spec.includes(p), + None => false, + } + }, + _ => false, + } + + } +} + +#[derive(Clone, Debug)] pub enum Effect { Redirect (String), Proxy { @@ -61,11 +160,12 @@ pub enum Effect { } impl Effect { - pub async fn perform (&self) -> std::io::Result<()> { - todo!() - } + pub async fn perform (&self) -> Response { + todo!() + } } +#[derive(Clone, Debug)] pub enum Ports { Single (u16), Either (Vec), @@ -73,16 +173,55 @@ pub enum Ports { } impl Ports { - /// Whether this set of ports includes the given port. + /// Whether this set of ports includes the given port. pub fn includes (&self, p: u16) -> bool { - match self { - Ports::Single (x) => *x == p, - Ports::Either (l) => l.contains(&p), - Ports::Any => true, - } - } + match self { + Ports::Single (x) => *x == p, + Ports::Either (l) => l.contains(&p), + Ports::Any => true, + } + } } +/// A config consists of a set if [`Rule`]. +#[derive(Clone, Debug)] +pub struct Config (Vec); + +impl Config { + pub fn rules (&self) -> impl Iterator { + self.0.iter() + } +} + +/// Load a config from a path. +pub fn load (p: impl AsRef) -> std::io::Result { + std::fs::read_to_string(p.as_ref()) +} + +/// Parse a config string. +/// +/// Example config string: +/// +/// ```text +/// hmt.riley.lgbt : (80 | 443) --> 6000 # --> is proxy_pass +/// riley.lgbt : (80 | 443) --> 3000 [ssl] # add [ssl] to automate ssl for this domain +/// rly.cx : any ==> riley.lgbt # ==> is HTTP redirect +/// ``` +pub fn parse (data: String) -> Config { + let rules = data + .lines() + .map(parse::rule) + .filter_map(|x| match x { + Ok ((_, rule)) => Some (rule), + Err (e) => { + eprintln!("Error parsing rule: {:?}", e); + None + }, + }) + .collect(); + + Config (rules) +} pub mod parse { use super::{ Ports, Effect, Pattern, Rule }; @@ -126,7 +265,7 @@ pub mod parse { .parse(s) } - /// Parse an [`Effect`]. + /// Parse an [`Effect`]. pub fn effect (s: &str) -> PResult<'_, Effect> { let redirect = { @@ -165,24 +304,24 @@ pub mod parse { .parse(s) } - /// Parse a [`Pattern`]. - /// - /// ``` - /// use proxima::parse; - /// - /// # fn main () -> parse::PResult<'static, ()> { - /// let (_, pattern) = parse::pattern("example.com : any")?; - /// # Ok ("", ()) - /// # } - /// ``` + /// Parse a [`Pattern`]. + /// + /// ``` + /// use proxima::parse; + /// + /// # fn main () -> parse::PResult<'static, ()> { + /// let (_, pattern) = parse::pattern("example.com : any")?; + /// # Ok ("", ()) + /// # } + /// ``` pub fn pattern (s: &str) -> PResult<'_, Pattern> { - let spaced = |x| seq::delimited(chr::char(' '), x, chr::char(' ')); - seq::separated_pair(domain, spaced(chr::char(':')), portspec) - .map(|(domain, ports)| Pattern { domain, ports }) - .parse(s) + let spaced = |x| seq::delimited(chr::space1, x, chr::space1); + seq::separated_pair(domain, spaced(chr::char(':')), portspec) + .map(|(domain, ports)| Pattern { domain, ports }) + .parse(s) } - /// Parse a [`Rule`]. + /// Parse a [`Rule`]. pub fn rule (s: &str) -> PResult<'_, Rule> { pattern.and(effect) .map(|(p, e)| Rule (p, e)) @@ -194,82 +333,50 @@ pub mod parse { use super::*; - /// Test whether a pattern containing an Any portspec gets parsed - /// correctly. - #[test] - fn simple_pattern () { - let input = "example.com : any"; - let (_, Pattern { domain, ports }) = pattern(input).unwrap(); - assert!(domain == "example.com"); - assert!(match ports { - Ports::Any => true, - _ => false, - }) - } + /// Test whether a pattern containing an Any portspec gets parsed + /// correctly. + #[test] + fn simple_pattern () { + let input = "example.com : any"; + let (_, Pattern { domain, ports }) = pattern(input).unwrap(); + assert!(domain == "example.com"); + assert!(match ports { + Ports::Any => true, + _ => false, + }) + } - /// Test whether an Either portspec is parsed correctly. - #[test] - fn either_portspec () { - let input = "(69 | 420)"; - assert!(match portspec(input).unwrap() { - ("", Ports::Either(p)) => p == [69, 420], - _ => false, - }) - } + /// Test whether an Either portspec is parsed correctly. + #[test] + fn either_portspec () { + let input = "(69 | 420)"; + assert!(match portspec(input).unwrap() { + ("", Ports::Either(p)) => p == [69, 420], + _ => false, + }) + } - /// Test whether domain names are parsed correctly. - #[test] - fn domains () { - let inputs = ["example.com", "im.badat.dev", "riley.lgbt", "toot.site", "a.b.c.d.e.f.g.h"]; - // Each of these should be considered valid - for input in inputs { - domain(input).unwrap(); - } - } + /// Test whether domain names are parsed correctly. + #[test] + fn domains () { + let inputs = ["example.com", "im.badat.dev", "riley.lgbt", "toot.site", "a.b.c.d.e.f.g.h"]; + // Each of these should be considered valid + for input in inputs { + domain(input).unwrap(); + } + } - /// Test whether a simple rule gets parsed correctly. - #[test] - fn simple_rule () { - let input = "example.gay : any --> 3000"; - let (_, Rule (_, effect)) = rule(input).unwrap(); - assert!(match effect { - Effect::Proxy { port: 3000, ssl } if !ssl => true, - _ => false, - }); - } + /// Test whether a simple rule gets parsed correctly. + #[test] + fn simple_rule () { + let input = "example.gay : any --> 3000"; + let (_, Rule (_, effect)) = rule(input).unwrap(); + assert!(match effect { + Effect::Proxy { port: 3000, ssl } if !ssl => true, + _ => false, + }); + } } } -/// A config consists of a set if [`Rule`]. -pub struct Config (Vec); - -/// Load a config from a path. -pub fn load (p: impl AsRef) -> std::io::Result { - std::fs::read_to_string(p.as_ref()) -} - -/// Parse a config string. -/// -/// Example config string: -/// -/// ```text -/// hmt.riley.lgbt : (80 | 443) --> 6000 # --> is proxy_pass -/// riley.lgbt : (80 | 443) --> 3000 [ssl] # add [ssl] to automate ssl for this domain -/// rly.cx : any ==> riley.lgbt # ==> is HTTP redirect -/// ``` -pub fn parse (data: String) -> Config { - let rules = data - .lines() - .map(parse::rule) - .filter_map(|x| match x { - Ok ((_, rule)) => Some (rule), - Err (e) => { - eprintln!("Error parsing rule: {:?}", e); - None - }, - }) - .collect(); - - Config (rules) -}