use axum::{ extract::State, http::StatusCode, response::Json, routing::post, Router, }; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Mutex; use std::time::Instant; use crate::api::types::AppState; use crate::core::db::schema; use crate::core::llm::function_calling::{self, ChatMessage, LlmResponse, ToolCall, ToolDef}; // ── Conversation Manager ───────────────────────────────────────── struct Conversation { messages: Vec, created_at: Instant, last_active: Instant, } static CONVERSATIONS: Lazy>> = Lazy::new(|| { // Spawn cleanup task std::thread::spawn(|| loop { std::thread::sleep(std::time::Duration::from_secs(60)); let mut map = CONVERSATIONS.lock().unwrap(); let now = Instant::now(); map.retain(|_, conv| now.duration_since(conv.last_active).as_secs() < 1800); }); Mutex::new(HashMap::new()) }); fn get_or_create_conv(conv_id: Option<&str>) -> (String, Vec) { let mut map = CONVERSATIONS.lock().unwrap(); if let Some(cid) = conv_id { if let Some(conv) = map.get_mut(cid) { conv.last_active = Instant::now(); return (cid.to_string(), conv.messages.clone()); } } let id = uuid::Uuid::new_v4().to_string().replace('-', "")[..16].to_string(); map.insert(id.clone(), Conversation { messages: Vec::new(), created_at: Instant::now(), last_active: Instant::now(), }); (id, Vec::new()) } fn save_messages(conv_id: &str, messages: &[ChatMessage]) { if let Some(conv) = CONVERSATIONS.lock().unwrap().get_mut(conv_id) { conv.messages = messages.to_vec(); conv.last_active = Instant::now(); } } // ── Request / Response ─────────────────────────────────────────── #[derive(Debug, Deserialize)] pub struct AgentSearchRequest { pub query: String, pub conversation_id: Option, pub file_uuid: Option, } #[derive(Debug, Serialize)] pub struct AgentSearchResponse { pub success: bool, pub conversation_id: String, pub answer: String, #[serde(skip_serializing_if = "Option::is_none")] pub suggestions: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub sources: Option>, } // ── Tool Definitions ────────────────────────────────────────────── const SYSTEM_PROMPT: &str = r#"你是 Momentry 影片分析助手。回答用戶關於影片內容的問題。 ## 工具使用規則 1. 先確認用戶在問哪部影片 — 使用 find_file 或 list_files 2. 人物問題優先使用 tkg_query 3. 語意/內容問題使用 smart_search 或 universal_search 4. 可以同時呼叫多個工具 ## 引導規則 - 如果用戶沒說片名 → 用 find_file 搜尋,如果名稱不明確就反問 - 反問時提供 suggestions,例如演員名、年代 - **如果影片的 has_data 為 false,代表尚未完成處理,不要推薦用戶使用。引導用戶選擇 has_data=true 的影片** - 不要輸出 JSON,用自然語言回答 - 引用資料時附上具體數字(frame 編號、時間秒數) ## 回答規則 - 回答要簡潔但完整 - 如果找到影片,附上 file_uuid(用戶之後可能需要) - 對於人物問題,說出角色名和演員名"#; fn make_tools(pool: &sqlx::PgPool) -> Vec { vec![ function_calling::make_tool( "find_file", "透過關鍵字搜尋影片(片名、演員、年份)。回傳符合的影片列表。", serde_json::json!({ "query": {"type": "string", "description": "搜尋關鍵字(片名、演員名、年份)"} }), vec!["query"], ), function_calling::make_tool( "list_files", "列出近期註冊的影片。", serde_json::json!({ "limit": {"type": "integer", "description": "回傳筆數上限", "default": 10} }), vec![], ), function_calling::make_tool( "tkg_query", "查詢影片的人物互動、配對、同框資料。query_type 包括:top_identities(人物排名)、first_cooccurrence(第一次同框)、identity_details(人物詳細)、mutual_gaze(互看)、interaction_network(互動網絡)、identity_traces(出場片段)、file_info(影片資訊)。", serde_json::json!({ "file_uuid": {"type": "string", "description": "影片 UUID"}, "query_type": { "type": "string", "enum": ["top_identities", "first_cooccurrence", "identity_details", "mutual_gaze", "interaction_network", "identity_traces", "file_info"], "description": "查詢類型" }, "identity_name": {"type": "string", "description": "人物名稱(配合 identity_details / identity_traces)"}, "identity_b": {"type": "string", "description": "第二人物名稱(配合 first_cooccurrence / mutual_gaze)"}, "limit": {"type": "integer", "default": 5} }), vec!["file_uuid", "query_type"], ), function_calling::make_tool( "smart_search", "語意搜尋 chunk 文字內容。適合需要理解意圖的查詢。", serde_json::json!({ "file_uuid": {"type": "string", "description": "限制搜尋範圍(可選)"}, "query": {"type": "string", "description": "搜尋關鍵字"}, "limit": {"type": "integer", "default": 5} }), vec!["query"], ), function_calling::make_tool( "get_identity_detail", "查詢單一身份的詳細資料(名字、角色、TMDb 資訊)。", serde_json::json!({ "name": {"type": "string", "description": "人物名稱"} }), vec!["name"], ), function_calling::make_tool( "get_file_info", "查詢影片基本資訊(片名、長度、解析度)。", serde_json::json!({ "file_uuid": {"type": "string", "description": "影片 UUID"} }), vec!["file_uuid"], ), function_calling::make_tool( "get_representative_frame", "查詢影片最具代表性的 frame 資訊(frame 編號、時間、人物)。", serde_json::json!({ "file_uuid": {"type": "string", "description": "影片 UUID"} }), vec!["file_uuid"], ), ] } // ── Tool Executors ─────────────────────────────────────────────── async fn exec_find_file(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { let query = args.get("query").and_then(|v| v.as_str()).unwrap_or(""); let videos = schema::table_name("videos"); let fd_table = schema::table_name("face_detections"); let like = format!("%{}%", query); let rows: Vec<(String, String, bool)> = sqlx::query_as(&format!( "SELECT v.file_uuid::text, v.file_name, \ (SELECT COUNT(*) FROM {} fd WHERE fd.file_uuid = v.file_uuid) > 0 AS has_data \ FROM {} v WHERE v.file_name ILIKE $1 \ ORDER BY v.created_at DESC LIMIT 10", fd_table, videos )) .bind(&like) .fetch_all(pool) .await .map_err(|e| e.to_string())?; if rows.is_empty() { return Ok(serde_json::json!({"found": false, "message": "No files match the query. Try different keywords."}).to_string()); } let files: Vec = rows.into_iter().map(|(u, n, hd)| { serde_json::json!({"file_uuid": u, "file_name": n, "has_data": hd}) }).collect(); Ok(serde_json::json!({"found": true, "files": files}).to_string()) } async fn exec_list_files(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(10); let videos = schema::table_name("videos"); let fd_table = schema::table_name("face_detections"); let rows: Vec<(String, String, bool)> = sqlx::query_as(&format!( "SELECT v.file_uuid::text, v.file_name, \ (SELECT COUNT(*) FROM {} fd WHERE fd.file_uuid = v.file_uuid) > 0 AS has_data \ FROM {} v ORDER BY v.created_at DESC LIMIT $1", fd_table, videos )) .bind(limit) .fetch_all(pool) .await .map_err(|e| e.to_string())?; let files: Vec = rows.into_iter().map(|(u, n, hd)| { serde_json::json!({"file_uuid": u, "file_name": n, "has_data": hd}) }).collect(); Ok(serde_json::json!({"files": files}).to_string()) } async fn exec_tkg_query(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()).unwrap_or(""); let query_type = args.get("query_type").and_then(|v| v.as_str()).unwrap_or(""); let identity_name = args.get("identity_name").and_then(|v| v.as_str()); let identity_b = args.get("identity_b").and_then(|v| v.as_str()); let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(5); let id_table = schema::table_name("identities"); let fd_table = schema::table_name("face_detections"); let videos = schema::table_name("videos"); let nodes = schema::table_name("tkg_nodes"); let edges = schema::table_name("tkg_edges"); match query_type { "top_identities" => { let rows: Vec<(String, String, i64)> = sqlx::query_as(&format!( "SELECT i.uuid::text, i.name, COUNT(fd.id)::bigint AS face_count \ FROM {} fd JOIN {} i ON i.id = fd.identity_id \ WHERE fd.file_uuid = $1 AND fd.identity_id IS NOT NULL AND i.source = 'tmdb' \ GROUP BY i.uuid, i.name ORDER BY face_count DESC LIMIT $2", fd_table, id_table )) .bind(file_uuid).bind(limit) .fetch_all(pool) .await.map_err(|e| e.to_string())?; Ok(serde_json::json!({"identities": rows}).to_string()) } "first_cooccurrence" => { let name_a = identity_name.unwrap_or(""); let name_b = identity_b.unwrap_or(""); let row: Option<(i64, f64)> = sqlx::query_as(&format!( "SELECT MIN(fd_a.frame_number)::bigint, \ ROUND(MIN(fd_a.frame_number)::numeric / GREATEST(MAX(v.fps)::numeric, 25.0), 2)::float8 \ FROM {} fd_a JOIN {} fd_b ON fd_a.frame_number = fd_b.frame_number \ JOIN {} v ON v.file_uuid = $1 \ WHERE fd_a.file_uuid = $1 \ AND fd_a.identity_id = (SELECT id FROM {} WHERE name ILIKE $2 LIMIT 1) \ AND fd_b.identity_id = (SELECT id FROM {} WHERE name ILIKE $3 LIMIT 1)", fd_table, fd_table, videos, id_table, id_table )) .bind(file_uuid).bind(name_a).bind(name_b) .fetch_optional(pool) .await.map_err(|e| e.to_string())?; Ok(serde_json::json!({"first_cooccurrence": row.map(|(f, t)| serde_json::json!({"frame": f, "timestamp_secs": t}))}).to_string()) } "identity_details" => { let name = identity_name.unwrap_or(""); let row: Option<(String, String, Option, i64)> = sqlx::query_as(&format!( "SELECT i.uuid::text, i.name, i.tmdb_id, \ (SELECT COUNT(*) FROM {} fd WHERE fd.identity_id = i.id AND fd.file_uuid = $1)::bigint \ FROM {} i WHERE i.name ILIKE $2 LIMIT 1", fd_table, id_table )) .bind(file_uuid).bind(name) .fetch_optional(pool) .await.map_err(|e| e.to_string())?; Ok(serde_json::json!({"identity": row.map(|(u, n, tid, fc)| serde_json::json!({"uuid": u, "name": n, "tmdb_id": tid, "face_count": fc}))}).to_string()) } "mutual_gaze" => { let name_a = identity_name.unwrap_or(""); let name_b = identity_b.unwrap_or(""); let row: Option<(i64, i64, f64, f64)> = sqlx::query_as(&format!( "SELECT (e.properties->>'first_frame')::bigint, \ (e.properties->>'gaze_frame_count')::int::bigint, \ (e.properties->>'yaw_a_avg')::float8, \ (e.properties->>'yaw_b_avg')::float8 \ FROM {} e \ JOIN {} a ON a.id = e.source_node_id \ JOIN {} b ON b.id = e.target_node_id \ JOIN {} fd_a ON fd_a.file_uuid = $1 AND fd_a.trace_id = REPLACE(a.external_id, 'trace_', '')::int \ JOIN {} fd_b ON fd_b.file_uuid = $1 AND fd_b.trace_id = REPLACE(b.external_id, 'trace_', '')::int \ JOIN {} ia ON ia.id = fd_a.identity_id \ JOIN {} ib ON ib.id = fd_b.identity_id \ WHERE e.file_uuid = $1 AND ia.name ILIKE $2 AND ib.name ILIKE $3 \ AND e.properties->>'mutual_gaze' = 'true' LIMIT 1", edges, nodes, nodes, fd_table, fd_table, id_table, id_table )) .bind(file_uuid).bind(name_a).bind(name_b) .fetch_optional(pool) .await.map_err(|e| e.to_string())?; Ok(serde_json::json!({"mutual_gaze": row.map(|(f, gc, ya, yb)| serde_json::json!({"first_frame": f, "gaze_frame_count": gc, "yaw_a": ya, "yaw_b": yb}))}).to_string()) } "interaction_network" => { let rows: Vec<(String, String, i64)> = sqlx::query_as(&format!( "SELECT ia.name, ib.name, COUNT(*)::bigint \ FROM {} e \ JOIN {} a ON a.id = e.source_node_id \ JOIN {} b ON b.id = e.target_node_id \ JOIN {} fd_a ON fd_a.trace_id = REPLACE(a.external_id, 'trace_', '')::int AND fd_a.file_uuid = $1 \ JOIN {} fd_b ON fd_b.trace_id = REPLACE(b.external_id, 'trace_', '')::int AND fd_b.file_uuid = $1 \ JOIN {} ia ON ia.id = fd_a.identity_id \ JOIN {} ib ON ib.id = fd_b.identity_id \ WHERE e.file_uuid = $1 AND e.edge_type = 'CO_OCCURS_WITH' \ AND ia.name != ib.name AND ia.source = 'tmdb' AND ib.source = 'tmdb' \ GROUP BY ia.name, ib.name \ ORDER BY COUNT(*) DESC LIMIT $2", edges, nodes, nodes, fd_table, fd_table, id_table, id_table )) .bind(file_uuid).bind(limit) .fetch_all(pool) .await.map_err(|e| e.to_string())?; Ok(serde_json::json!({"interaction_network": rows}).to_string()) } "identity_traces" => { let name = identity_name.unwrap_or(""); let rows: Vec<(i32, i64, i32, i32)> = sqlx::query_as(&format!( "SELECT fd.trace_id, COUNT(*)::bigint, MIN(fd.frame_number)::int, MAX(fd.frame_number)::int \ FROM {} fd JOIN {} i ON i.id = fd.identity_id \ WHERE fd.file_uuid = $1 AND i.name ILIKE $2 \ GROUP BY fd.trace_id ORDER BY COUNT(*) DESC LIMIT $3", fd_table, id_table )) .bind(file_uuid).bind(name).bind(limit) .fetch_all(pool) .await.map_err(|e| e.to_string())?; Ok(serde_json::json!({"traces": rows}).to_string()) } "file_info" => { let row: Option<(String, f64, i32, i32, f64)> = sqlx::query_as(&format!( "SELECT file_name, duration, width, height, fps FROM {} WHERE file_uuid = $1", videos )) .bind(file_uuid) .fetch_optional(pool) .await.map_err(|e| e.to_string())?; Ok(serde_json::json!({"file_info": row.map(|(n, d, w, h, f)| serde_json::json!({"file_name": n, "duration_sec": d, "width": w, "height": h, "fps": f}))}).to_string()) } _ => Ok(serde_json::json!({"error": format!("Unknown query_type: {}", query_type)}).to_string()), } } async fn exec_smart_search(_pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { let query = args.get("query").and_then(|v| v.as_str()).unwrap_or(""); let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()); let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(5); let chunk_table = schema::table_name("chunk"); let mut sql = format!( "SELECT chunk_id, text_content, start_frame, end_frame, chunk_type \ FROM {} WHERE text_content ILIKE $1", chunk_table ); if file_uuid.is_some() { sql.push_str(" AND file_uuid = $2"); } sql.push_str(&format!(" ORDER BY start_frame LIMIT {}", limit)); if let Some(fuid) = file_uuid { let like = format!("%{}%", query); let rows: Vec<(String, Option, i64, i64, String)> = sqlx::query_as(&sql) .bind(&like).bind(fuid) .fetch_all(_pool) .await.map_err(|e| e.to_string())?; Ok(serde_json::json!({"results": rows}).to_string()) } else { let like = format!("%{}%", query); let rows: Vec<(String, Option, i64, i64, String)> = sqlx::query_as(&sql) .bind(&like) .fetch_all(_pool) .await.map_err(|e| e.to_string())?; Ok(serde_json::json!({"results": rows}).to_string()) } } async fn exec_get_identity_detail(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { let name = args.get("name").and_then(|v| v.as_str()).unwrap_or(""); let id_table = schema::table_name("identities"); let row: Option<(String, String, Option, Option, Option)> = sqlx::query_as(&format!( "SELECT uuid::text, name, source, tmdb_id, metadata->>'tmdb_character' FROM {} WHERE name ILIKE $1 LIMIT 1", id_table )) .bind(name) .fetch_optional(pool) .await.map_err(|e| e.to_string())?; Ok(serde_json::json!({"identity": row.map(|(u, n, s, t, c)| serde_json::json!({"uuid": u, "name": n, "source": s, "tmdb_id": t, "character": c}))}).to_string()) } async fn exec_get_file_info(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()).unwrap_or(""); let videos = schema::table_name("videos"); let row: Option<(String, f64, i32, i32, f64)> = sqlx::query_as(&format!( "SELECT file_name, duration, width, height, fps FROM {} WHERE file_uuid = $1", videos )) .bind(file_uuid) .fetch_optional(pool) .await.map_err(|e| e.to_string())?; Ok(serde_json::json!({"file_info": row.map(|(n, d, w, h, f)| serde_json::json!({"file_name": n, "duration_sec": d, "width": w, "height": h, "fps": f}))}).to_string()) } async fn exec_get_representative_frame(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()).unwrap_or(""); match crate::core::processor::tkg::query_auto_representative_frame(pool, file_uuid).await { Ok(r) => Ok(serde_json::json!({ "frame_number": r.frame_number, "face_quality": r.face_quality, "main_identities": r.main_identities, "traces": r.traces, }).to_string()), Err(e) => Ok(serde_json::json!({"error": e.to_string()}).to_string()), } } // ── Tool Router ─────────────────────────────────────────────────── async fn execute_tool(pool: &sqlx::PgPool, tool_call: &ToolCall) -> (String, String, String) { let name = tool_call.function.name.clone(); let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments).unwrap_or_default(); let result = match name.as_str() { "find_file" => exec_find_file(pool, &args).await, "list_files" => exec_list_files(pool, &args).await, "tkg_query" => exec_tkg_query(pool, &args).await, "smart_search" => exec_smart_search(pool, &args).await, "get_identity_detail" => exec_get_identity_detail(pool, &args).await, "get_file_info" => exec_get_file_info(pool, &args).await, "get_representative_frame" => exec_get_representative_frame(pool, &args).await, _ => Err(format!("Unknown tool: {}", name)), }; let content = match result { Ok(s) => s, Err(e) => serde_json::json!({"error": e}).to_string(), }; let tool_call_id = tool_call.id.clone().unwrap_or_default(); (tool_call_id, name, content) } // ── Tool Loop ───────────────────────────────────────────────────── const MAX_ROUNDS: u32 = 5; async fn run_tool_loop( pool: &sqlx::PgPool, system_prompt: &str, user_query: &str, history: Vec, ) -> (String, Vec) { let mut messages = function_calling::build_conversation(system_prompt, user_query, history); let mut sources = Vec::new(); for round in 0..MAX_ROUNDS { let tools = Some(make_tools(pool)); match function_calling::call_llm(messages.clone(), tools, 2048, 120).await { Ok(LlmResponse::Text(text)) => { return (text, sources); } Ok(LlmResponse::ToolCalls(calls)) => { // Push assistant message with tool_calls so Gemma4 remembers messages.push(ChatMessage { role: "assistant".to_string(), content: None, tool_calls: Some(calls.clone()), tool_call_id: None, name: None, }); for call in &calls { let (tool_call_id, name, content) = execute_tool(pool, call).await; sources.push(serde_json::json!({"tool": name, "result": content})); messages.push(function_calling::make_tool_result(&tool_call_id, &name, &content)); } } Err(e) => { return (format!("系統錯誤:{}", e), sources); } } } ("已達到最大查詢次數,請縮小問題範圍後重新詢問。".to_string(), sources) } // ── Handler ─────────────────────────────────────────────────────── async fn agent_search( State(state): State, Json(req): Json, ) -> Result, (StatusCode, Json)> { let (conv_id, history) = get_or_create_conv(req.conversation_id.as_deref()); let (answer, sources) = run_tool_loop( state.db.pool(), SYSTEM_PROMPT, &req.query, history, ) .await; // Save updated messages for conversation continuation let new_msgs = function_calling::build_conversation(SYSTEM_PROMPT, &req.query, vec![]); save_messages(&conv_id, &new_msgs); let needs_input = answer.contains('?') || answer.contains('?'); let suggestions = if needs_input { Some(vec!["演員名".to_string(), "電影片名".to_string(), "年份".to_string()]) } else { None }; Ok(Json(AgentSearchResponse { success: true, conversation_id: conv_id, answer, suggestions, sources: Some(sources), })) } // ── Routes ───────────────────────────────────────────────────────── pub fn agent_search_routes() -> Router { Router::new() .route("/api/v1/agents/search", post(agent_search)) }