diff --git a/backend/Cargo.lock b/backend/Cargo.lock index 6f0e8bb..7e4dcdb 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -154,7 +154,7 @@ dependencies = [ "miniz_oxide", "object", "rustc-demangle", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -346,7 +346,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -656,9 +656,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.15" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f66d5bd4c6f02bf0542fad85d626775bab9258cf795a4256dcaf3161114d1df" +checksum = "8d9b05277c7e8da2c93a568989bb6207bef0112e8d17df7a6eda4a3cf143bc5e" dependencies = [ "base64", "bytes 1.10.1", @@ -672,7 +672,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2", + "socket2 0.6.0", "system-configuration", "tokio", "tower-service 0.3.3", @@ -799,9 +799,9 @@ dependencies = [ [[package]] name = "io-uring" -version = "0.7.8" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b86e202f00093dcba4275d4636b93ef9dd75d025ae560d2521b45ea28ab49013" +checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" dependencies = [ "bitflags 2.9.1", "cfg-if 1.0.1", @@ -1195,9 +1195,9 @@ checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if 1.0.1", "libc", - "redox_syscall 0.5.14", + "redox_syscall 0.5.15", "smallvec 1.15.1", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -1265,9 +1265,9 @@ checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce" [[package]] name = "redox_syscall" -version = "0.5.14" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3a5d9f0aba1dbcec1cc47f0ff94a4b778fe55bca98a6dfa92e4e094e57b1c4" +checksum = "7e8af0dde094006011e6a740d4879319439489813bd0bcdc7d821beaeeff48ec" dependencies = [ "bitflags 2.9.1", ] @@ -1395,7 +1395,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -1605,6 +1605,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -1722,7 +1732,7 @@ dependencies = [ "pin-project-lite", "signal-hook-registry", "slab", - "socket2", + "socket2 0.5.10", "tokio-macros", "windows-sys 0.52.0", ] @@ -1872,12 +1882,14 @@ version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" dependencies = [ + "base64", "bitflags 2.9.1", "bytes 1.10.1", "futures-util", "http 1.3.1", "http-body 1.0.1", "iri-string", + "mime", "pin-project-lite", "tokio", "tower", @@ -2229,7 +2241,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -2238,7 +2250,16 @@ version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.2", ] [[package]] @@ -2247,14 +2268,30 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c66f69fcc9ce11da9966ddb31a40968cad001c5bedeb5c2b82ede4253ab48aef" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", ] [[package]] @@ -2263,48 +2300,96 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + [[package]] name = "windows_i686_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "wit-bindgen-rt" version = "0.39.0" diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 123ac77..1ffe93c 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -11,7 +11,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.141" tokio = { version = "1.0", features = ["full"] } tower = { version = "0.5.2", features = ["full"] } -tower-http = { version = "0.6.6", features = ["timeout", "trace"] } +tower-http = { version = "0.6.6", features = ["timeout", "trace", "auth"] } tower-http-util = "0.1.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/backend/src/config.rs b/backend/src/config.rs new file mode 100644 index 0000000..4ae37c4 --- /dev/null +++ b/backend/src/config.rs @@ -0,0 +1,6 @@ +#[derive(Default, Clone)] +pub struct Config { + pub port: u32, + pub host: String, + pub admin_secret: Option, +} diff --git a/backend/src/lib.rs b/backend/src/lib.rs index bb6642f..b4c5374 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -1,3 +1,5 @@ +mod config; mod router; +pub use config::Config; pub use router::app; diff --git a/backend/src/main.rs b/backend/src/main.rs index bb67dfb..f1cf19f 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,6 +1,7 @@ use std::sync::mpsc; use clap::Parser; +use nuchat::Config; use nuchat::app; use tokio::net::TcpListener; use tokio::signal; @@ -17,11 +18,16 @@ struct Args { /// Host to run server on #[arg(long, default_value = "127.0.0.1")] host: String, + + /// Admin secret to use, leave blank to disable + #[arg(long)] + admin_secret: Option, } #[tokio::main] async fn main() { let args = Args::parse(); + let config = BinConfig::from_args(args); tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { @@ -31,12 +37,12 @@ async fn main() { .with(tracing_subscriber::fmt::layer().with_target(false)) .init(); - let listener = TcpListener::bind(format!("{}:{}", args.host, args.port)) + let listener = TcpListener::bind(format!("{}:{}", config.0.host, config.0.port)) .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); - let (app, rx) = app(); + let (app, rx) = app(&config.0); axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal(rx)) .await @@ -77,3 +83,14 @@ async fn shutdown_signal(rx: mpsc::Receiver) { } info!("Shutting server down gracefully..."); } + +struct BinConfig(Config); +impl BinConfig { + fn from_args(args: Args) -> Self { + Self(Config { + port: args.port, + host: args.host, + admin_secret: args.admin_secret, + }) + } +} diff --git a/backend/src/router.rs b/backend/src/router.rs index 6c83bcf..27abce8 100644 --- a/backend/src/router.rs +++ b/backend/src/router.rs @@ -1,27 +1,25 @@ +mod admin; mod healthcheck; use std::sync::mpsc; use std::time::Duration; +use crate::config; use axum::extract::Request; -use axum::middleware::{Next, from_fn}; -use axum::response::Response; -#[allow(unused_imports)] -use axum::routing::{get, post}; +use axum::routing::get; use axum::{Router, body::Body}; -use http::StatusCode; use tower::ServiceBuilder; use tower_http::timeout::TimeoutLayer; use tower_http::trace::TraceLayer; -use tracing::{Level, warn}; +use tracing::Level; use uuid::Uuid; -pub fn app() -> (Router, mpsc::Receiver) { +pub fn app(config: &config::Config) -> (Router, mpsc::Receiver) { let (tx, rx) = mpsc::channel(); ( Router::new() .route("/healthcheck", get(healthcheck::healthcheck)) .route("/forever", get(std::future::pending::<()>)) - .nest("/admin", admin(tx)) + .nest("/admin", admin::router(tx, config)) .layer( ServiceBuilder::new() .layer( @@ -40,75 +38,3 @@ pub fn app() -> (Router, mpsc::Receiver) { rx, ) } - -fn admin(tx: mpsc::Sender) -> Router { - let r = Router::new().route("/", get(async || StatusCode::OK)); - - let r = add_shutdown_endpoint(r, tx); - r.layer(from_fn(async |req: Request, next: Next| { - if let Ok(secret) = std::env::var("ADMIN_SECRET") { - match req.headers().get("Authorization") { - Some(key) if secret == *key => (), - Some(key) => { - warn!("Unauthorized request with key: {key:?}"); - return Response::builder() - .status(StatusCode::UNAUTHORIZED) - .body(Body::empty()) - .unwrap(); - } - _ => { - warn!("Unauthorized request no key given"); - return Response::builder() - .status(StatusCode::UNAUTHORIZED) - .body(Body::empty()) - .unwrap(); - } - } - } - - next.run(req).await - })) -} - -#[cfg(feature = "shutdown")] -fn add_shutdown_endpoint(r: Router, tx: mpsc::Sender) -> Router { - r.route( - "/shutdown", - post(async move || { - let res = tx.send(true); - if res.is_ok() { - StatusCode::OK - } else { - StatusCode::INTERNAL_SERVER_ERROR - } - }), - ) -} - -#[cfg(not(feature = "shutdown"))] -fn add_shutdown_endpoint(r: Router, _: mpsc::Sender) -> Router { - r -} - -#[cfg(test)] -mod tests { - use super::*; - use tower::{self, ServiceExt}; - - #[tokio::test] - async fn test_authorization_disables_when_no_env_var_set() { - let (app, _) = app(); - - let resp = app - .oneshot( - axum::http::Request::builder() - .uri("/admin") - .body(Body::empty()) - .unwrap(), - ) - .await - .unwrap(); - - assert_eq!(resp.status(), StatusCode::OK); - } -} diff --git a/backend/src/router/admin.rs b/backend/src/router/admin.rs new file mode 100644 index 0000000..9dfce06 --- /dev/null +++ b/backend/src/router/admin.rs @@ -0,0 +1,127 @@ +use std::sync::mpsc; + +use axum::Router; +#[allow(unused_imports)] +use axum::routing::{get, post}; +use http::StatusCode; +use tower_http::validate_request::ValidateRequestHeaderLayer; +use tracing::{info, warn}; + +pub fn router(tx: mpsc::Sender, config: &crate::Config) -> Router { + let r = Router::new().route("/", get(async || StatusCode::OK)); + + let r = add_shutdown_endpoint(r, tx); + if let Some(secret) = config.admin_secret.clone() { + info!("Enabled admin authorization"); + r.layer(ValidateRequestHeaderLayer::bearer(&secret)) + } else { + warn!("Admin authorization disabled"); + r + } +} + +#[cfg(feature = "shutdown")] +fn add_shutdown_endpoint(r: Router, tx: mpsc::Sender) -> Router { + r.route( + "/shutdown", + post(async move || { + let res = tx.send(true); + if res.is_ok() { + StatusCode::OK + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + }), + ) +} + +#[cfg(not(feature = "shutdown"))] +fn add_shutdown_endpoint(r: Router, _: mpsc::Sender) -> Router { + r +} + +#[cfg(test)] +mod test { + use axum::{body::Body, http::Request}; + use http::header; + use tower::ServiceExt; + + use crate::config; + + use super::*; + + #[tokio::test] + async fn test_authorization_disables_when_no_secret_set() { + let (tx, _) = mpsc::channel(); + + let resp = router(tx, &config::Config::default()) + .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + } + + #[tokio::test] + async fn test_authorization_unauthorized_no_bearer_token() { + let (tx, _) = mpsc::channel(); + + let conf = config::Config { + admin_secret: Some(String::from("1234")), + ..Default::default() + }; + + let resp = router(tx, &conf) + .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn test_authorization_unauthorized_invalid_bearer_token() { + let (tx, _) = mpsc::channel(); + + let conf = config::Config { + admin_secret: Some(String::from("1234")), + ..Default::default() + }; + + let resp = router(tx, &conf) + .oneshot( + Request::builder() + .uri("/") + .header(header::AUTHORIZATION, "bearer abcd") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn test_authorization_authorized_valid_bearer_token() { + let (tx, _) = mpsc::channel(); + + let conf = config::Config { + admin_secret: Some(String::from("1234")), + ..Default::default() + }; + + let resp = router(tx, &conf) + .oneshot( + Request::builder() + .uri("/") + .header(header::AUTHORIZATION, "Bearer 1234") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + } +}