add database connection to backend

This commit is contained in:
2025-07-28 00:56:41 +01:00
parent 20f64cd35d
commit 79e43f19df
11 changed files with 1040 additions and 46 deletions

859
backend/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@ clap = { version = "4.5.41", features = ["derive"] }
http = "1.3.1"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.141"
sqlx = { version = "0.8.6", features = ["postgres", "macros", "runtime-tokio"] }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.5.2", features = ["full"] }
tower-http = { version = "0.6.6", features = ["timeout", "trace", "auth", "request-id"] }

View File

@ -0,0 +1,2 @@
DROP DATABASE IF EXISTS nuchat_test;
CREATE DATABASE nuchat_test;

View File

@ -1,11 +1,21 @@
#!/usr/bin/env bash
POSTGRES_URL=${POSTGRES_URL:-"postgresql://postgres:postgres@localhost:5432"}
if ! command -v cargo-nextest > /dev/null 2>&1; then
echo "Command not found cargo-nextest"
echo "Try installing with cargo install cargo-nextest"
exit 1
fi
psql "$POSTGRES_URL" -f ./scripts/create_test_db.sql
if [ "$?" -ne "0" ]; then
echo "Unable to connect to database, make sure it is started"
exit 1
fi
if [ ! -d logs ]; then
mkdir logs
fi
@ -14,7 +24,7 @@ fi
curl -s -X POST localhost:7001/admin/shutdown 2>&1 > /dev/null
# start server
cargo run -- --port 7001 2>&1 > logs/nuchat.log &
cargo run -- --port 7001 --postgres-url "$POSTGRES_URL" 2>&1 > logs/nuchat.log &
# run tests
cargo nextest run --color=always 2>&1 | tee logs/test-output.log

View File

@ -3,4 +3,6 @@ pub struct Config {
pub port: u32,
pub host: String,
pub admin_secret: Option<String>,
pub postgres_url: String,
pub database_name: String,
}

View File

@ -1,5 +1,7 @@
mod config;
mod router;
mod state;
pub use config::Config;
pub use router::app;
pub use state::{AppState, NuState};

View File

@ -1,8 +1,12 @@
use std::sync::mpsc;
use clap::Parser;
use nuchat::AppState;
use nuchat::Config;
use nuchat::NuState;
use nuchat::app;
use sqlx::Pool;
use sqlx::Postgres;
use tokio::net::TcpListener;
use tokio::signal;
use tracing::info;
@ -22,6 +26,14 @@ struct Args {
/// Admin secret to use, leave blank to disable
#[arg(long)]
admin_secret: Option<String>,
/// postgres base url, should container users and host info
#[arg(long, default_value = "postgres://postgres:postgres@localhost:5432/")]
postgres_url: String,
/// name of database to use
#[arg(long, default_value = "nuchat_dev")]
database: String,
}
#[tokio::main]
@ -37,16 +49,24 @@ async fn main() {
.with(tracing_subscriber::fmt::layer().with_target(false))
.init();
let pool = Pool::<Postgres>::connect(&config.0.postgres_url)
.await
.expect("Could not connect to database");
let listener = TcpListener::bind(format!("{}:{}", config.0.host, config.0.port))
.await
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
let (app, rx) = app(&config.0);
axum::serve(listener, app)
let state = AppState::new(NuState::new(pool.clone(), config.0));
let (app, rx) = app(&state);
axum::serve(listener, app.with_state(state))
.with_graceful_shutdown(shutdown_signal(rx))
.await
.unwrap();
pool.close().await;
info!("Server stopped");
}
@ -91,6 +111,8 @@ impl BinConfig {
port: args.port,
host: args.host,
admin_secret: args.admin_secret,
postgres_url: args.postgres_url,
database_name: args.database,
})
}
}

View File

@ -3,11 +3,11 @@ mod healthcheck;
use std::sync::mpsc;
use std::time::Duration;
use crate::config;
use crate::AppState;
use axum::extract::Request;
use axum::routing::get;
use axum::{Router, body::Body};
use http::HeaderName;
use http::{HeaderName, HeaderValue};
use tower::ServiceBuilder;
use tower_http::request_id::{MakeRequestId, RequestId, SetRequestIdLayer};
use tower_http::timeout::TimeoutLayer;
@ -26,13 +26,14 @@ impl MakeRequestId for RequestIdLayer {
}
}
pub fn app(config: &config::Config) -> (Router, mpsc::Receiver<bool>) {
pub fn app(state: &AppState) -> (Router<AppState>, mpsc::Receiver<bool>) {
let (tx, rx) = mpsc::channel();
(
Router::new()
.with_state(state.clone())
.route("/healthcheck", get(healthcheck::healthcheck))
.route("/forever", get(std::future::pending::<()>))
.nest("/admin", admin::router(tx, config))
.nest("/admin", admin::router(tx, state))
.layer(
ServiceBuilder::new()
.layer(SetRequestIdLayer::new(
@ -41,23 +42,15 @@ pub fn app(config: &config::Config) -> (Router, mpsc::Receiver<bool>) {
))
.layer(
TraceLayer::new_for_http().make_span_with(|req: &Request<Body>| {
if let Some(req_id) = req.headers().get("x-request-id") {
tracing::span!(
Level::DEBUG,
"request",
req_id = req_id.to_str().unwrap(),
method = format!("{}", req.method()),
uri = format!("{}", req.uri()),
)
} else {
tracing::span!(
Level::DEBUG,
"request",
req_id = "<missing>",
method = format!("{}", req.method()),
uri = format!("{}", req.uri()),
)
}
let default = HeaderValue::from_static("<missing>");
let req_id = req.headers().get("x-request-id").unwrap_or(&default);
tracing::span!(
Level::DEBUG,
"request",
req_id = req_id.to_str().unwrap(),
method = format!("{}", req.method()),
uri = format!("{}", req.uri()),
)
}),
)
.layer(TimeoutLayer::new(Duration::from_secs(10))),

View File

@ -7,13 +7,17 @@ use http::StatusCode;
use tower_http::validate_request::ValidateRequestHeaderLayer;
use tracing::{info, warn};
pub fn router(tx: mpsc::Sender<bool>, config: &crate::Config) -> Router {
let r = Router::new().route("/", get(async || StatusCode::OK));
use crate::AppState;
pub fn router(tx: mpsc::Sender<bool>, state: &AppState) -> Router<AppState> {
let r = Router::new()
.with_state(state.clone())
.route("/", get(async || StatusCode::OK));
let r = add_shutdown_endpoint(r, tx);
if let Some(secret) = config.admin_secret.clone() {
if let Some(secret) = &state.config.admin_secret {
info!("Enabled admin authorization");
r.layer(ValidateRequestHeaderLayer::bearer(&secret))
r.layer(ValidateRequestHeaderLayer::bearer(secret))
} else {
warn!("Admin authorization disabled");
r
@ -21,7 +25,7 @@ pub fn router(tx: mpsc::Sender<bool>, config: &crate::Config) -> Router {
}
#[cfg(feature = "shutdown")]
fn add_shutdown_endpoint(r: Router, tx: mpsc::Sender<bool>) -> Router {
fn add_shutdown_endpoint(r: Router<AppState>, tx: mpsc::Sender<bool>) -> Router<AppState> {
r.route(
"/shutdown",
post(async move || {
@ -44,17 +48,21 @@ fn add_shutdown_endpoint(r: Router, _: mpsc::Sender<bool>) -> Router {
mod test {
use axum::{body::Body, http::Request};
use http::header;
use sqlx::PgPool;
use tower::ServiceExt;
use crate::config;
use crate::{config, state::NuState};
use super::*;
#[tokio::test]
async fn test_authorization_disables_when_no_secret_set() {
#[sqlx::test]
async fn test_authorization_disables_when_no_secret_set(pool: PgPool) {
let (tx, _) = mpsc::channel();
let resp = router(tx, &config::Config::default())
let state = AppState::new(NuState::new(pool, config::Config::default()));
let resp = router(tx, &state)
.with_state(state)
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
@ -62,8 +70,8 @@ mod test {
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_authorization_unauthorized_no_bearer_token() {
#[sqlx::test]
async fn test_authorization_unauthorized_no_bearer_token(pool: PgPool) {
let (tx, _) = mpsc::channel();
let conf = config::Config {
@ -71,7 +79,10 @@ mod test {
..Default::default()
};
let resp = router(tx, &conf)
let state = AppState::new(NuState::new(pool, conf));
let resp = router(tx, &state)
.with_state(state)
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
@ -79,8 +90,8 @@ mod test {
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_authorization_unauthorized_invalid_bearer_token() {
#[sqlx::test]
async fn test_authorization_unauthorized_invalid_bearer_token(pool: PgPool) {
let (tx, _) = mpsc::channel();
let conf = config::Config {
@ -88,7 +99,10 @@ mod test {
..Default::default()
};
let resp = router(tx, &conf)
let state = AppState::new(NuState::new(pool, conf));
let resp = router(tx, &state)
.with_state(state)
.oneshot(
Request::builder()
.uri("/")
@ -102,8 +116,8 @@ mod test {
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_authorization_authorized_valid_bearer_token() {
#[sqlx::test]
async fn test_authorization_authorized_valid_bearer_token(pool: PgPool) {
let (tx, _) = mpsc::channel();
let conf = config::Config {
@ -111,7 +125,10 @@ mod test {
..Default::default()
};
let resp = router(tx, &conf)
let state = AppState::new(NuState::new(pool, conf));
let resp = router(tx, &state)
.with_state(state)
.oneshot(
Request::builder()
.uri("/")

View File

@ -1,6 +1,78 @@
use axum::extract::{self, State};
use axum::response::Json;
use http::StatusCode;
use serde_json::{Value, json};
use tracing::error;
pub async fn healthcheck() -> Json<Value> {
Json(json!({"healthy": true}))
use crate::AppState;
pub async fn healthcheck(State(s): extract::State<AppState>) -> Result<Json<Value>, StatusCode> {
sqlx::query!(
"select exists(SELECT datname FROM pg_catalog.pg_database WHERE datname = $1);",
s.config.database_name
)
.fetch_one(&s.db)
.await
.and_then(|x| x.exists.ok_or(sqlx::Error::RowNotFound))
.and_then(|db| {
if db {
Ok(Json(json!({"healthy": db})))
} else {
error!("Could not find configured database in postgres");
Err(sqlx::Error::RowNotFound)
}
})
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
#[cfg(test)]
mod test {
use axum::{Router, body::Body, routing::get};
use http::Request;
use sqlx::PgPool;
use tower::ServiceExt;
use crate::{Config, NuState};
use super::*;
#[sqlx::test]
async fn healthcheck_passes_with_db_connection(pool: PgPool) {
let state = AppState::new(NuState::new(
pool,
Config {
database_name: String::from("nuchat_test"),
..Config::default()
},
));
let resp = Router::new()
.route("/", get(healthcheck))
.with_state(state)
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[sqlx::test]
async fn healthcheck_fails_db_doesnt_exist(pool: PgPool) {
let state = AppState::new(NuState::new(
pool,
Config {
database_name: String::from("asdfasdfasdf"),
..Config::default()
},
));
let resp = Router::new()
.route("/", get(healthcheck))
.with_state(state)
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
}

18
backend/src/state.rs Normal file
View File

@ -0,0 +1,18 @@
use std::sync::Arc;
use sqlx::PgPool;
use crate::Config;
pub type AppState = Arc<NuState>;
#[derive(Clone)]
pub struct NuState {
pub db: sqlx::PgPool,
pub config: Config,
}
impl NuState {
pub fn new(db: PgPool, config: Config) -> Self {
Self { db, config }
}
}