refactor admin router into separate file
also set admin secret using command line args
This commit is contained in:
6
backend/src/config.rs
Normal file
6
backend/src/config.rs
Normal file
@ -0,0 +1,6 @@
|
||||
#[derive(Default, Clone)]
|
||||
pub struct Config {
|
||||
pub port: u32,
|
||||
pub host: String,
|
||||
pub admin_secret: Option<String>,
|
||||
}
|
||||
@ -1,3 +1,5 @@
|
||||
mod config;
|
||||
mod router;
|
||||
|
||||
pub use config::Config;
|
||||
pub use router::app;
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
use std::sync::mpsc;
|
||||
|
||||
use clap::Parser;
|
||||
use nuchat::Config;
|
||||
use nuchat::app;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
@ -17,11 +18,16 @@ struct Args {
|
||||
/// Host to run server on
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
host: String,
|
||||
|
||||
/// Admin secret to use, leave blank to disable
|
||||
#[arg(long)]
|
||||
admin_secret: Option<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let args = Args::parse();
|
||||
let config = BinConfig::from_args(args);
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
|
||||
@ -31,12 +37,12 @@ async fn main() {
|
||||
.with(tracing_subscriber::fmt::layer().with_target(false))
|
||||
.init();
|
||||
|
||||
let listener = TcpListener::bind(format!("{}:{}", args.host, args.port))
|
||||
let listener = TcpListener::bind(format!("{}:{}", config.0.host, config.0.port))
|
||||
.await
|
||||
.unwrap();
|
||||
tracing::debug!("listening on {}", listener.local_addr().unwrap());
|
||||
|
||||
let (app, rx) = app();
|
||||
let (app, rx) = app(&config.0);
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(shutdown_signal(rx))
|
||||
.await
|
||||
@ -77,3 +83,14 @@ async fn shutdown_signal(rx: mpsc::Receiver<bool>) {
|
||||
}
|
||||
info!("Shutting server down gracefully...");
|
||||
}
|
||||
|
||||
struct BinConfig(Config);
|
||||
impl BinConfig {
|
||||
fn from_args(args: Args) -> Self {
|
||||
Self(Config {
|
||||
port: args.port,
|
||||
host: args.host,
|
||||
admin_secret: args.admin_secret,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,27 +1,25 @@
|
||||
mod admin;
|
||||
mod healthcheck;
|
||||
use std::sync::mpsc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::config;
|
||||
use axum::extract::Request;
|
||||
use axum::middleware::{Next, from_fn};
|
||||
use axum::response::Response;
|
||||
#[allow(unused_imports)]
|
||||
use axum::routing::{get, post};
|
||||
use axum::routing::get;
|
||||
use axum::{Router, body::Body};
|
||||
use http::StatusCode;
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::timeout::TimeoutLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing::{Level, warn};
|
||||
use tracing::Level;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn app() -> (Router, mpsc::Receiver<bool>) {
|
||||
pub fn app(config: &config::Config) -> (Router, mpsc::Receiver<bool>) {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
(
|
||||
Router::new()
|
||||
.route("/healthcheck", get(healthcheck::healthcheck))
|
||||
.route("/forever", get(std::future::pending::<()>))
|
||||
.nest("/admin", admin(tx))
|
||||
.nest("/admin", admin::router(tx, config))
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(
|
||||
@ -40,75 +38,3 @@ pub fn app() -> (Router, mpsc::Receiver<bool>) {
|
||||
rx,
|
||||
)
|
||||
}
|
||||
|
||||
fn admin(tx: mpsc::Sender<bool>) -> Router {
|
||||
let r = Router::new().route("/", get(async || StatusCode::OK));
|
||||
|
||||
let r = add_shutdown_endpoint(r, tx);
|
||||
r.layer(from_fn(async |req: Request, next: Next| {
|
||||
if let Ok(secret) = std::env::var("ADMIN_SECRET") {
|
||||
match req.headers().get("Authorization") {
|
||||
Some(key) if secret == *key => (),
|
||||
Some(key) => {
|
||||
warn!("Unauthorized request with key: {key:?}");
|
||||
return Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
_ => {
|
||||
warn!("Unauthorized request no key given");
|
||||
return Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(feature = "shutdown")]
|
||||
fn add_shutdown_endpoint(r: Router, tx: mpsc::Sender<bool>) -> Router {
|
||||
r.route(
|
||||
"/shutdown",
|
||||
post(async move || {
|
||||
let res = tx.send(true);
|
||||
if res.is_ok() {
|
||||
StatusCode::OK
|
||||
} else {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "shutdown"))]
|
||||
fn add_shutdown_endpoint(r: Router, _: mpsc::Sender<bool>) -> Router {
|
||||
r
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tower::{self, ServiceExt};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authorization_disables_when_no_env_var_set() {
|
||||
let (app, _) = app();
|
||||
|
||||
let resp = app
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.uri("/admin")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
}
|
||||
}
|
||||
|
||||
127
backend/src/router/admin.rs
Normal file
127
backend/src/router/admin.rs
Normal file
@ -0,0 +1,127 @@
|
||||
use std::sync::mpsc;
|
||||
|
||||
use axum::Router;
|
||||
#[allow(unused_imports)]
|
||||
use axum::routing::{get, post};
|
||||
use http::StatusCode;
|
||||
use tower_http::validate_request::ValidateRequestHeaderLayer;
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub fn router(tx: mpsc::Sender<bool>, config: &crate::Config) -> Router {
|
||||
let r = Router::new().route("/", get(async || StatusCode::OK));
|
||||
|
||||
let r = add_shutdown_endpoint(r, tx);
|
||||
if let Some(secret) = config.admin_secret.clone() {
|
||||
info!("Enabled admin authorization");
|
||||
r.layer(ValidateRequestHeaderLayer::bearer(&secret))
|
||||
} else {
|
||||
warn!("Admin authorization disabled");
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "shutdown")]
|
||||
fn add_shutdown_endpoint(r: Router, tx: mpsc::Sender<bool>) -> Router {
|
||||
r.route(
|
||||
"/shutdown",
|
||||
post(async move || {
|
||||
let res = tx.send(true);
|
||||
if res.is_ok() {
|
||||
StatusCode::OK
|
||||
} else {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "shutdown"))]
|
||||
fn add_shutdown_endpoint(r: Router, _: mpsc::Sender<bool>) -> Router {
|
||||
r
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use axum::{body::Body, http::Request};
|
||||
use http::header;
|
||||
use tower::ServiceExt;
|
||||
|
||||
use crate::config;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authorization_disables_when_no_secret_set() {
|
||||
let (tx, _) = mpsc::channel();
|
||||
|
||||
let resp = router(tx, &config::Config::default())
|
||||
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authorization_unauthorized_no_bearer_token() {
|
||||
let (tx, _) = mpsc::channel();
|
||||
|
||||
let conf = config::Config {
|
||||
admin_secret: Some(String::from("1234")),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let resp = router(tx, &conf)
|
||||
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authorization_unauthorized_invalid_bearer_token() {
|
||||
let (tx, _) = mpsc::channel();
|
||||
|
||||
let conf = config::Config {
|
||||
admin_secret: Some(String::from("1234")),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let resp = router(tx, &conf)
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/")
|
||||
.header(header::AUTHORIZATION, "bearer abcd")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authorization_authorized_valid_bearer_token() {
|
||||
let (tx, _) = mpsc::channel();
|
||||
|
||||
let conf = config::Config {
|
||||
admin_secret: Some(String::from("1234")),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let resp = router(tx, &conf)
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/")
|
||||
.header(header::AUTHORIZATION, "Bearer 1234")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user