use axum::{ extract::{Request, State}, http::{header::HeaderMap, StatusCode}, middleware::Next, response::Response, }; use sha2::{Digest, Sha256}; use std::sync::Arc; use crate::core::db::postgres_db::ApiKeyRecord; use crate::core::db::PostgresDb; #[derive(Clone)] pub struct ApiKeyAuth { pub key_id: String, pub record: ApiKeyRecord, } #[derive(Clone)] pub struct ApiState { pub db: Arc, } const PUBLIC_PATHS: &[&str] = &[ "/api/v1/faces/", // Thumbnail paths (partial match) ]; fn is_public_path(path: &str) -> bool { PUBLIC_PATHS.iter().any(|prefix| path.starts_with(prefix)) && path.ends_with("/thumbnail") } pub async fn api_key_validation( State(state): State, request: Request, next: Next, ) -> Response { let path = request.uri().path(); tracing::info!("[MIDDLEWARE] Starting API key validation"); tracing::info!("[MIDDLEWARE] Path: {:?}", path); if is_public_path(path) { tracing::info!("[MIDDLEWARE] Public path, skipping auth: {}", path); return next.run(request).await; } let headers = request.headers(); tracing::info!("[MIDDLEWARE] All headers: {:?}", headers); let uri = request.uri().clone(); let api_key = match extract_api_key(headers, &uri) { Ok(key) => { tracing::info!("[MIDDLEWARE] API key extracted, length: {}", key.len()); if key.len() > 8 { tracing::info!( "[MIDDLEWARE] Key value: {}...{}", &key[..4], &key[key.len() - 4..] ); } else { tracing::info!("[MIDDLEWARE] Key value: ****"); } key } Err(status) => { tracing::warn!("[MIDDLEWARE] API key extraction failed: {:?}", status); return Response::builder() .status(status) .body(axum::body::Body::empty()) .unwrap(); } }; let key_hash = hash_key(&api_key); tracing::info!("[MIDDLEWARE] Key hash: {}", &key_hash[..16]); tracing::info!("[MIDDLEWARE] Querying database for key..."); let record = match state.db.get_api_key_by_hash(&key_hash).await { Ok(Some(r)) => { tracing::info!("[MIDDLEWARE] API key found: {}", r.key_id); r } Ok(None) => { tracing::warn!( "[MIDDLEWARE] API key NOT FOUND in database for hash: {}", &key_hash[..16] ); return Response::builder() .status(StatusCode::UNAUTHORIZED) .body(axum::body::Body::empty()) .unwrap(); } Err(e) => { tracing::error!("[MIDDLEWARE] DB error: {}", e); return Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(axum::body::Body::empty()) .unwrap(); } }; if record.status != "active" { tracing::warn!("[MIDDLEWARE] API key not active: {}", record.status); return Response::builder() .status(StatusCode::UNAUTHORIZED) .body(axum::body::Body::empty()) .unwrap(); } tracing::info!( "[MIDDLEWARE] API key validated successfully: {}", record.key_id ); let auth = ApiKeyAuth { key_id: record.key_id.clone(), record, }; if let Err(e) = state.db.update_api_key_usage(&auth.key_id, None).await { tracing::warn!("[MIDDLEWARE] Failed to update API key usage: {}", e); } let mut request = request; request.extensions_mut().insert(auth); tracing::info!("[MIDDLEWARE] Passing request to handler"); let response = next.run(request).await; tracing::info!("[MIDDLEWARE] Handler returned response"); response } fn extract_api_key(headers: &HeaderMap, uri: &axum::http::Uri) -> Result { // 1. X-API-Key header if let Some(key) = headers .get("X-API-Key") .and_then(|v| v.to_str().ok()) { return Ok(key.to_string()); } // 2. Authorization: Bearer if let Some(auth) = headers .get("Authorization") .and_then(|v| v.to_str().ok()) { if let Some(key) = auth.strip_prefix("Bearer ") { return Ok(key.to_string()); } } // 3. ?api_key= query parameter if let Some(query) = uri.query() { for pair in query.split('&') { let mut parts = pair.splitn(2, '='); if let (Some(k), Some(v)) = (parts.next(), parts.next()) { if k == "api_key" { return Ok(percent_decode(v)); } } } } Err(StatusCode::UNAUTHORIZED) } fn percent_decode(s: &str) -> String { let mut result = String::new(); let mut chars = s.bytes(); while let Some(b) = chars.next() { match b { b'%' => { let hi = chars.next().and_then(|c| hex_val(c)).unwrap_or(0); let lo = chars.next().and_then(|c| hex_val(c)).unwrap_or(0); result.push((hi << 4 | lo) as char); } b'+' => result.push(' '), _ => result.push(b as char), } } result } fn hex_val(c: u8) -> Option { match c { b'0'..=b'9' => Some(c - b'0'), b'a'..=b'f' => Some(c - b'a' + 10), b'A'..=b'F' => Some(c - b'A' + 10), _ => None, } } fn hash_key(key: &str) -> String { let mut hasher = Sha256::new(); hasher.update(key.as_bytes()); format!("{:x}", hasher.finalize()) }