diff --git a/Cargo.toml b/Cargo.toml index 11e23e4..47d51bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,10 @@ license = "AGPL-3.0+" [dependencies] clap = { version = "4.5", features = [ "cargo" ] } gitlab = "0.1610" -hyper = { version = "0.14", features = [ "full" ] } +hyper = { version = "1.4", features = [ "full" ] } +hyper-util = { version = "0.1", features = [ "tokio" ] } +http-body-util = "0.1" +bytes = "1.2" jid = { version = "*", features = [ "serde" ] } log = "0.4" tokio = { version = "1", features = [ "full" ] } diff --git a/src/main.rs b/src/main.rs index 964be89..07813dd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,7 +26,6 @@ use crate::error::Error; use crate::web::webhooks; use crate::webhook::WebHook; -use std::convert::Infallible; use std::fs::File; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read}; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; @@ -34,13 +33,11 @@ use std::path::{Path, PathBuf}; use std::sync::{Arc, Mutex}; use clap::{command, value_parser, Arg}; -use hyper::{ - service::{make_service_fn, service_fn}, - Server, -}; +use hyper::{server::conn::http1, service::service_fn}; +use hyper_util::rt::tokio::{TokioIo, TokioTimer}; use log::debug; use serde::{Deserialize, Serialize}; -use tokio::sync::mpsc; +use tokio::{net::TcpListener, sync::mpsc}; use xmpp::BareJid; #[derive(Debug, Serialize, Deserialize)] @@ -135,25 +132,6 @@ async fn main() -> Result { let (value_tx, mut value_rx) = mpsc::unbounded_channel::(); - if let Some(token) = config.webhook_token { - let value_tx = Arc::new(Mutex::new(value_tx)); - let make_svc = make_service_fn(move |_conn| { - let value_tx = value_tx.clone(); - let token = token.clone(); - async move { - Ok::<_, Infallible>(service_fn(move |req| { - let value_tx = value_tx.clone(); - let token = token.clone(); - webhooks(req, token, value_tx) - })) - } - }); - let server = Server::bind(&config.addr).serve(make_svc); - println!("Listening on http://{}", &config.addr); - - let _join = tokio::spawn(server); - } - let mut client = XmppClient::new( config.jid, config.password.as_str(), @@ -161,9 +139,35 @@ async fn main() -> Result { config.nickname, ); + let tcp_bind = TcpListener::bind(config.addr).await?; + let token: Option<&'static String> = + unsafe { core::mem::transmute(config.webhook_token.as_ref()) }; + let value_tx = Arc::new(Mutex::new(value_tx)); + loop { + let value_tx = value_tx.clone(); + tokio::select! { _ = client.next() => (), + accept = tcp_bind.accept() => { + if let Ok((tcp, _)) = accept { + let io = TokioIo::new(tcp); + tokio::task::spawn(async move { + if let Err(err) = http1::Builder::new() + .timer(TokioTimer::new()) + .serve_connection(io, service_fn(|request| { + let value_tx = value_tx.clone(); + async move { + webhooks(request, token, value_tx).await + } + })) + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } + } wh = value_rx.recv() => { if let Some(wh) = wh { client.webhook(wh).await diff --git a/src/web.rs b/src/web.rs index 5cb843b..1714a78 100644 --- a/src/web.rs +++ b/src/web.rs @@ -17,30 +17,36 @@ use crate::error::Error; use crate::webhook::WebHook; use std::convert::Infallible; -use std::str::from_utf8; use std::sync::{Arc, Mutex}; -use hyper::{body, header, Body, Method, Request, Response}; +use bytes::{Buf, Bytes}; +use http_body_util::{BodyExt, Full}; +use hyper::{body::Incoming, header, Method, Request, Response}; use log::{debug, error}; use tokio::sync::mpsc::UnboundedSender; -fn error_res(e: E) -> Result, Infallible> { +fn error_res(e: E) -> Result>, Infallible> { error!("error response: {:?}", e); let text = format!("{:?}", e); let res = Response::builder() .status(200) - .body(Body::from(Vec::from(text.as_bytes()))) + .body(Full::new(Bytes::from(text))) .unwrap(); Ok(res) } -async fn webhooks_inner(req: Request, token: &str) -> Result { +async fn webhooks_inner(req: Request, token: Option<&String>) -> Result { match req.method() { &Method::POST => (), _ => return Err(Error::MethodMismatch), } + if token.is_none() { + return Err(Error::InvalidToken); + } + let token: &str = token.unwrap(); + debug!("Headers: {:?}", req.headers()); let headers = req.headers(); @@ -56,23 +62,22 @@ async fn webhooks_inner(req: Request, token: &str) -> Result, - token: String, + req: Request, + token: Option<&String>, value_tx: Arc>>, -) -> Result, Infallible> { - match webhooks_inner(req, token.as_ref()).await { +) -> Result>, Infallible> { + match webhooks_inner(req, token).await { Ok(wh) => { debug!("Passed: {:?}", wh); value_tx.lock().unwrap().send(wh).unwrap(); - Ok(Response::new("Hello world".into())) + Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) } Err(err) => error_res(err), }