diff --git a/Cargo.toml b/Cargo.toml
index 8fad34d..07990d4 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -10,9 +10,12 @@ license = "AGPL-3.0+"
clap = { version = "4.3", features = [ "cargo", "derive" ] }
gitlab = "0.1511.0"
hyper = { version = "0.14", features = [ "full" ] }
+jid = { version = "*", features = [ "serde" ] }
log = "0.4"
tokio = { version = "1", features = [ "full" ] }
pretty_env_logger = "0.5"
+serde = { version = "1.0", features = [ "derive" ] }
serde_json = "1.0"
+toml = "0.7"
xmpp = "0.4"
xmpp-parsers = "0.19"
diff --git a/src/error.rs b/src/error.rs
index 1ec5813..b7ce3c0 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -13,7 +13,9 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
+use std::env::VarError;
use std::error::Error as StdError;
+use std::io::Error as IoError;
use std::str::Utf8Error;
#[derive(Debug)]
@@ -22,8 +24,11 @@ pub(crate) enum Error {
InvalidToken,
InvalidContentType,
Hyper(hyper::Error),
+ Io(IoError),
SerdeJson(serde_json::Error),
+ Toml(toml::de::Error),
Utf8(Utf8Error),
+ Var(VarError),
}
impl StdError for Error {}
@@ -35,8 +40,11 @@ impl std::fmt::Display for Error {
Error::InvalidToken => write!(fmt, "the token is invalid"),
Error::InvalidContentType => write!(fmt, "the content-type is invalid"),
Error::Hyper(e) => write!(fmt, "hyper error: {}", e),
+ Error::Io(e) => write!(fmt, "Io error: {}", e),
Error::SerdeJson(e) => write!(fmt, "serde_json error: {}", e),
+ Error::Toml(e) => write!(fmt, "toml deserialization error: {}", e),
Error::Utf8(e) => write!(fmt, "Utf8 error: {}", e),
+ Error::Var(e) => write!(fmt, "Var error: {}", e),
}
}
}
@@ -47,14 +55,32 @@ impl From for Error {
}
}
+impl From for Error {
+ fn from(err: IoError) -> Error {
+ Error::Io(err)
+ }
+}
+
impl From for Error {
fn from(err: serde_json::Error) -> Error {
Error::SerdeJson(err)
}
}
+impl From for Error {
+ fn from(err: toml::de::Error) -> Error {
+ Error::Toml(err)
+ }
+}
+
impl From for Error {
fn from(err: Utf8Error) -> Error {
Error::Utf8(err)
}
}
+
+impl From for Error {
+ fn from(err: VarError) -> Error {
+ Error::Var(err)
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index 2cef85f..2c8680b 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -14,18 +14,23 @@
// along with this program. If not, see .
#![feature(let_chains)]
+#![feature(never_type)]
mod error;
mod web;
mod webhook;
mod xmpp;
+use crate::error::Error;
use crate::web::webhooks;
use crate::webhook::WebHook;
use crate::xmpp::XmppClient;
use std::convert::Infallible;
-use std::net::SocketAddr;
+use std::fs::File;
+use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read};
+use std::net::{IpAddr, Ipv6Addr, SocketAddr};
+use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use clap::Parser;
@@ -33,45 +38,100 @@ use hyper::{
service::{make_service_fn, service_fn},
Server,
};
+use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use xmpp_parsers::BareJid;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
- /// Account address
+ /// Config file path
#[arg(short, long)]
+ config: Option,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct Config {
+ /// Account address
jid: BareJid,
/// Account password
- #[arg(short, long)]
password: String,
/// Rooms to join, e.g., room@chat.example.org
- #[arg(short, long = "room", value_name = "ROOM")]
+ #[serde(default = "Vec::new")]
rooms: Vec,
/// Nickname to use in rooms
- #[arg(short, long, default_value = "bot")]
+ #[serde(default = "default_nickname")]
nickname: String,
/// Token to match the one provided by the Webhook service
- #[arg(short, long)]
+ #[serde(rename = "webhook-token")]
webhook_token: Option,
/// HTTP Webhook listening address and port, e.g., 127.0.0.1:1234 or [::1]:1234
- #[arg(long, default_value = "[::1]:3000")]
+ #[serde(default = "default_addr")]
addr: SocketAddr,
}
+fn default_nickname() -> String {
+ String::from("cusku")
+}
+
+fn default_addr() -> SocketAddr {
+ SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 3000)
+}
+
+fn config_from_file(file: PathBuf) -> Result {
+ if file.try_exists().is_err() {
+ let err = IoError::new(IoErrorKind::NotFound, format!("{:?} not found", file));
+ return Err(Error::Io(err));
+ }
+
+ let mut buf = String::new();
+ let mut f = File::open(file)?;
+ f.read_to_string(&mut buf)?;
+
+ Ok(toml::from_str(&buf)?)
+}
+
#[tokio::main]
-async fn main() {
+async fn main() -> Result {
pretty_env_logger::init();
let args = Args::parse();
+ let config = {
+ let path = match args.config {
+ Some(path) => {
+ if !path.starts_with("/") {
+ std::env::current_dir()?.join(path)
+ } else {
+ path
+ }
+ }
+ None => {
+ let confdir: PathBuf = match std::env::var("XDG_CONFIG_HOME") {
+ Ok(ref dir) => Path::new(dir).to_path_buf(),
+ Err(_) => {
+ let home = std::env::var("HOME")?;
+ Path::new(home.as_str()).join(".config")
+ }
+ };
+
+ confdir.join("cusku/config.toml")
+ }
+ };
+
+ match config_from_file(path) {
+ Ok(config) => config,
+ Err(err) => return Err(err),
+ }
+ };
+
let (value_tx, mut value_rx) = mpsc::unbounded_channel::();
- if let Some(token) = args.webhook_token {
+ if let Some(token) = config.webhook_token {
let value_tx = Arc::new(Mutex::new(value_tx));
let make_svc = make_service_fn(move |_conn| {
let value_tx = value_tx.clone();
@@ -84,17 +144,17 @@ async fn main() {
}))
}
});
- let server = Server::bind(&args.addr).serve(make_svc);
- println!("Listening on http://{}", &args.addr);
+ let server = Server::bind(&config.addr).serve(make_svc);
+ println!("Listening on http://{}", &config.addr);
let _join = tokio::spawn(server);
}
let mut client = XmppClient::new(
- &String::from(args.jid),
- args.password.as_str(),
- args.rooms,
- args.nickname,
+ &String::from(config.jid),
+ config.password.as_str(),
+ config.rooms,
+ config.nickname,
);
loop {