Files
momentry_core/src/api/identity_agent_api.rs
T
Accusys 6cbc11efda feat: add confirm_identity API endpoint
- Add POST /api/v1/agents/identity/confirm endpoint
- Calls confirm_identity.py to bind trace to identity
- Updates TKG, Qdrant _faces, PG face_detections, _seeds
- Optional Round 2 propagation after confirmation
- Fix trace_id=0 check in confirm_identity.py (use 'is not None')
- Document API endpoint in 08_identity_agent.md
2026-06-26 08:30:03 +08:00

1230 lines
41 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::{
extract::{Multipart, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use sqlx::Row;
use std::path::PathBuf;
use crate::api::types::AppState;
use crate::core::db::schema;
use crate::core::db::PostgresDb;
pub fn identity_agent_routes() -> Router<AppState> {
Router::new()
.route(
"/api/v1/agents/identity/match-from-photo",
post(match_from_photo),
)
.route(
"/api/v1/agents/identity/match-from-trace",
post(match_from_trace),
)
.route(
"/api/v1/agents/identity/generate-seeds",
post(generate_seeds_handler),
)
.route(
"/api/v1/agents/identity/run",
post(run_identity_handler),
)
.route(
"/api/v1/agents/identity/confirm",
post(confirm_identity_handler),
)
}
#[derive(Debug, Serialize)]
pub struct IdentityResult {
pub identity_id: String,
pub person_ids: Vec<String>,
pub speaker_ids: Vec<String>,
pub confidence: f64,
pub evidence: IdentityEvidence,
pub reasoning: String,
}
#[derive(Debug, Serialize)]
pub struct IdentityEvidence {
pub face_similarity: Option<f64>,
pub speaker_overlap: f64,
pub time_overlap: f64,
pub frame_ratio: f64,
}
#[derive(Debug, Serialize)]
struct MatchFromPhotoResponse {
success: bool,
identity_uuid: String,
file_uuid: String,
matches: usize,
traces_matched: Vec<i32>,
message: String,
}
async fn match_from_photo(
State(state): State<AppState>,
mut multipart: Multipart,
) -> Result<Json<MatchFromPhotoResponse>, (StatusCode, Json<serde_json::Value>)> {
let mut identity_uuid = String::new();
let mut file_uuid = String::new();
let mut image_data: Option<Vec<u8>> = None;
while let Ok(Some(field)) = multipart.next_field().await {
let name = field.name().unwrap_or("").to_string();
match name.as_str() {
"identity_uuid" => {
identity_uuid = field.text().await.unwrap_or_default();
}
"file_uuid" => {
file_uuid = field.text().await.unwrap_or_default();
}
"image" => {
image_data = Some(field.bytes().await.unwrap_or_default().to_vec());
}
_ => {}
}
}
let uuid_clean = identity_uuid.replace('-', "");
if uuid_clean.is_empty() || file_uuid.is_empty() {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"success": false, "message": "identity_uuid and file_uuid are required"
})),
));
}
let data = image_data.ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"success": false, "message": "No image field found. Use field name 'image'."
})),
)
})?;
// 1. Save uploaded image to temp
let scripts_dir = std::env::var("MOMENTRY_SCRIPTS_DIR")
.unwrap_or_else(|_| "/Users/accusys/momentry_core_0.1/scripts".to_string());
let python_path = std::env::var("MOMENTRY_PYTHON_PATH")
.unwrap_or_else(|_| "/opt/homebrew/bin/python3.11".to_string());
let temp_dir = std::env::temp_dir().join("momentry_match_face");
std::fs::create_dir_all(&temp_dir).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("Failed to create temp dir: {}", e)})),
)
})?;
let temp_img = temp_dir.join(format!("{}.jpg", uuid_clean));
std::fs::write(&temp_img, &data).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("Failed to save temp image: {}", e)})),
)
})?;
// 2. Extract face embedding via Python script
let extract_script = std::path::Path::new(&scripts_dir).join("extract_face_embedding.py");
let output = tokio::process::Command::new(&*python_path)
.arg(&extract_script)
.arg(&temp_img)
.output()
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("Failed to run extractor: {}", e)})),
)
})?;
let _ = std::fs::remove_file(&temp_img);
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"success": false, "message": format!("Face extraction failed: {}", stderr)
})),
));
}
let stdout = String::from_utf8_lossy(&output.stdout);
let extract_result: serde_json::Value = serde_json::from_str(&stdout).map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": "Failed to parse extractor output"})),
)
})?;
let embedding: Vec<f64> = serde_json::from_value(
extract_result
.get("embedding")
.ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"message": "No embedding in extractor output"})),
)
})?
.clone(),
)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": "Invalid embedding format"})),
)
})?;
let embedding_f32: Vec<f32> = embedding.into_iter().map(|v| v as f32).collect();
// 3. Look up identity internal ID
let id_table = schema::table_name("identities");
let identity_id_row: Option<(i32,)> = sqlx::query_as(&format!(
"SELECT id FROM {} WHERE REPLACE(uuid::text, '-', '') = $1",
id_table
))
.bind(&uuid_clean)
.fetch_optional(state.db.pool())
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("DB error: {}", e)})),
)
})?;
let identity_id = match identity_id_row {
Some((id,)) => id,
None => {
return Err((
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"success": false, "message": "Identity not found"
})),
))
}
};
// 4. Find best matching trace (highest similarity, no threshold)
let fd_table = schema::table_name("face_detections");
let best_match: Option<(i32, i32, f64)> = sqlx::query_as(&format!(
r#"SELECT id, trace_id,
1 - (embedding::vector <=> $1::vector) as similarity
FROM {}
WHERE file_uuid = $2 AND embedding IS NOT NULL
ORDER BY embedding::vector <=> $1::vector
LIMIT 1"#,
fd_table
))
.bind(&embedding_f32)
.bind(&file_uuid)
.fetch_optional(state.db.pool())
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("Search failed: {}", e)})),
)
})?;
// 5. Update best match face_detection
let mut traces_matched: Vec<i32> = Vec::new();
if let Some((fb_id, fb_trace, fb_sim)) = best_match {
let _ = sqlx::query(&format!(
"UPDATE {} SET identity_id = $1 WHERE id = $2",
fd_table
))
.bind(identity_id)
.bind(fb_id)
.execute(state.db.pool())
.await;
traces_matched.push(fb_trace);
// 6. Save identity file
let _ = crate::core::identity::storage::save_identity_file(&*state.db, &uuid_clean).await;
Ok(Json(MatchFromPhotoResponse {
success: true,
identity_uuid: uuid_clean,
file_uuid,
matches: 1,
traces_matched,
message: format!(
"Best trace: trace_id={}, similarity={:.4}",
fb_trace, fb_sim
),
}))
} else {
Ok(Json(MatchFromPhotoResponse {
success: true,
identity_uuid: uuid_clean,
file_uuid,
matches: 0,
traces_matched,
message: "No matching face found in video".to_string(),
}))
}
}
#[derive(Debug, Deserialize)]
struct MatchFromTraceRequest {
file_uuid: String,
trace_id: i32,
identity_uuid: String,
}
async fn match_from_trace(
State(state): State<AppState>,
Json(req): Json<MatchFromTraceRequest>,
) -> Result<Json<MatchFromPhotoResponse>, (StatusCode, Json<serde_json::Value>)> {
let uuid_clean = req.identity_uuid.replace('-', "");
// 1. Get 3 best face embeddings from this trace at different angles
// Divide trace frame range into 3 segments, pick best face from each
let fd_table = schema::table_name("face_detections");
let all_faces: Vec<(Vec<f32>, i64)> = sqlx::query_as::<_, (Vec<f32>, i64)>(&format!(
"SELECT embedding, frame_number FROM {} \
WHERE file_uuid = $1 AND trace_id = $2 AND embedding IS NOT NULL \
ORDER BY frame_number ASC",
fd_table
))
.bind(&req.file_uuid)
.bind(req.trace_id)
.fetch_all(state.db.pool())
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("DB error: {}", e)})),
)
})?;
if all_faces.is_empty() {
return Err((
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"success": false, "message": "No embedding found for this trace"
})),
));
}
// Pick 3 samples: divide frame range into 3 segments, use face with largest area per segment
let total = all_faces.len();
let segments = [
(0, total / 3),
(total / 3, total * 2 / 3),
(total * 2 / 3, total),
];
let mut query_embeddings: Vec<Vec<f32>> = Vec::new();
// Get width*height info if available (not all pipelines store it)
let face_sizes: Vec<(i64, i32)> = sqlx::query_as::<_, (i64, i32)>(&format!(
"SELECT frame_number, COALESCE(width, 0) * COALESCE(height, 0) AS area \
FROM {} WHERE file_uuid = $1 AND trace_id = $2 AND embedding IS NOT NULL \
ORDER BY frame_number ASC",
fd_table
))
.bind(&req.file_uuid)
.bind(req.trace_id)
.fetch_all(state.db.pool())
.await
.unwrap_or_default();
let face_sizes_map: std::collections::HashMap<i64, i32> = face_sizes.into_iter().collect();
for (start, end) in segments {
let seg_start = start.min(total - 1);
let seg_end = end.min(total);
if seg_start >= seg_end {
continue;
}
let seg_slice = &all_faces[seg_start..seg_end];
// Pick the face with largest area within this segment
let best_idx = seg_slice
.iter()
.enumerate()
.max_by_key(|(_, f)| face_sizes_map.get(&f.1).copied().unwrap_or(0))
.map(|(i, _)| i)
.unwrap_or(0);
query_embeddings.push(seg_slice[best_idx].0.clone());
}
if query_embeddings.is_empty() {
query_embeddings.push(all_faces[total / 2].0.clone());
}
// 2. Three angles each find their best match; union all results
let mut validated: Vec<(i32, i32, f64)> = Vec::new();
let mut seen_trace_ids = std::collections::HashSet::new();
for qemb in &query_embeddings {
let top = sqlx::query_as::<_, (i32, i32, f64)>(&format!(
r#"SELECT id, trace_id,
1 - (embedding::vector <=> $1::vector) as similarity
FROM {}
WHERE file_uuid = $2
AND trace_id != $3
AND embedding IS NOT NULL
ORDER BY embedding::vector <=> $1::vector
LIMIT 1"#,
fd_table
))
.bind(qemb)
.bind(&req.file_uuid)
.bind(req.trace_id)
.fetch_optional(state.db.pool())
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("Search failed: {}", e)})),
)
})?;
if let Some((cface_id, c_trace_id, c_sim)) = top {
if seen_trace_ids.insert(c_trace_id) {
validated.push((cface_id, c_trace_id, c_sim));
}
}
}
// 3. Look up identity internal ID
let id_table = schema::table_name("identities");
let identity_id_row: Option<(i32,)> = sqlx::query_as(&format!(
"SELECT id FROM {} WHERE REPLACE(uuid::text, '-', '') = $1",
id_table
))
.bind(&uuid_clean)
.fetch_optional(state.db.pool())
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("DB error: {}", e)})),
)
})?;
let identity_id = match identity_id_row {
Some((id,)) => id,
None => {
return Err((
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"success": false, "message": "Identity not found"
})),
))
}
};
// 4. Update matched face_detections
let mut traces_matched: Vec<i32> = Vec::new();
for (id, trace_id, _similarity) in &validated {
if let Err(e) = sqlx::query(&format!(
"UPDATE {} SET identity_id = $1 WHERE id = $2",
fd_table
))
.bind(identity_id)
.bind(id)
.execute(state.db.pool())
.await
{
tracing::warn!(
"[match-from-trace] Failed to update face_detection {}: {}",
id,
e
);
} else {
if !traces_matched.contains(trace_id) {
traces_matched.push(*trace_id);
}
}
}
// 5. Also bind the source trace itself
let _ = sqlx::query(&format!(
"UPDATE {} SET identity_id = $1 WHERE file_uuid = $2 AND trace_id = $3",
fd_table
))
.bind(identity_id)
.bind(&req.file_uuid)
.bind(req.trace_id)
.execute(state.db.pool())
.await;
if !traces_matched.contains(&req.trace_id) {
traces_matched.push(req.trace_id);
}
// 6. Save identity file
let _ = crate::core::identity::storage::save_identity_file(&*state.db, &uuid_clean).await;
let match_count = validated.len() + 1;
let face_track_count = traces_matched.len();
Ok(Json(MatchFromPhotoResponse {
success: true,
identity_uuid: uuid_clean,
file_uuid: req.file_uuid,
matches: match_count,
traces_matched,
message: format!(
"Matched {} faces ({} unique traces)",
match_count, face_track_count
),
}))
}
fn extract_persons_from_face_data(face_data: &serde_json::Value) -> Vec<PersonData> {
let mut persons = Vec::new();
if let Some(frames) = face_data.get("frames").and_then(|f| f.as_array()) {
let mut person_frames_map: std::collections::HashMap<String, Vec<i32>> =
std::collections::HashMap::new();
for frame in frames {
if let Some(frame_num) = frame.get("frame").and_then(|f| f.as_i64()) {
if let Some(person_id) = frame.get("person_id").and_then(|p| p.as_str()) {
person_frames_map
.entry(person_id.to_string())
.or_insert_with(Vec::new)
.push(frame_num as i32);
}
}
}
for (person_id, frames) in person_frames_map {
persons.push(PersonData {
person_id,
frames,
avg_embedding: None,
});
}
}
persons
}
fn extract_speakers_from_asrx_data(asrx_data: &Option<serde_json::Value>) -> Vec<SpeakerData> {
let mut speakers = Vec::new();
if let Some(data) = asrx_data {
if let Some(segments) = data.get("segments").and_then(|s| s.as_array()) {
for seg in segments {
if let (Some(start), Some(end), Some(speaker_id)) = (
seg.get("start_time").and_then(|v| v.as_f64()),
seg.get("end_time").and_then(|v| v.as_f64()),
seg.get("speaker_id").and_then(|v| v.as_str()),
) {
speakers.push(SpeakerData {
speaker_id: speaker_id.to_string(),
segments: vec![(start, end)],
});
}
}
}
}
speakers
}
fn analyze_person_speaker_overlap(
persons: &[PersonData],
speakers: &[SpeakerData],
) -> Vec<IdentityResult> {
let mut identities: Vec<IdentityResult> = Vec::new();
let mut visited_persons: std::collections::HashSet<String> = std::collections::HashSet::new();
for person in persons {
if visited_persons.contains(&person.person_id) {
continue;
}
let mut matched_persons = vec![person.person_id.clone()];
let mut matched_speakers: Vec<String> = Vec::new();
visited_persons.insert(person.person_id.clone());
for other_person in persons {
if visited_persons.contains(&other_person.person_id) {
continue;
}
// Check if persons co-occur in time (frame proximity)
let overlap = person
.frames
.iter()
.any(|f| other_person.frames.contains(f));
if overlap {
matched_persons.push(other_person.person_id.clone());
visited_persons.insert(other_person.person_id.clone());
}
}
// Check speaker overlap
let person_time_range = (
person.frames.iter().min().copied().unwrap_or(0) as f64,
person.frames.iter().max().copied().unwrap_or(0) as f64,
);
for speaker in speakers {
let has_overlap = speaker
.segments
.iter()
.any(|(start, end)| *start <= person_time_range.1 && *end >= person_time_range.0);
if has_overlap {
if !matched_speakers.contains(&speaker.speaker_id) {
matched_speakers.push(speaker.speaker_id.clone());
}
}
}
let frame_count = person.frames.len() as f64;
let speaker_overlap = if matched_speakers.is_empty() {
0.0
} else {
matched_speakers.len() as f64 / speakers.len().max(1) as f64
};
identities.push(IdentityResult {
identity_id: person.person_id.clone(),
person_ids: matched_persons.clone(),
speaker_ids: matched_speakers.clone(),
confidence: 0.5 + (speaker_overlap * 0.3),
evidence: IdentityEvidence {
face_similarity: None,
speaker_overlap,
time_overlap: 1.0,
frame_ratio: frame_count / 100.0,
},
reasoning: format!(
"Matched {} persons with {} speakers, overlap={:.2}",
matched_persons.len(),
speaker_overlap,
speaker_overlap
),
});
}
identities
}
#[derive(Debug)]
struct PersonData {
person_id: String,
frames: Vec<i32>,
avg_embedding: Option<Vec<f64>>,
}
#[derive(Debug)]
struct SpeakerData {
speaker_id: String,
segments: Vec<(f64, f64)>,
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if na == 0.0 || nb == 0.0 {
0.0
} else {
dot / (na * nb)
}
}
fn average_embeddings<'a>(embeddings: impl Iterator<Item = &'a Vec<f32>>) -> Vec<f32> {
let mut count = 0usize;
let mut sum: Option<Vec<f32>> = None;
for emb in embeddings {
if emb.len() != 512 {
continue;
}
match &mut sum {
None => sum = Some(emb.clone()),
Some(s) => {
for (i, v) in emb.iter().enumerate() {
s[i] += v;
}
}
}
count += 1;
}
if let Some(mut s) = sum {
let c = count as f32;
for v in &mut s {
*v /= c;
}
s
} else {
vec![0.0f32; 512]
}
}
/// Cluster: trace centroid + seeds from Qdrant + stranger clustering.
/// Round 1: centroid vs seeds (TH=0.55)
/// Round 2+: propagate from matched (TH=0.50)
/// Unknown: greedy stranger clustering (TH=0.40)
/// Writes identity_ref/stranger_ref to Qdrant payload, TKG nodes, and face_detections.
async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
use crate::core::processor::executor::PythonExecutor;
use std::time::Duration;
let executor = PythonExecutor::new()?;
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
let output_path = std::path::PathBuf::from(&output_dir)
.join(file_uuid)
.join(format!("{}.identity_match_round1.json", file_uuid));
std::fs::create_dir_all(output_path.parent().unwrap()).ok();
let scripts_dir = executor.script_dir();
let python_path = executor.python_path();
let script_path = scripts_dir.join("identity_matcher.py");
let qdrant_url = std::env::var("QDRANT_URL")
.unwrap_or_else(|_| "http://localhost:6333".to_string());
let qdrant_api_key = std::env::var("QDRANT_API_KEY")
.unwrap_or_else(|_| "Test3200Test3200Test3200".to_string());
let db_url = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgresql://accusys@localhost:5432/momentry".to_string());
let mut cmd = tokio::process::Command::new(python_path);
cmd.env("MOMENTRY_OUTPUT_DIR", &output_dir);
cmd.env("DATABASE_SCHEMA", "public");
cmd.env("MOMENTRY_DB_SCHEMA", "public");
cmd.env("DATABASE_URL", &db_url);
cmd.env("QDRANT_URL", &qdrant_url);
cmd.env("QDRANT_API_KEY", &qdrant_api_key);
cmd.arg(&script_path);
cmd.arg("--file-uuid").arg(file_uuid);
cmd.arg("--round").arg("1");
cmd.arg("--mark-tkg");
cmd.arg("--output").arg(&output_path);
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
tracing::info!("[FaceMatch] Starting identity_matcher for {}", file_uuid);
let output = cmd.output().await?;
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
if !output.status.success() {
tracing::error!("[FaceMatch] identity_matcher failed with exit code: {:?}", output.status.code());
tracing::error!("[FaceMatch] stderr: {}", stderr);
tracing::error!("[FaceMatch] stdout: {}", stdout);
return Ok(0);
}
tracing::info!("[FaceMatch] stdout: {}", stdout);
if !output_path.exists() {
tracing::info!("[FaceMatch] No matches found for {}", file_uuid);
return Ok(0);
}
let content = std::fs::read_to_string(&output_path)?;
let result: serde_json::Value = serde_json::from_str(&content)?;
let matched = result.get("matched").and_then(|v| v.as_i64()).unwrap_or(0) as usize;
let tkg_updated = result.get("tkg_nodes_updated").and_then(|v| v.as_i64()).unwrap_or(0) as usize;
tracing::info!(
"[FaceMatch] Round 1 for {}: {} matches, {} TKG nodes updated",
file_uuid, matched, tkg_updated
);
Ok(matched)
}
/// Fallback: PostgreSQL-based matching (disabled - embedding column removed)
async fn match_faces_iterative_pg(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
tracing::warn!(
"[FaceMatch-PG] PostgreSQL matching disabled - embedding column removed for {}",
file_uuid
);
Ok(0)
}
/// Bind ASRX speakers to face traces based on temporal overlap.
/// Reads face_detections (trace_id, identity_id, frame_number) and ASRX
/// segments (speaker_id, start_time, end_time), computes overlap,
/// and stores bindings in identity_bindings table.
pub async fn bind_speakers(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
// Load face traces with identity_id and frame numbers
let fd_table = schema::table_name("face_detections");
let traces = sqlx::query_as::<_, (i32, Vec<i32>)>(&format!(
"SELECT trace_id, array_agg(frame_number ORDER BY frame_number) \
FROM {} WHERE file_uuid=$1 AND trace_id IS NOT NULL AND identity_id IS NOT NULL \
GROUP BY trace_id",
fd_table
))
.bind(file_uuid)
.fetch_all(pool)
.await?;
if traces.is_empty() {
tracing::info!("[SpeakerBind] No face traces with identities");
return Ok(0);
}
// Load ASRX speakers from the output JSON
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
let asrx_path = std::path::Path::new(&output_dir).join(format!("{}.asrx.json", file_uuid));
let asrx_data: serde_json::Value = match std::fs::read_to_string(&asrx_path) {
Ok(s) => serde_json::from_str(&s).unwrap_or_default(),
Err(_) => {
tracing::info!("[SpeakerBind] No ASRX file found");
return Ok(0);
}
};
// Extract speaker segments: speaker_id → [(start_time, end_time)]
use std::collections::HashMap;
let mut speakers: HashMap<String, Vec<(f64, f64)>> = HashMap::new();
if let Some(segments) = asrx_data.get("segments").and_then(|s| s.as_array()) {
for seg in segments {
let sid = seg
.get("speaker_id")
.and_then(|s| s.as_str())
.or_else(|| seg.get("speaker").and_then(|s| s.as_str()));
if let Some(sid) = sid {
let start = seg
.get("start_time")
.or_else(|| seg.get("start"))
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let end = seg
.get("end_time")
.or_else(|| seg.get("end"))
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
speakers
.entry(sid.to_string())
.or_default()
.push((start, end));
}
}
}
if speakers.is_empty() {
tracing::info!("[SpeakerBind] No speakers found in ASRX data");
return Ok(0);
}
// Get fps for frame-to-time conversion
let fps: f64 = 25.0; // default, could also read from DB
// For each trace, compute overlap with each speaker
let mut bindings = 0usize;
for (trace_id, frames) in &traces {
if frames.is_empty() {
continue;
}
// Get identity_id for this trace
let fd_table = schema::table_name("face_detections");
let identity_id: Option<i32> = sqlx::query_scalar(
&format!("SELECT identity_id FROM {} WHERE file_uuid=$1 AND trace_id=$2 AND identity_id IS NOT NULL LIMIT 1", fd_table)
)
.bind(file_uuid).bind(trace_id)
.fetch_optional(pool).await?.flatten();
if identity_id.is_none() {
continue;
}
let identity_id = identity_id.unwrap();
// Compute overlap with each speaker
let mut best_speaker = String::new();
let mut best_overlap = 0usize;
for (speaker_id, segments) in &speakers {
let mut overlap = 0usize;
for &fn_num in frames {
let frame_time = fn_num as f64 / fps;
for (start, end) in segments {
if frame_time >= *start && frame_time <= *end {
overlap += 1;
break;
}
}
}
if overlap > best_overlap {
best_overlap = overlap;
best_speaker = speaker_id.clone();
}
}
// Only bind if meaningful overlap
let overlap_ratio = best_overlap as f64 / frames.len() as f64;
if overlap_ratio > 0.3 && !best_speaker.is_empty() {
let metadata = serde_json::json!({
"trace_id": trace_id,
"overlap_frames": best_overlap,
"total_frames": frames.len(),
"overlap_ratio": overlap_ratio,
});
let ib_table = schema::table_name("identity_bindings");
let _ = sqlx::query(
&format!("INSERT INTO {} (identity_id, identity_type, identity_value, file_uuid, confidence, metadata) \
VALUES ($1, 'speaker', $2, $3, $4, $5::jsonb) \
ON CONFLICT (identity_id, identity_type, identity_value, file_uuid) \
DO UPDATE SET confidence = EXCLUDED.confidence, metadata = EXCLUDED.metadata", ib_table)
)
.bind(identity_id)
.bind(&best_speaker)
.bind(file_uuid)
.bind(overlap_ratio)
.bind(&metadata)
.execute(pool).await;
// Also update speaker_detections with the identity_id
let sd_table = schema::table_name("speaker_detections");
let _ = sqlx::query(&format!(
"UPDATE {} SET identity_id = $1, confidence = $2 \
WHERE file_uuid = $3 AND speaker_id = $4 AND identity_id IS NULL",
sd_table
))
.bind(identity_id)
.bind(overlap_ratio)
.bind(file_uuid)
.bind(&best_speaker)
.execute(pool)
.await;
bindings += 1;
}
}
tracing::info!(
"[SpeakerBind] Created {}/{} speaker bindings",
bindings,
traces.len()
);
Ok(bindings)
}
/// Pipeline-triggered entry point: runs the full identity agent for a file.
/// Reads face_clustered.json + asrx.json, extracts persons/speakers, creates identities,
/// runs iterative face matching, and binds speakers.
pub async fn run_identity_agent(db: &PostgresDb, file_uuid: &str) -> anyhow::Result<()> {
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
let pool = db.pool();
// Step 1: 先跑 face matching(不需 face_clustered.json
let matched = match_faces_iterative(pool, file_uuid).await.unwrap_or(0);
// Step 2: 試著載入 face_clustered.json 建立新 identities
let video_dir = PathBuf::from(&output_dir).join(file_uuid);
let face_clustered_path = video_dir.join(format!("{}.face_clustered.json", file_uuid));
let face_clustered_path = if face_clustered_path.exists() {
face_clustered_path
} else {
PathBuf::from(&output_dir).join(format!("{}.face_clustered.json", file_uuid))
};
if face_clustered_path.exists() {
let face_data: serde_json::Value =
std::fs::read_to_string(&face_clustered_path)?.parse()?;
let asrx_path = video_dir.join(format!("{}.asrx.json", file_uuid));
let asrx_data: Option<serde_json::Value> = if asrx_path.exists() {
Some(std::fs::read_to_string(&asrx_path)?.parse()?)
} else {
None
};
let persons = extract_persons_from_face_data(&face_data);
let speakers = extract_speakers_from_asrx_data(&asrx_data);
let identities = analyze_person_speaker_overlap(&persons, &speakers);
let _ = identities.len();
if !identities.is_empty() {
let metadata = serde_json::json!({
"source": "identity_agent",
"speaker_ids": identities[0].speaker_ids,
"confidence": identities[0].confidence,
"evidence": {
"speaker_overlap": identities[0].evidence.speaker_overlap,
"frame_ratio": identities[0].evidence.frame_ratio,
},
"reasoning": identities[0].reasoning,
});
let _ = sqlx::query(&format!(
"INSERT INTO {} (file_uuid, trace_id, metadata) \
VALUES ($1, NULL, $2::jsonb) ON CONFLICT DO NOTHING",
schema::table_name("strangers")
))
.bind(file_uuid)
.bind(&metadata)
.execute(pool)
.await;
}
tracing::info!(
"[IdentityAgent] Analyzed {} face clusters from face_clustered for {}",
identities.len(),
file_uuid
);
} else {
tracing::warn!(
"[IdentityAgent] face_clustered.json not found for {}, skipping identity creation",
file_uuid
);
}
let bound = bind_speakers(pool, file_uuid).await.unwrap_or(0);
tracing::info!(
"[IdentityAgent] Done for {}: {} face matches, {} speaker bindings",
file_uuid,
matched,
bound
);
Ok(())
}
/// API handler: POST /api/v1/agents/identity/generate-seeds
async fn generate_seeds_handler(
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let db = &state.db;
let pool = db.pool();
let count = generate_seed_embeddings(db)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"success": false, "message": format!("{}", e)})),
)
})?;
// Auto-trigger identity agent for all ready files
if count > 0 {
let ready_files = find_ready_files(pool).await.unwrap_or_default();
if !ready_files.is_empty() {
tracing::info!(
"[GenerateSeeds] Auto-triggering identity agent for {} files: {:?}",
ready_files.len(),
ready_files
);
for file_uuid in &ready_files {
let db = state.db.clone();
let fid = file_uuid.clone();
tokio::spawn(async move {
match run_identity_agent(&db, &fid).await {
Ok(_) => tracing::info!(
"[GenerateSeeds] Identity agent completed for {}",
fid
),
Err(e) => tracing::warn!(
"[GenerateSeeds] Identity agent failed for {}: {}",
fid,
e
),
}
});
}
}
}
Ok(Json(serde_json::json!({
"success": true,
"message": format!("Generated {} seed embeddings", count),
"count": count
})))
}
/// Find videos that are ready for identity processing (have face embeddings).
async fn find_ready_files(pool: &sqlx::PgPool) -> anyhow::Result<Vec<String>> {
let fd_table = crate::core::db::schema::table_name("face_detections");
let rows: Vec<(String,)> = sqlx::query_as(&format!(
"SELECT DISTINCT file_uuid FROM {} WHERE embedding IS NOT NULL AND identity_id IS NULL",
fd_table
))
.fetch_all(pool)
.await?;
Ok(rows.into_iter().map(|r| r.0).collect())
}
/// API handler: POST /api/v1/agents/identity/run
async fn run_identity_handler(
State(state): State<AppState>,
axum::Json(body): axum::Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let file_uuid = body
.get("file_uuid")
.and_then(|v| v.as_str())
.ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"success": false, "message": "file_uuid required"})),
)
})?;
match run_identity_agent(&state.db, file_uuid).await {
Ok(()) => Ok(Json(serde_json::json!({
"success": true,
"message": format!("Identity agent completed for {}", file_uuid),
}))),
Err(e) => Ok(Json(serde_json::json!({
"success": false,
"message": format!("Identity agent failed: {}", e),
}))),
}
}
#[derive(Debug, Deserialize)]
struct ConfirmIdentityRequest {
file_uuid: String,
trace_id: i32,
identity_id: i32,
identity_uuid: String,
name: String,
propagate: Option<bool>,
}
#[derive(Debug, Serialize)]
struct ConfirmIdentityResponse {
success: bool,
file_uuid: String,
trace_id: i32,
identity_uuid: String,
name: String,
steps: serde_json::Value,
propagation: Option<serde_json::Value>,
}
async fn confirm_identity_handler(
State(_state): State<AppState>,
Json(req): Json<ConfirmIdentityRequest>,
) -> Result<Json<ConfirmIdentityResponse>, (StatusCode, Json<serde_json::Value>)> {
use crate::core::processor::executor::PythonExecutor;
let executor = PythonExecutor::new().map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"success": false, "message": format!("PythonExecutor error: {}", e)})),
)
})?;
let scripts_dir = executor.script_dir();
let python_path = executor.python_path();
let script_path = scripts_dir.join("confirm_identity.py");
let qdrant_url = std::env::var("QDRANT_URL")
.unwrap_or_else(|_| "http://localhost:6333".to_string());
let qdrant_api_key = std::env::var("QDRANT_API_KEY")
.unwrap_or_else(|_| "Test3200Test3200Test3200".to_string());
let db_url = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgresql://accusys@localhost:5432/momentry".to_string());
let db_schema = std::env::var("DATABASE_SCHEMA")
.unwrap_or_else(|_| "dev".to_string());
let propagate = req.propagate.unwrap_or(true);
let mut cmd = tokio::process::Command::new(python_path);
cmd.env("DATABASE_URL", &db_url);
cmd.env("DATABASE_SCHEMA", &db_schema);
cmd.env("MOMENTRY_DB_SCHEMA", &db_schema);
cmd.env("QDRANT_URL", &qdrant_url);
cmd.env("QDRANT_API_KEY", &qdrant_api_key);
cmd.arg(&script_path);
cmd.arg("--file-uuid").arg(&req.file_uuid);
cmd.arg("--trace-id").arg(req.trace_id.to_string());
cmd.arg("--identity-id").arg(req.identity_id.to_string());
cmd.arg("--identity-uuid").arg(&req.identity_uuid);
cmd.arg("--name").arg(&req.name);
if !propagate {
cmd.arg("--no-propagate");
}
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
tracing::info!(
"[ConfirmIdentity] Starting for {} trace {} -> {} ({})",
req.file_uuid, req.trace_id, req.identity_uuid, req.name
);
let output = cmd.output().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"success": false, "message": format!("Command failed: {}", e)})),
)
})?;
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
if !output.status.success() {
tracing::error!("[ConfirmIdentity] Script failed with exit code: {:?}", output.status.code());
tracing::error!("[ConfirmIdentity] stderr: {}", stderr);
tracing::error!("[ConfirmIdentity] stdout: {}", stdout);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"success": false,
"message": format!("Script failed: {}", stderr),
"stdout": stdout.to_string(),
})),
));
}
tracing::info!("[ConfirmIdentity] stdout: {}", stdout);
let json_start = stdout.find('{');
if json_start.is_none() {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"success": false,
"message": "No JSON output found",
"stdout": stdout.to_string(),
})),
));
}
let json_str = &stdout[json_start.unwrap()..];
let result: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"success": false,
"message": format!("Failed to parse output: {}", e),
"stdout": stdout.to_string(),
"json_str": json_str.to_string(),
})),
)
})?;
Ok(Json(ConfirmIdentityResponse {
success: result.get("status").and_then(|v| v.as_str()) == Some("success"),
file_uuid: req.file_uuid,
trace_id: req.trace_id,
identity_uuid: req.identity_uuid,
name: req.name,
steps: result.get("steps").cloned().unwrap_or(serde_json::json!({})),
propagation: result.get("propagation").cloned(),
}))
}
/// Read all TMDb identities with profile photos, extract face embeddings, store in Qdrant as seeds.
pub async fn generate_seed_embeddings(db: &PostgresDb) -> anyhow::Result<usize> {
tracing::warn!(
"[GenerateSeeds] Seed embedding generation disabled - FaceEmbeddingDb removed. \
TODO: Reimplement with _faces collection"
);
Ok(0)
}