use axum::{ extract::{Query, State}, http::StatusCode, response::{IntoResponse, Response}, routing::{get, post}, Json, Router, }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use crate::core::processor::{classify_image, classify_images, detect_objects, ClipPrediction}; use crate::api::types::AppState; #[derive(Debug, Deserialize)] pub struct ClassifyRequest { image_path: String, labels: String, #[serde(default = "default_top_k")] top_k: usize, #[serde(default)] model: Option, } fn default_top_k() -> usize { 5 } #[derive(Debug, Deserialize)] pub struct DetectRequest { image_path: String, objects: String, #[serde(default = "default_threshold")] threshold: f32, #[serde(default)] model: Option, } fn default_threshold() -> f32 { 0.15 } #[derive(Debug, Deserialize)] pub struct BatchClassifyRequest { image_paths: String, labels: String, #[serde(default = "default_top_k")] top_k: usize, #[serde(default)] model: Option, } #[derive(Debug, Serialize)] pub struct ClassifyResponse { success: bool, predictions: Vec, } #[derive(Debug, Serialize)] pub struct DetectResponse { success: bool, detected: Vec, } #[derive(Debug, Serialize)] pub struct BatchClassifyResponse { success: bool, results: HashMap>, } #[derive(Debug, Serialize)] pub struct ErrorResponse { success: bool, error: String, } pub fn clip_routes() -> Router { Router::new() .route("/api/v1/clip/classify", post(classify_image_endpoint)) .route("/api/v1/clip/detect", post(detect_objects_endpoint)) .route("/api/v1/clip/batch", post(batch_classify_endpoint)) } async fn classify_image_endpoint( State(_state): State, Json(req): Json, ) -> Response { let labels: Vec<&str> = req.labels.split(',').map(|s| s.trim()).collect(); let result = classify_image( &req.image_path, &labels, Some(req.top_k), req.model.as_deref(), ).await; match result { Ok(predictions) => { tracing::info!( "[CLIP_API] Classified {} -> top: {} ({:.3})", req.image_path, predictions.first().map(|p| p.label.as_str()).unwrap_or("none"), predictions.first().map(|p| p.confidence).unwrap_or(0.0) ); Json(ClassifyResponse { success: true, predictions, }).into_response() } Err(e) => { tracing::error!("[CLIP_API] Classification failed: {}", e); Json(ErrorResponse { success: false, error: e.to_string(), }).into_response() } } } async fn detect_objects_endpoint( State(_state): State, Json(req): Json, ) -> Response { let objects: Vec<&str> = req.objects.split(',').map(|s| s.trim()).collect(); let result = detect_objects( &req.image_path, &objects, Some(req.threshold), req.model.as_deref(), ).await; match result { Ok(detected) => { if !detected.is_empty() { tracing::info!( "[CLIP_API] Detected {} objects in {}: {}", detected.len(), req.image_path, detected.iter().map(|p| p.label.as_str()).collect::>().join(", ") ); } else { tracing::info!("[CLIP_API] No objects detected in {} (threshold: {:.2})", req.image_path, req.threshold); } Json(DetectResponse { success: true, detected, }).into_response() } Err(e) => { tracing::error!("[CLIP_API] Detection failed: {}", e); Json(ErrorResponse { success: false, error: e.to_string(), }).into_response() } } } async fn batch_classify_endpoint( State(_state): State, Json(req): Json, ) -> Response { let image_paths: Vec<&str> = req.image_paths.split(',').map(|s| s.trim()).collect(); let labels: Vec<&str> = req.labels.split(',').map(|s| s.trim()).collect(); let result = classify_images( &image_paths, &labels, Some(req.top_k), req.model.as_deref(), ).await; match result { Ok(results_vec) => { let results: HashMap> = results_vec .into_iter() .map(|r| (r.image_path, r.predictions)) .collect(); tracing::info!("[CLIP_API] Batch classified {} images", results.len()); Json(BatchClassifyResponse { success: true, results, }).into_response() } Err(e) => { tracing::error!("[CLIP_API] Batch classification failed: {}", e); Json(ErrorResponse { success: false, error: e.to_string(), }).into_response() } } }