add database connection to backend
This commit is contained in:
859
backend/Cargo.lock
generated
859
backend/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -9,6 +9,7 @@ clap = { version = "4.5.41", features = ["derive"] }
|
||||
http = "1.3.1"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0.141"
|
||||
sqlx = { version = "0.8.6", features = ["postgres", "macros", "runtime-tokio"] }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
tower = { version = "0.5.2", features = ["full"] }
|
||||
tower-http = { version = "0.6.6", features = ["timeout", "trace", "auth", "request-id"] }
|
||||
|
||||
2
backend/scripts/create_test_db.sql
Normal file
2
backend/scripts/create_test_db.sql
Normal file
@ -0,0 +1,2 @@
|
||||
DROP DATABASE IF EXISTS nuchat_test;
|
||||
CREATE DATABASE nuchat_test;
|
||||
@ -1,11 +1,21 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
POSTGRES_URL=${POSTGRES_URL:-"postgresql://postgres:postgres@localhost:5432"}
|
||||
|
||||
if ! command -v cargo-nextest > /dev/null 2>&1; then
|
||||
echo "Command not found cargo-nextest"
|
||||
echo "Try installing with cargo install cargo-nextest"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
psql "$POSTGRES_URL" -f ./scripts/create_test_db.sql
|
||||
|
||||
if [ "$?" -ne "0" ]; then
|
||||
echo "Unable to connect to database, make sure it is started"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -d logs ]; then
|
||||
mkdir logs
|
||||
fi
|
||||
@ -14,7 +24,7 @@ fi
|
||||
curl -s -X POST localhost:7001/admin/shutdown 2>&1 > /dev/null
|
||||
|
||||
# start server
|
||||
cargo run -- --port 7001 2>&1 > logs/nuchat.log &
|
||||
cargo run -- --port 7001 --postgres-url "$POSTGRES_URL" 2>&1 > logs/nuchat.log &
|
||||
|
||||
# run tests
|
||||
cargo nextest run --color=always 2>&1 | tee logs/test-output.log
|
||||
|
||||
@ -3,4 +3,6 @@ pub struct Config {
|
||||
pub port: u32,
|
||||
pub host: String,
|
||||
pub admin_secret: Option<String>,
|
||||
pub postgres_url: String,
|
||||
pub database_name: String,
|
||||
}
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
mod config;
|
||||
mod router;
|
||||
mod state;
|
||||
|
||||
pub use config::Config;
|
||||
pub use router::app;
|
||||
pub use state::{AppState, NuState};
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
use std::sync::mpsc;
|
||||
|
||||
use clap::Parser;
|
||||
use nuchat::AppState;
|
||||
use nuchat::Config;
|
||||
use nuchat::NuState;
|
||||
use nuchat::app;
|
||||
use sqlx::Pool;
|
||||
use sqlx::Postgres;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
use tracing::info;
|
||||
@ -22,6 +26,14 @@ struct Args {
|
||||
/// Admin secret to use, leave blank to disable
|
||||
#[arg(long)]
|
||||
admin_secret: Option<String>,
|
||||
|
||||
/// postgres base url, should container users and host info
|
||||
#[arg(long, default_value = "postgres://postgres:postgres@localhost:5432/")]
|
||||
postgres_url: String,
|
||||
|
||||
/// name of database to use
|
||||
#[arg(long, default_value = "nuchat_dev")]
|
||||
database: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@ -37,16 +49,24 @@ async fn main() {
|
||||
.with(tracing_subscriber::fmt::layer().with_target(false))
|
||||
.init();
|
||||
|
||||
let pool = Pool::<Postgres>::connect(&config.0.postgres_url)
|
||||
.await
|
||||
.expect("Could not connect to database");
|
||||
|
||||
let listener = TcpListener::bind(format!("{}:{}", config.0.host, config.0.port))
|
||||
.await
|
||||
.unwrap();
|
||||
tracing::debug!("listening on {}", listener.local_addr().unwrap());
|
||||
|
||||
let (app, rx) = app(&config.0);
|
||||
axum::serve(listener, app)
|
||||
let state = AppState::new(NuState::new(pool.clone(), config.0));
|
||||
let (app, rx) = app(&state);
|
||||
axum::serve(listener, app.with_state(state))
|
||||
.with_graceful_shutdown(shutdown_signal(rx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
pool.close().await;
|
||||
|
||||
info!("Server stopped");
|
||||
}
|
||||
|
||||
@ -91,6 +111,8 @@ impl BinConfig {
|
||||
port: args.port,
|
||||
host: args.host,
|
||||
admin_secret: args.admin_secret,
|
||||
postgres_url: args.postgres_url,
|
||||
database_name: args.database,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,11 +3,11 @@ mod healthcheck;
|
||||
use std::sync::mpsc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::config;
|
||||
use crate::AppState;
|
||||
use axum::extract::Request;
|
||||
use axum::routing::get;
|
||||
use axum::{Router, body::Body};
|
||||
use http::HeaderName;
|
||||
use http::{HeaderName, HeaderValue};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::request_id::{MakeRequestId, RequestId, SetRequestIdLayer};
|
||||
use tower_http::timeout::TimeoutLayer;
|
||||
@ -26,13 +26,14 @@ impl MakeRequestId for RequestIdLayer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn app(config: &config::Config) -> (Router, mpsc::Receiver<bool>) {
|
||||
pub fn app(state: &AppState) -> (Router<AppState>, mpsc::Receiver<bool>) {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
(
|
||||
Router::new()
|
||||
.with_state(state.clone())
|
||||
.route("/healthcheck", get(healthcheck::healthcheck))
|
||||
.route("/forever", get(std::future::pending::<()>))
|
||||
.nest("/admin", admin::router(tx, config))
|
||||
.nest("/admin", admin::router(tx, state))
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(SetRequestIdLayer::new(
|
||||
@ -41,7 +42,8 @@ pub fn app(config: &config::Config) -> (Router, mpsc::Receiver<bool>) {
|
||||
))
|
||||
.layer(
|
||||
TraceLayer::new_for_http().make_span_with(|req: &Request<Body>| {
|
||||
if let Some(req_id) = req.headers().get("x-request-id") {
|
||||
let default = HeaderValue::from_static("<missing>");
|
||||
let req_id = req.headers().get("x-request-id").unwrap_or(&default);
|
||||
tracing::span!(
|
||||
Level::DEBUG,
|
||||
"request",
|
||||
@ -49,15 +51,6 @@ pub fn app(config: &config::Config) -> (Router, mpsc::Receiver<bool>) {
|
||||
method = format!("{}", req.method()),
|
||||
uri = format!("{}", req.uri()),
|
||||
)
|
||||
} else {
|
||||
tracing::span!(
|
||||
Level::DEBUG,
|
||||
"request",
|
||||
req_id = "<missing>",
|
||||
method = format!("{}", req.method()),
|
||||
uri = format!("{}", req.uri()),
|
||||
)
|
||||
}
|
||||
}),
|
||||
)
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(10))),
|
||||
|
||||
@ -7,13 +7,17 @@ use http::StatusCode;
|
||||
use tower_http::validate_request::ValidateRequestHeaderLayer;
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub fn router(tx: mpsc::Sender<bool>, config: &crate::Config) -> Router {
|
||||
let r = Router::new().route("/", get(async || StatusCode::OK));
|
||||
use crate::AppState;
|
||||
|
||||
pub fn router(tx: mpsc::Sender<bool>, state: &AppState) -> Router<AppState> {
|
||||
let r = Router::new()
|
||||
.with_state(state.clone())
|
||||
.route("/", get(async || StatusCode::OK));
|
||||
|
||||
let r = add_shutdown_endpoint(r, tx);
|
||||
if let Some(secret) = config.admin_secret.clone() {
|
||||
if let Some(secret) = &state.config.admin_secret {
|
||||
info!("Enabled admin authorization");
|
||||
r.layer(ValidateRequestHeaderLayer::bearer(&secret))
|
||||
r.layer(ValidateRequestHeaderLayer::bearer(secret))
|
||||
} else {
|
||||
warn!("Admin authorization disabled");
|
||||
r
|
||||
@ -21,7 +25,7 @@ pub fn router(tx: mpsc::Sender<bool>, config: &crate::Config) -> Router {
|
||||
}
|
||||
|
||||
#[cfg(feature = "shutdown")]
|
||||
fn add_shutdown_endpoint(r: Router, tx: mpsc::Sender<bool>) -> Router {
|
||||
fn add_shutdown_endpoint(r: Router<AppState>, tx: mpsc::Sender<bool>) -> Router<AppState> {
|
||||
r.route(
|
||||
"/shutdown",
|
||||
post(async move || {
|
||||
@ -44,17 +48,21 @@ fn add_shutdown_endpoint(r: Router, _: mpsc::Sender<bool>) -> Router {
|
||||
mod test {
|
||||
use axum::{body::Body, http::Request};
|
||||
use http::header;
|
||||
use sqlx::PgPool;
|
||||
use tower::ServiceExt;
|
||||
|
||||
use crate::config;
|
||||
use crate::{config, state::NuState};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authorization_disables_when_no_secret_set() {
|
||||
#[sqlx::test]
|
||||
async fn test_authorization_disables_when_no_secret_set(pool: PgPool) {
|
||||
let (tx, _) = mpsc::channel();
|
||||
|
||||
let resp = router(tx, &config::Config::default())
|
||||
let state = AppState::new(NuState::new(pool, config::Config::default()));
|
||||
|
||||
let resp = router(tx, &state)
|
||||
.with_state(state)
|
||||
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
@ -62,8 +70,8 @@ mod test {
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authorization_unauthorized_no_bearer_token() {
|
||||
#[sqlx::test]
|
||||
async fn test_authorization_unauthorized_no_bearer_token(pool: PgPool) {
|
||||
let (tx, _) = mpsc::channel();
|
||||
|
||||
let conf = config::Config {
|
||||
@ -71,7 +79,10 @@ mod test {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let resp = router(tx, &conf)
|
||||
let state = AppState::new(NuState::new(pool, conf));
|
||||
|
||||
let resp = router(tx, &state)
|
||||
.with_state(state)
|
||||
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
@ -79,8 +90,8 @@ mod test {
|
||||
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authorization_unauthorized_invalid_bearer_token() {
|
||||
#[sqlx::test]
|
||||
async fn test_authorization_unauthorized_invalid_bearer_token(pool: PgPool) {
|
||||
let (tx, _) = mpsc::channel();
|
||||
|
||||
let conf = config::Config {
|
||||
@ -88,7 +99,10 @@ mod test {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let resp = router(tx, &conf)
|
||||
let state = AppState::new(NuState::new(pool, conf));
|
||||
|
||||
let resp = router(tx, &state)
|
||||
.with_state(state)
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/")
|
||||
@ -102,8 +116,8 @@ mod test {
|
||||
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authorization_authorized_valid_bearer_token() {
|
||||
#[sqlx::test]
|
||||
async fn test_authorization_authorized_valid_bearer_token(pool: PgPool) {
|
||||
let (tx, _) = mpsc::channel();
|
||||
|
||||
let conf = config::Config {
|
||||
@ -111,7 +125,10 @@ mod test {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let resp = router(tx, &conf)
|
||||
let state = AppState::new(NuState::new(pool, conf));
|
||||
|
||||
let resp = router(tx, &state)
|
||||
.with_state(state)
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/")
|
||||
|
||||
@ -1,6 +1,78 @@
|
||||
use axum::extract::{self, State};
|
||||
use axum::response::Json;
|
||||
use http::StatusCode;
|
||||
use serde_json::{Value, json};
|
||||
use tracing::error;
|
||||
|
||||
pub async fn healthcheck() -> Json<Value> {
|
||||
Json(json!({"healthy": true}))
|
||||
use crate::AppState;
|
||||
|
||||
pub async fn healthcheck(State(s): extract::State<AppState>) -> Result<Json<Value>, StatusCode> {
|
||||
sqlx::query!(
|
||||
"select exists(SELECT datname FROM pg_catalog.pg_database WHERE datname = $1);",
|
||||
s.config.database_name
|
||||
)
|
||||
.fetch_one(&s.db)
|
||||
.await
|
||||
.and_then(|x| x.exists.ok_or(sqlx::Error::RowNotFound))
|
||||
.and_then(|db| {
|
||||
if db {
|
||||
Ok(Json(json!({"healthy": db})))
|
||||
} else {
|
||||
error!("Could not find configured database in postgres");
|
||||
Err(sqlx::Error::RowNotFound)
|
||||
}
|
||||
})
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use axum::{Router, body::Body, routing::get};
|
||||
use http::Request;
|
||||
use sqlx::PgPool;
|
||||
use tower::ServiceExt;
|
||||
|
||||
use crate::{Config, NuState};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[sqlx::test]
|
||||
async fn healthcheck_passes_with_db_connection(pool: PgPool) {
|
||||
let state = AppState::new(NuState::new(
|
||||
pool,
|
||||
Config {
|
||||
database_name: String::from("nuchat_test"),
|
||||
..Config::default()
|
||||
},
|
||||
));
|
||||
|
||||
let resp = Router::new()
|
||||
.route("/", get(healthcheck))
|
||||
.with_state(state)
|
||||
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[sqlx::test]
|
||||
async fn healthcheck_fails_db_doesnt_exist(pool: PgPool) {
|
||||
let state = AppState::new(NuState::new(
|
||||
pool,
|
||||
Config {
|
||||
database_name: String::from("asdfasdfasdf"),
|
||||
..Config::default()
|
||||
},
|
||||
));
|
||||
|
||||
let resp = Router::new()
|
||||
.route("/", get(healthcheck))
|
||||
.with_state(state)
|
||||
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
18
backend/src/state.rs
Normal file
18
backend/src/state.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::Config;
|
||||
|
||||
pub type AppState = Arc<NuState>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NuState {
|
||||
pub db: sqlx::PgPool,
|
||||
pub config: Config,
|
||||
}
|
||||
impl NuState {
|
||||
pub fn new(db: PgPool, config: Config) -> Self {
|
||||
Self { db, config }
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user