Files
momentry_core/src/api/agent_search.rs

532 lines
24 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::{
extract::State,
http::StatusCode,
response::Json,
routing::post,
Router,
};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Mutex;
use std::time::Instant;
use crate::api::types::AppState;
use crate::core::db::schema;
use crate::core::llm::function_calling::{self, ChatMessage, LlmResponse, ToolCall, ToolDef};
// ── Conversation Manager ─────────────────────────────────────────
struct Conversation {
messages: Vec<ChatMessage>,
created_at: Instant,
last_active: Instant,
}
static CONVERSATIONS: Lazy<Mutex<HashMap<String, Conversation>>> = Lazy::new(|| {
// Spawn cleanup task
std::thread::spawn(|| loop {
std::thread::sleep(std::time::Duration::from_secs(60));
let mut map = CONVERSATIONS.lock().unwrap();
let now = Instant::now();
map.retain(|_, conv| now.duration_since(conv.last_active).as_secs() < 1800);
});
Mutex::new(HashMap::new())
});
fn get_or_create_conv(conv_id: Option<&str>) -> (String, Vec<ChatMessage>) {
let mut map = CONVERSATIONS.lock().unwrap();
if let Some(cid) = conv_id {
if let Some(conv) = map.get_mut(cid) {
conv.last_active = Instant::now();
return (cid.to_string(), conv.messages.clone());
}
}
let id = uuid::Uuid::new_v4().to_string().replace('-', "")[..16].to_string();
map.insert(id.clone(), Conversation {
messages: Vec::new(),
created_at: Instant::now(),
last_active: Instant::now(),
});
(id, Vec::new())
}
fn save_messages(conv_id: &str, messages: &[ChatMessage]) {
if let Some(conv) = CONVERSATIONS.lock().unwrap().get_mut(conv_id) {
conv.messages = messages.to_vec();
conv.last_active = Instant::now();
}
}
// ── Request / Response ───────────────────────────────────────────
#[derive(Debug, Deserialize)]
pub struct AgentSearchRequest {
pub query: String,
pub conversation_id: Option<String>,
pub file_uuid: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct AgentSearchResponse {
pub success: bool,
pub conversation_id: String,
pub answer: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub suggestions: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sources: Option<Vec<serde_json::Value>>,
}
// ── Tool Definitions ──────────────────────────────────────────────
const SYSTEM_PROMPT: &str = r#"你是 Momentry 影片分析助手。回答用戶關於影片內容的問題。
## 工具使用規則
1. 先確認用戶在問哪部影片 — 使用 find_file 或 list_files
2. 人物問題優先使用 tkg_query
3. 語意/內容問題使用 smart_search 或 universal_search
4. 可以同時呼叫多個工具
## 引導規則
- 如果用戶沒說片名 → 用 find_file 搜尋,如果名稱不明確就反問
- 反問時提供 suggestions例如演員名、年代
- **如果影片的 has_data 為 false代表尚未完成處理不要推薦用戶使用。引導用戶選擇 has_data=true 的影片**
- 不要輸出 JSON用自然語言回答
- 引用資料時附上具體數字frame 編號、時間秒數)
## 回答規則
- 回答要簡潔但完整
- 如果找到影片,附上 file_uuid用戶之後可能需要
- 對於人物問題,說出角色名和演員名"#;
fn make_tools(pool: &sqlx::PgPool) -> Vec<ToolDef> {
vec![
function_calling::make_tool(
"find_file",
"透過關鍵字搜尋影片(片名、演員、年份)。回傳符合的影片列表。",
serde_json::json!({
"query": {"type": "string", "description": "搜尋關鍵字(片名、演員名、年份)"}
}),
vec!["query"],
),
function_calling::make_tool(
"list_files",
"列出近期註冊的影片。",
serde_json::json!({
"limit": {"type": "integer", "description": "回傳筆數上限", "default": 10}
}),
vec![],
),
function_calling::make_tool(
"tkg_query",
"查詢影片的人物互動、配對、同框資料。query_type 包括top_identities人物排名、first_cooccurrence第一次同框、identity_details人物詳細、mutual_gaze互看、interaction_network互動網絡、identity_traces出場片段、file_info影片資訊",
serde_json::json!({
"file_uuid": {"type": "string", "description": "影片 UUID"},
"query_type": {
"type": "string",
"enum": ["top_identities", "first_cooccurrence", "identity_details", "mutual_gaze", "interaction_network", "identity_traces", "file_info"],
"description": "查詢類型"
},
"identity_name": {"type": "string", "description": "人物名稱(配合 identity_details / identity_traces"},
"identity_b": {"type": "string", "description": "第二人物名稱(配合 first_cooccurrence / mutual_gaze"},
"limit": {"type": "integer", "default": 5}
}),
vec!["file_uuid", "query_type"],
),
function_calling::make_tool(
"smart_search",
"語意搜尋 chunk 文字內容。適合需要理解意圖的查詢。",
serde_json::json!({
"file_uuid": {"type": "string", "description": "限制搜尋範圍(可選)"},
"query": {"type": "string", "description": "搜尋關鍵字"},
"limit": {"type": "integer", "default": 5}
}),
vec!["query"],
),
function_calling::make_tool(
"get_identity_detail",
"查詢單一身份的詳細資料名字、角色、TMDb 資訊)。",
serde_json::json!({
"name": {"type": "string", "description": "人物名稱"}
}),
vec!["name"],
),
function_calling::make_tool(
"get_file_info",
"查詢影片基本資訊(片名、長度、解析度)。",
serde_json::json!({
"file_uuid": {"type": "string", "description": "影片 UUID"}
}),
vec!["file_uuid"],
),
function_calling::make_tool(
"get_representative_frame",
"查詢影片最具代表性的 frame 資訊frame 編號、時間、人物)。",
serde_json::json!({
"file_uuid": {"type": "string", "description": "影片 UUID"}
}),
vec!["file_uuid"],
),
]
}
// ── Tool Executors ───────────────────────────────────────────────
async fn exec_find_file(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
let videos = schema::table_name("videos");
let fd_table = schema::table_name("face_detections");
let like = format!("%{}%", query);
let rows: Vec<(String, String, bool)> = sqlx::query_as(&format!(
"SELECT v.file_uuid::text, v.file_name, \
(SELECT COUNT(*) FROM {} fd WHERE fd.file_uuid = v.file_uuid) > 0 AS has_data \
FROM {} v WHERE v.file_name ILIKE $1 \
ORDER BY v.created_at DESC LIMIT 10",
fd_table, videos
))
.bind(&like)
.fetch_all(pool)
.await
.map_err(|e| e.to_string())?;
if rows.is_empty() {
return Ok(serde_json::json!({"found": false, "message": "No files match the query. Try different keywords."}).to_string());
}
let files: Vec<serde_json::Value> = rows.into_iter().map(|(u, n, hd)| {
serde_json::json!({"file_uuid": u, "file_name": n, "has_data": hd})
}).collect();
Ok(serde_json::json!({"found": true, "files": files}).to_string())
}
async fn exec_list_files(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(10);
let videos = schema::table_name("videos");
let fd_table = schema::table_name("face_detections");
let rows: Vec<(String, String, bool)> = sqlx::query_as(&format!(
"SELECT v.file_uuid::text, v.file_name, \
(SELECT COUNT(*) FROM {} fd WHERE fd.file_uuid = v.file_uuid) > 0 AS has_data \
FROM {} v ORDER BY v.created_at DESC LIMIT $1",
fd_table, videos
))
.bind(limit)
.fetch_all(pool)
.await
.map_err(|e| e.to_string())?;
let files: Vec<serde_json::Value> = rows.into_iter().map(|(u, n, hd)| {
serde_json::json!({"file_uuid": u, "file_name": n, "has_data": hd})
}).collect();
Ok(serde_json::json!({"files": files}).to_string())
}
async fn exec_tkg_query(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()).unwrap_or("");
let query_type = args.get("query_type").and_then(|v| v.as_str()).unwrap_or("");
let identity_name = args.get("identity_name").and_then(|v| v.as_str());
let identity_b = args.get("identity_b").and_then(|v| v.as_str());
let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(5);
let id_table = schema::table_name("identities");
let fd_table = schema::table_name("face_detections");
let videos = schema::table_name("videos");
let nodes = schema::table_name("tkg_nodes");
let edges = schema::table_name("tkg_edges");
match query_type {
"top_identities" => {
let rows: Vec<(String, String, i64)> = sqlx::query_as(&format!(
"SELECT i.uuid::text, i.name, COUNT(fd.id)::bigint AS face_count \
FROM {} fd JOIN {} i ON i.id = fd.identity_id \
WHERE fd.file_uuid = $1 AND fd.identity_id IS NOT NULL AND i.source = 'tmdb' \
GROUP BY i.uuid, i.name ORDER BY face_count DESC LIMIT $2",
fd_table, id_table
))
.bind(file_uuid).bind(limit)
.fetch_all(pool)
.await.map_err(|e| e.to_string())?;
Ok(serde_json::json!({"identities": rows}).to_string())
}
"first_cooccurrence" => {
let name_a = identity_name.unwrap_or("");
let name_b = identity_b.unwrap_or("");
let row: Option<(i64, f64)> = sqlx::query_as(&format!(
"SELECT MIN(fd_a.frame_number)::bigint, \
ROUND(MIN(fd_a.frame_number)::numeric / GREATEST(MAX(v.fps)::numeric, 25.0), 2)::float8 \
FROM {} fd_a JOIN {} fd_b ON fd_a.frame_number = fd_b.frame_number \
JOIN {} v ON v.file_uuid = $1 \
WHERE fd_a.file_uuid = $1 \
AND fd_a.identity_id = (SELECT id FROM {} WHERE name ILIKE $2 LIMIT 1) \
AND fd_b.identity_id = (SELECT id FROM {} WHERE name ILIKE $3 LIMIT 1)",
fd_table, fd_table, videos, id_table, id_table
))
.bind(file_uuid).bind(name_a).bind(name_b)
.fetch_optional(pool)
.await.map_err(|e| e.to_string())?;
Ok(serde_json::json!({"first_cooccurrence": row.map(|(f, t)| serde_json::json!({"frame": f, "timestamp_secs": t}))}).to_string())
}
"identity_details" => {
let name = identity_name.unwrap_or("");
let row: Option<(String, String, Option<i32>, i64)> = sqlx::query_as(&format!(
"SELECT i.uuid::text, i.name, i.tmdb_id, \
(SELECT COUNT(*) FROM {} fd WHERE fd.identity_id = i.id AND fd.file_uuid = $1)::bigint \
FROM {} i WHERE i.name ILIKE $2 LIMIT 1",
fd_table, id_table
))
.bind(file_uuid).bind(name)
.fetch_optional(pool)
.await.map_err(|e| e.to_string())?;
Ok(serde_json::json!({"identity": row.map(|(u, n, tid, fc)| serde_json::json!({"uuid": u, "name": n, "tmdb_id": tid, "face_count": fc}))}).to_string())
}
"mutual_gaze" => {
let name_a = identity_name.unwrap_or("");
let name_b = identity_b.unwrap_or("");
let row: Option<(i64, i64, f64, f64)> = sqlx::query_as(&format!(
"SELECT (e.properties->>'first_frame')::bigint, \
(e.properties->>'gaze_frame_count')::int::bigint, \
(e.properties->>'yaw_a_avg')::float8, \
(e.properties->>'yaw_b_avg')::float8 \
FROM {} e \
JOIN {} a ON a.id = e.source_node_id \
JOIN {} b ON b.id = e.target_node_id \
JOIN {} fd_a ON fd_a.file_uuid = $1 AND fd_a.trace_id = REPLACE(a.external_id, 'trace_', '')::int \
JOIN {} fd_b ON fd_b.file_uuid = $1 AND fd_b.trace_id = REPLACE(b.external_id, 'trace_', '')::int \
JOIN {} ia ON ia.id = fd_a.identity_id \
JOIN {} ib ON ib.id = fd_b.identity_id \
WHERE e.file_uuid = $1 AND ia.name ILIKE $2 AND ib.name ILIKE $3 \
AND e.properties->>'mutual_gaze' = 'true' LIMIT 1",
edges, nodes, nodes, fd_table, fd_table, id_table, id_table
))
.bind(file_uuid).bind(name_a).bind(name_b)
.fetch_optional(pool)
.await.map_err(|e| e.to_string())?;
Ok(serde_json::json!({"mutual_gaze": row.map(|(f, gc, ya, yb)| serde_json::json!({"first_frame": f, "gaze_frame_count": gc, "yaw_a": ya, "yaw_b": yb}))}).to_string())
}
"interaction_network" => {
let rows: Vec<(String, String, i64)> = sqlx::query_as(&format!(
"SELECT ia.name, ib.name, COUNT(*)::bigint \
FROM {} e \
JOIN {} a ON a.id = e.source_node_id \
JOIN {} b ON b.id = e.target_node_id \
JOIN {} fd_a ON fd_a.trace_id = REPLACE(a.external_id, 'trace_', '')::int AND fd_a.file_uuid = $1 \
JOIN {} fd_b ON fd_b.trace_id = REPLACE(b.external_id, 'trace_', '')::int AND fd_b.file_uuid = $1 \
JOIN {} ia ON ia.id = fd_a.identity_id \
JOIN {} ib ON ib.id = fd_b.identity_id \
WHERE e.file_uuid = $1 AND e.edge_type = 'CO_OCCURS_WITH' \
AND ia.name != ib.name AND ia.source = 'tmdb' AND ib.source = 'tmdb' \
GROUP BY ia.name, ib.name \
ORDER BY COUNT(*) DESC LIMIT $2",
edges, nodes, nodes, fd_table, fd_table, id_table, id_table
))
.bind(file_uuid).bind(limit)
.fetch_all(pool)
.await.map_err(|e| e.to_string())?;
Ok(serde_json::json!({"interaction_network": rows}).to_string())
}
"identity_traces" => {
let name = identity_name.unwrap_or("");
let rows: Vec<(i32, i64, i32, i32)> = sqlx::query_as(&format!(
"SELECT fd.trace_id, COUNT(*)::bigint, MIN(fd.frame_number)::int, MAX(fd.frame_number)::int \
FROM {} fd JOIN {} i ON i.id = fd.identity_id \
WHERE fd.file_uuid = $1 AND i.name ILIKE $2 \
GROUP BY fd.trace_id ORDER BY COUNT(*) DESC LIMIT $3",
fd_table, id_table
))
.bind(file_uuid).bind(name).bind(limit)
.fetch_all(pool)
.await.map_err(|e| e.to_string())?;
Ok(serde_json::json!({"traces": rows}).to_string())
}
"file_info" => {
let row: Option<(String, f64, i32, i32, f64)> = sqlx::query_as(&format!(
"SELECT file_name, duration, width, height, fps FROM {} WHERE file_uuid = $1",
videos
))
.bind(file_uuid)
.fetch_optional(pool)
.await.map_err(|e| e.to_string())?;
Ok(serde_json::json!({"file_info": row.map(|(n, d, w, h, f)| serde_json::json!({"file_name": n, "duration_sec": d, "width": w, "height": h, "fps": f}))}).to_string())
}
_ => Ok(serde_json::json!({"error": format!("Unknown query_type: {}", query_type)}).to_string()),
}
}
async fn exec_smart_search(_pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
let file_uuid = args.get("file_uuid").and_then(|v| v.as_str());
let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(5);
let chunk_table = schema::table_name("chunk");
let mut sql = format!(
"SELECT chunk_id, text_content, start_frame, end_frame, chunk_type \
FROM {} WHERE text_content ILIKE $1", chunk_table
);
if file_uuid.is_some() {
sql.push_str(" AND file_uuid = $2");
}
sql.push_str(&format!(" ORDER BY start_frame LIMIT {}", limit));
if let Some(fuid) = file_uuid {
let like = format!("%{}%", query);
let rows: Vec<(String, Option<String>, i64, i64, String)> = sqlx::query_as(&sql)
.bind(&like).bind(fuid)
.fetch_all(_pool)
.await.map_err(|e| e.to_string())?;
Ok(serde_json::json!({"results": rows}).to_string())
} else {
let like = format!("%{}%", query);
let rows: Vec<(String, Option<String>, i64, i64, String)> = sqlx::query_as(&sql)
.bind(&like)
.fetch_all(_pool)
.await.map_err(|e| e.to_string())?;
Ok(serde_json::json!({"results": rows}).to_string())
}
}
async fn exec_get_identity_detail(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
let name = args.get("name").and_then(|v| v.as_str()).unwrap_or("");
let id_table = schema::table_name("identities");
let row: Option<(String, String, Option<String>, Option<i32>, Option<String>)> = sqlx::query_as(&format!(
"SELECT uuid::text, name, source, tmdb_id, metadata->>'tmdb_character' FROM {} WHERE name ILIKE $1 LIMIT 1",
id_table
))
.bind(name)
.fetch_optional(pool)
.await.map_err(|e| e.to_string())?;
Ok(serde_json::json!({"identity": row.map(|(u, n, s, t, c)| serde_json::json!({"uuid": u, "name": n, "source": s, "tmdb_id": t, "character": c}))}).to_string())
}
async fn exec_get_file_info(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()).unwrap_or("");
let videos = schema::table_name("videos");
let row: Option<(String, f64, i32, i32, f64)> = sqlx::query_as(&format!(
"SELECT file_name, duration, width, height, fps FROM {} WHERE file_uuid = $1",
videos
))
.bind(file_uuid)
.fetch_optional(pool)
.await.map_err(|e| e.to_string())?;
Ok(serde_json::json!({"file_info": row.map(|(n, d, w, h, f)| serde_json::json!({"file_name": n, "duration_sec": d, "width": w, "height": h, "fps": f}))}).to_string())
}
async fn exec_get_representative_frame(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()).unwrap_or("");
match crate::core::processor::tkg::query_auto_representative_frame(pool, file_uuid).await {
Ok(r) => Ok(serde_json::json!({
"frame_number": r.frame_number,
"face_quality": r.face_quality,
"main_identities": r.main_identities,
"traces": r.traces,
}).to_string()),
Err(e) => Ok(serde_json::json!({"error": e.to_string()}).to_string()),
}
}
// ── Tool Router ───────────────────────────────────────────────────
async fn execute_tool(pool: &sqlx::PgPool, tool_call: &ToolCall) -> (String, String, String) {
let name = tool_call.function.name.clone();
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments).unwrap_or_default();
let result = match name.as_str() {
"find_file" => exec_find_file(pool, &args).await,
"list_files" => exec_list_files(pool, &args).await,
"tkg_query" => exec_tkg_query(pool, &args).await,
"smart_search" => exec_smart_search(pool, &args).await,
"get_identity_detail" => exec_get_identity_detail(pool, &args).await,
"get_file_info" => exec_get_file_info(pool, &args).await,
"get_representative_frame" => exec_get_representative_frame(pool, &args).await,
_ => Err(format!("Unknown tool: {}", name)),
};
let content = match result {
Ok(s) => s,
Err(e) => serde_json::json!({"error": e}).to_string(),
};
let tool_call_id = tool_call.id.clone().unwrap_or_default();
(tool_call_id, name, content)
}
// ── Tool Loop ─────────────────────────────────────────────────────
const MAX_ROUNDS: u32 = 5;
async fn run_tool_loop(
pool: &sqlx::PgPool,
system_prompt: &str,
user_query: &str,
history: Vec<ChatMessage>,
) -> (String, Vec<serde_json::Value>) {
let mut messages = function_calling::build_conversation(system_prompt, user_query, history);
let mut sources = Vec::new();
for round in 0..MAX_ROUNDS {
let tools = Some(make_tools(pool));
match function_calling::call_llm(messages.clone(), tools, 2048, 120).await {
Ok(LlmResponse::Text(text)) => {
return (text, sources);
}
Ok(LlmResponse::ToolCalls(calls)) => {
// Push assistant message with tool_calls so Gemma4 remembers
messages.push(ChatMessage {
role: "assistant".to_string(),
content: None,
tool_calls: Some(calls.clone()),
tool_call_id: None,
name: None,
});
for call in &calls {
let (tool_call_id, name, content) = execute_tool(pool, call).await;
sources.push(serde_json::json!({"tool": name, "result": content}));
messages.push(function_calling::make_tool_result(&tool_call_id, &name, &content));
}
}
Err(e) => {
return (format!("系統錯誤:{}", e), sources);
}
}
}
("已達到最大查詢次數,請縮小問題範圍後重新詢問。".to_string(), sources)
}
// ── Handler ───────────────────────────────────────────────────────
async fn agent_search(
State(state): State<AppState>,
Json(req): Json<AgentSearchRequest>,
) -> Result<Json<AgentSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
let (conv_id, history) = get_or_create_conv(req.conversation_id.as_deref());
let (answer, sources) = run_tool_loop(
state.db.pool(),
SYSTEM_PROMPT,
&req.query,
history,
)
.await;
// Save updated messages for conversation continuation
let new_msgs = function_calling::build_conversation(SYSTEM_PROMPT, &req.query, vec![]);
save_messages(&conv_id, &new_msgs);
let needs_input = answer.contains('') || answer.contains('?');
let suggestions = if needs_input {
Some(vec!["演員名".to_string(), "電影片名".to_string(), "年份".to_string()])
} else {
None
};
Ok(Json(AgentSearchResponse {
success: true,
conversation_id: conv_id,
answer,
suggestions,
sources: Some(sources),
}))
}
// ── Routes ─────────────────────────────────────────────────────────
pub fn agent_search_routes() -> Router<AppState> {
Router::new()
.route("/api/v1/agents/search", post(agent_search))
}