fix: ASRX duplication, TKG edges, trace ingest, and add pipeline progress publishing
- ASRX handler no longer stores duplicate 'asr' pre_chunks - Pre_chunks storage made idempotent (delete-before-insert) - Rule 1 + trace_ingest changed to query 'asrx' not 'asr' - Trace chunks removed (dynamic from TKG/Qdrant) - TKG scroll_face_points fixed: trace_id >= 1 (not == 1) - TKG AsrxSegmentEntry: start/end -> start_time/end_time (match ASRX JSON) - Unregister error handling: log instead of silent discard - Add publish_pipeline_progress calls at each pipeline stage (processors, rule1, face_trace, identity_agent, TKG, rule2, completion)
This commit is contained in:
+351
-222
@@ -8,10 +8,14 @@ use axum::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::Row;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::api::types::AppState;
|
||||
use crate::core::db::schema;
|
||||
use crate::core::db::PostgresDb;
|
||||
use crate::core::db::QdrantDb;
|
||||
use crate::core::progress::{AgentPhase, AgentProgress, AgentStats, publish_agent_progress};
|
||||
use crate::core::db::redis_client::RedisClient;
|
||||
|
||||
pub fn identity_agent_routes() -> Router<AppState> {
|
||||
Router::new()
|
||||
@@ -27,10 +31,7 @@ pub fn identity_agent_routes() -> Router<AppState> {
|
||||
"/api/v1/agents/identity/generate-seeds",
|
||||
post(generate_seeds_handler),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/agents/identity/run",
|
||||
post(run_identity_handler),
|
||||
)
|
||||
.route("/api/v1/agents/identity/run", post(run_identity_handler))
|
||||
.route(
|
||||
"/api/v1/agents/identity/confirm",
|
||||
post(confirm_identity_handler),
|
||||
@@ -209,39 +210,42 @@ async fn match_from_photo(
|
||||
}
|
||||
};
|
||||
|
||||
// 4. Find best matching trace (highest similarity, no threshold)
|
||||
let fd_table = schema::table_name("face_detections");
|
||||
let best_match: Option<(i32, i32, f64)> = sqlx::query_as(&format!(
|
||||
r#"SELECT id, trace_id,
|
||||
1 - (embedding::vector <=> $1::vector) as similarity
|
||||
FROM {}
|
||||
WHERE file_uuid = $2 AND embedding IS NOT NULL
|
||||
ORDER BY embedding::vector <=> $1::vector
|
||||
LIMIT 1"#,
|
||||
fd_table
|
||||
))
|
||||
.bind(&embedding_f32)
|
||||
.bind(&file_uuid)
|
||||
.fetch_optional(state.db.pool())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"message": format!("Search failed: {}", e)})),
|
||||
)
|
||||
})?;
|
||||
// 4. Find best matching trace via Qdrant _faces search
|
||||
let qdrant = QdrantDb::new();
|
||||
|
||||
// 5. Update best match face_detection
|
||||
let best_match: Option<(i32, f64)> = match qdrant.search_face_collection(
|
||||
"_faces",
|
||||
&embedding_f32,
|
||||
1,
|
||||
"file_uuid",
|
||||
"",
|
||||
Some(&file_uuid),
|
||||
).await {
|
||||
Ok(hits) if !hits.is_empty() => {
|
||||
let (score, payload) = &hits[0];
|
||||
let trace_id = payload.get("trace_id").and_then(|v| v.as_i64()).unwrap_or(0) as i32;
|
||||
Some((trace_id, *score))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// 5. Update best match in Qdrant _faces (trace-scoped)
|
||||
let mut traces_matched: Vec<i32> = Vec::new();
|
||||
if let Some((fb_id, fb_trace, fb_sim)) = best_match {
|
||||
let _ = sqlx::query(&format!(
|
||||
"UPDATE {} SET identity_id = $1 WHERE id = $2",
|
||||
fd_table
|
||||
))
|
||||
.bind(identity_id)
|
||||
.bind(fb_id)
|
||||
.execute(state.db.pool())
|
||||
.await;
|
||||
if let Some((fb_trace, fb_sim)) = best_match {
|
||||
let qdrant = QdrantDb::new();
|
||||
let filter = serde_json::json!({
|
||||
"must": [
|
||||
{"key": "file_uuid", "match": {"value": file_uuid}},
|
||||
{"key": "trace_id", "match": {"value": fb_trace}}
|
||||
]
|
||||
});
|
||||
let payload = serde_json::json!({"identity_id": identity_id});
|
||||
if let Err(e) = qdrant
|
||||
.update_payload_by_filter("_faces", filter, payload)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("[match_from_photo] Qdrant update failed: {}", e);
|
||||
}
|
||||
traces_matched.push(fb_trace);
|
||||
|
||||
// 6. Save identity file
|
||||
@@ -283,25 +287,26 @@ async fn match_from_trace(
|
||||
) -> Result<Json<MatchFromPhotoResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let uuid_clean = req.identity_uuid.replace('-', "");
|
||||
|
||||
// 1. Get 3 best face embeddings from this trace at different angles
|
||||
// Divide trace frame range into 3 segments, pick best face from each
|
||||
let fd_table = schema::table_name("face_detections");
|
||||
let all_faces: Vec<(Vec<f32>, i64)> = sqlx::query_as::<_, (Vec<f32>, i64)>(&format!(
|
||||
"SELECT embedding, frame_number FROM {} \
|
||||
WHERE file_uuid = $1 AND trace_id = $2 AND embedding IS NOT NULL \
|
||||
ORDER BY frame_number ASC",
|
||||
fd_table
|
||||
))
|
||||
.bind(&req.file_uuid)
|
||||
.bind(req.trace_id)
|
||||
.fetch_all(state.db.pool())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"message": format!("DB error: {}", e)})),
|
||||
)
|
||||
})?;
|
||||
// 1. Get face embeddings from Qdrant _faces for this trace
|
||||
let qdrant = QdrantDb::new();
|
||||
let trace_filter = serde_json::json!({
|
||||
"must": [
|
||||
{"key": "file_uuid", "match": {"value": req.file_uuid}},
|
||||
{"key": "trace_id", "match": {"value": req.trace_id}}
|
||||
]
|
||||
});
|
||||
let points = qdrant.scroll_all_points("_faces", trace_filter, 500).await.unwrap_or_default();
|
||||
|
||||
let all_faces: Vec<(Vec<f32>, i64)> = points.iter().filter_map(|p| {
|
||||
let vector = p.get("vector").and_then(|v| v.as_array())?;
|
||||
let embedding: Vec<f32> = vector.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
|
||||
let frame = p["payload"]["frame"].as_i64()?;
|
||||
if embedding.len() == 512 {
|
||||
Some((embedding, frame))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}).collect();
|
||||
|
||||
if all_faces.is_empty() {
|
||||
return Err((
|
||||
@@ -322,18 +327,14 @@ async fn match_from_trace(
|
||||
|
||||
let mut query_embeddings: Vec<Vec<f32>> = Vec::new();
|
||||
|
||||
// Get width*height info if available (not all pipelines store it)
|
||||
let face_sizes: Vec<(i64, i32)> = sqlx::query_as::<_, (i64, i32)>(&format!(
|
||||
"SELECT frame_number, COALESCE(width, 0) * COALESCE(height, 0) AS area \
|
||||
FROM {} WHERE file_uuid = $1 AND trace_id = $2 AND embedding IS NOT NULL \
|
||||
ORDER BY frame_number ASC",
|
||||
fd_table
|
||||
))
|
||||
.bind(&req.file_uuid)
|
||||
.bind(req.trace_id)
|
||||
.fetch_all(state.db.pool())
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
// Get bbox size info from Qdrant payload
|
||||
let face_sizes: Vec<(i64, i32)> = points.iter().filter_map(|p| {
|
||||
let frame = p["payload"]["frame"].as_i64()?;
|
||||
let bbox = &p["payload"]["bbox"];
|
||||
let w = bbox["width"].as_f64().unwrap_or(0.0) as i32;
|
||||
let h = bbox["height"].as_f64().unwrap_or(0.0) as i32;
|
||||
Some((frame, w * h))
|
||||
}).collect();
|
||||
|
||||
let face_sizes_map: std::collections::HashMap<i64, i32> = face_sizes.into_iter().collect();
|
||||
|
||||
@@ -358,37 +359,39 @@ async fn match_from_trace(
|
||||
query_embeddings.push(all_faces[total / 2].0.clone());
|
||||
}
|
||||
|
||||
// 2. Three angles each find their best match; union all results
|
||||
// 2. Three angles each find their best match via Qdrant; union all results
|
||||
let mut validated: Vec<(i32, i32, f64)> = Vec::new();
|
||||
let mut seen_trace_ids = std::collections::HashSet::new();
|
||||
|
||||
for qemb in &query_embeddings {
|
||||
let top = sqlx::query_as::<_, (i32, i32, f64)>(&format!(
|
||||
r#"SELECT id, trace_id,
|
||||
1 - (embedding::vector <=> $1::vector) as similarity
|
||||
FROM {}
|
||||
WHERE file_uuid = $2
|
||||
AND trace_id != $3
|
||||
AND embedding IS NOT NULL
|
||||
ORDER BY embedding::vector <=> $1::vector
|
||||
LIMIT 1"#,
|
||||
fd_table
|
||||
))
|
||||
.bind(qemb)
|
||||
.bind(&req.file_uuid)
|
||||
.bind(req.trace_id)
|
||||
.fetch_optional(state.db.pool())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"message": format!("Search failed: {}", e)})),
|
||||
)
|
||||
})?;
|
||||
let filter = serde_json::json!({
|
||||
"must": [
|
||||
{"key": "file_uuid", "match": {"value": req.file_uuid}}
|
||||
],
|
||||
"must_not": [
|
||||
{"key": "trace_id", "match": {"value": req.trace_id}}
|
||||
]
|
||||
});
|
||||
|
||||
if let Some((cface_id, c_trace_id, c_sim)) = top {
|
||||
if seen_trace_ids.insert(c_trace_id) {
|
||||
validated.push((cface_id, c_trace_id, c_sim));
|
||||
let hits = match qdrant.search_face_collection(
|
||||
"_faces",
|
||||
qemb,
|
||||
1,
|
||||
"trace_id",
|
||||
&req.trace_id.to_string(),
|
||||
Some(&req.file_uuid),
|
||||
).await {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
tracing::warn!("[match_from_trace] Qdrant search failed: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some((score, payload)) = hits.first() {
|
||||
let trace_id = payload.get("trace_id").and_then(|v| v.as_i64()).unwrap_or(0) as i32;
|
||||
if seen_trace_ids.insert(trace_id) {
|
||||
validated.push((0, trace_id, *score));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -421,41 +424,49 @@ async fn match_from_trace(
|
||||
}
|
||||
};
|
||||
|
||||
// 4. Update matched face_detections
|
||||
// 4. Update matched traces in Qdrant _faces
|
||||
let qdrant = QdrantDb::new();
|
||||
let mut traces_matched: Vec<i32> = Vec::new();
|
||||
for (id, trace_id, _similarity) in &validated {
|
||||
if let Err(e) = sqlx::query(&format!(
|
||||
"UPDATE {} SET identity_id = $1 WHERE id = $2",
|
||||
fd_table
|
||||
))
|
||||
.bind(identity_id)
|
||||
.bind(id)
|
||||
.execute(state.db.pool())
|
||||
.await
|
||||
for (_id, trace_id, _similarity) in &validated {
|
||||
let filter = serde_json::json!({
|
||||
"must": [
|
||||
{"key": "file_uuid", "match": {"value": req.file_uuid}},
|
||||
{"key": "trace_id", "match": {"value": trace_id}}
|
||||
]
|
||||
});
|
||||
let payload = serde_json::json!({"identity_id": identity_id});
|
||||
if let Err(e) = qdrant
|
||||
.update_payload_by_filter("_faces", filter, payload)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
"[match-from-trace] Failed to update face_detection {}: {}",
|
||||
id,
|
||||
"[match-from-trace] Qdrant update failed for trace {}: {}",
|
||||
trace_id,
|
||||
e
|
||||
);
|
||||
} else {
|
||||
if !traces_matched.contains(trace_id) {
|
||||
traces_matched.push(*trace_id);
|
||||
}
|
||||
} else if !traces_matched.contains(trace_id) {
|
||||
traces_matched.push(*trace_id);
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Also bind the source trace itself
|
||||
let _ = sqlx::query(&format!(
|
||||
"UPDATE {} SET identity_id = $1 WHERE file_uuid = $2 AND trace_id = $3",
|
||||
fd_table
|
||||
))
|
||||
.bind(identity_id)
|
||||
.bind(&req.file_uuid)
|
||||
.bind(req.trace_id)
|
||||
.execute(state.db.pool())
|
||||
.await;
|
||||
|
||||
let filter = serde_json::json!({
|
||||
"must": [
|
||||
{"key": "file_uuid", "match": {"value": req.file_uuid}},
|
||||
{"key": "trace_id", "match": {"value": req.trace_id}}
|
||||
]
|
||||
});
|
||||
let payload = serde_json::json!({"identity_id": identity_id});
|
||||
if let Err(e) = qdrant
|
||||
.update_payload_by_filter("_faces", filter, payload)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
"[match-from-trace] Qdrant update failed for source trace {}: {}",
|
||||
req.trace_id,
|
||||
e
|
||||
);
|
||||
}
|
||||
if !traces_matched.contains(&req.trace_id) {
|
||||
traces_matched.push(req.trace_id);
|
||||
}
|
||||
@@ -667,33 +678,34 @@ fn average_embeddings<'a>(embeddings: impl Iterator<Item = &'a Vec<f32>>) -> Vec
|
||||
async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
|
||||
use crate::core::processor::executor::PythonExecutor;
|
||||
use std::time::Duration;
|
||||
|
||||
|
||||
let executor = PythonExecutor::new()?;
|
||||
|
||||
|
||||
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
|
||||
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
|
||||
|
||||
|
||||
let output_path = std::path::PathBuf::from(&output_dir)
|
||||
.join(file_uuid)
|
||||
.join(format!("{}.identity_match_round1.json", file_uuid));
|
||||
|
||||
|
||||
std::fs::create_dir_all(output_path.parent().unwrap()).ok();
|
||||
|
||||
|
||||
let scripts_dir = executor.script_dir();
|
||||
let python_path = executor.python_path();
|
||||
let script_path = scripts_dir.join("identity_matcher.py");
|
||||
|
||||
let qdrant_url = std::env::var("QDRANT_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:6333".to_string());
|
||||
let qdrant_api_key = std::env::var("QDRANT_API_KEY")
|
||||
.unwrap_or_else(|_| "Test3200Test3200Test3200".to_string());
|
||||
|
||||
let qdrant_url =
|
||||
std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://localhost:6333".to_string());
|
||||
let qdrant_api_key =
|
||||
std::env::var("QDRANT_API_KEY").unwrap_or_else(|_| "Test3200Test3200Test3200".to_string());
|
||||
let db_url = std::env::var("DATABASE_URL")
|
||||
.unwrap_or_else(|_| "postgresql://accusys@localhost:5432/momentry".to_string());
|
||||
|
||||
|
||||
let db_schema = std::env::var("DATABASE_SCHEMA").unwrap_or_else(|_| "public".to_string());
|
||||
let mut cmd = tokio::process::Command::new(python_path);
|
||||
cmd.env("MOMENTRY_OUTPUT_DIR", &output_dir);
|
||||
cmd.env("DATABASE_SCHEMA", "public");
|
||||
cmd.env("MOMENTRY_DB_SCHEMA", "public");
|
||||
cmd.env("DATABASE_SCHEMA", &db_schema);
|
||||
cmd.env("MOMENTRY_DB_SCHEMA", &db_schema);
|
||||
cmd.env("DATABASE_URL", &db_url);
|
||||
cmd.env("QDRANT_URL", &qdrant_url);
|
||||
cmd.env("QDRANT_API_KEY", &qdrant_api_key);
|
||||
@@ -702,42 +714,50 @@ async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::
|
||||
cmd.arg("--round").arg("1");
|
||||
cmd.arg("--mark-tkg");
|
||||
cmd.arg("--output").arg(&output_path);
|
||||
|
||||
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
|
||||
|
||||
tracing::info!("[FaceMatch] Starting identity_matcher for {}", file_uuid);
|
||||
|
||||
|
||||
let output = cmd.output().await?;
|
||||
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
|
||||
|
||||
if !output.status.success() {
|
||||
tracing::error!("[FaceMatch] identity_matcher failed with exit code: {:?}", output.status.code());
|
||||
tracing::error!(
|
||||
"[FaceMatch] identity_matcher failed with exit code: {:?}",
|
||||
output.status.code()
|
||||
);
|
||||
tracing::error!("[FaceMatch] stderr: {}", stderr);
|
||||
tracing::error!("[FaceMatch] stdout: {}", stdout);
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
|
||||
tracing::info!("[FaceMatch] stdout: {}", stdout);
|
||||
|
||||
|
||||
if !output_path.exists() {
|
||||
tracing::info!("[FaceMatch] No matches found for {}", file_uuid);
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
|
||||
let content = std::fs::read_to_string(&output_path)?;
|
||||
let result: serde_json::Value = serde_json::from_str(&content)?;
|
||||
|
||||
|
||||
let matched = result.get("matched").and_then(|v| v.as_i64()).unwrap_or(0) as usize;
|
||||
let tkg_updated = result.get("tkg_nodes_updated").and_then(|v| v.as_i64()).unwrap_or(0) as usize;
|
||||
|
||||
let tkg_updated = result
|
||||
.get("tkg_nodes_updated")
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(0) as usize;
|
||||
|
||||
tracing::info!(
|
||||
"[FaceMatch] Round 1 for {}: {} matches, {} TKG nodes updated",
|
||||
file_uuid, matched, tkg_updated
|
||||
file_uuid,
|
||||
matched,
|
||||
tkg_updated
|
||||
);
|
||||
|
||||
|
||||
Ok(matched)
|
||||
}
|
||||
|
||||
@@ -755,17 +775,33 @@ async fn match_faces_iterative_pg(pool: &sqlx::PgPool, file_uuid: &str) -> anyho
|
||||
/// segments (speaker_id, start_time, end_time), computes overlap,
|
||||
/// and stores bindings in identity_bindings table.
|
||||
pub async fn bind_speakers(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
|
||||
// Load face traces with identity_id and frame numbers
|
||||
let fd_table = schema::table_name("face_detections");
|
||||
let traces = sqlx::query_as::<_, (i32, Vec<i32>)>(&format!(
|
||||
"SELECT trace_id, array_agg(frame_number ORDER BY frame_number) \
|
||||
FROM {} WHERE file_uuid=$1 AND trace_id IS NOT NULL AND identity_id IS NOT NULL \
|
||||
GROUP BY trace_id",
|
||||
fd_table
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
use crate::core::db::qdrant_db::QdrantDb;
|
||||
use serde_json::json;
|
||||
|
||||
// Load face traces with identity_id from Qdrant _faces
|
||||
let qdrant = QdrantDb::new();
|
||||
let trace_filter = json!({
|
||||
"must": [
|
||||
{"key": "file_uuid", "match": {"value": file_uuid}},
|
||||
{"key": "identity_id", "exists": true},
|
||||
{"key": "trace_id", "match": {"value": 1}}
|
||||
]
|
||||
});
|
||||
let points = qdrant.scroll_all_points("_faces", trace_filter, 500).await.unwrap_or_default();
|
||||
|
||||
// Group by trace_id, collect frames
|
||||
let mut traces: HashMap<i32, Vec<i64>> = HashMap::new();
|
||||
for point in &points {
|
||||
let payload = &point["payload"];
|
||||
let trace_id = payload["trace_id"].as_i64().unwrap_or(0) as i32;
|
||||
let frame = payload["frame"].as_i64().unwrap_or(0);
|
||||
traces.entry(trace_id).or_default().push(frame);
|
||||
}
|
||||
|
||||
// Sort frames per trace
|
||||
for frames in traces.values_mut() {
|
||||
frames.sort();
|
||||
}
|
||||
|
||||
if traces.is_empty() {
|
||||
tracing::info!("[SpeakerBind] No face traces with identities");
|
||||
@@ -818,8 +854,23 @@ pub async fn bind_speakers(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Resu
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
// Get fps for frame-to-time conversion
|
||||
let fps: f64 = 25.0; // default, could also read from DB
|
||||
// Compute fps from video table
|
||||
let fps: f64 = sqlx::query_scalar::<_, f64>(
|
||||
"SELECT COALESCE(fps, 25.0) FROM videos WHERE file_uuid=$1"
|
||||
)
|
||||
.bind(file_uuid)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.unwrap_or(25.0);
|
||||
|
||||
tracing::info!(
|
||||
"[SpeakerBind] Using fps={:.3} for {} ({} traces)",
|
||||
fps,
|
||||
file_uuid,
|
||||
traces.len()
|
||||
);
|
||||
|
||||
// For each trace, compute overlap with each speaker
|
||||
let mut bindings = 0usize;
|
||||
@@ -828,13 +879,15 @@ pub async fn bind_speakers(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Resu
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get identity_id for this trace
|
||||
let fd_table = schema::table_name("face_detections");
|
||||
let identity_id: Option<i32> = sqlx::query_scalar(
|
||||
&format!("SELECT identity_id FROM {} WHERE file_uuid=$1 AND trace_id=$2 AND identity_id IS NOT NULL LIMIT 1", fd_table)
|
||||
)
|
||||
.bind(file_uuid).bind(trace_id)
|
||||
.fetch_optional(pool).await?.flatten();
|
||||
// Get identity_id for this trace from Qdrant payload
|
||||
let identity_id: Option<i32> = points.iter()
|
||||
.find(|p| {
|
||||
p["payload"]["trace_id"].as_i64() == Some(*trace_id as i64)
|
||||
&& p["payload"]["identity_id"].as_i64().is_some()
|
||||
&& p["payload"]["identity_id"].as_i64().unwrap() > 0
|
||||
})
|
||||
.and_then(|p| p["payload"]["identity_id"].as_i64())
|
||||
.map(|id| id as i32);
|
||||
|
||||
if identity_id.is_none() {
|
||||
continue;
|
||||
@@ -873,18 +926,20 @@ pub async fn bind_speakers(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Resu
|
||||
});
|
||||
|
||||
let ib_table = schema::table_name("identity_bindings");
|
||||
let _ = sqlx::query(
|
||||
&format!("INSERT INTO {} (identity_id, identity_type, identity_value, file_uuid, confidence, metadata) \
|
||||
VALUES ($1, 'speaker', $2, $3, $4, $5::jsonb) \
|
||||
ON CONFLICT (identity_id, identity_type, identity_value, file_uuid) \
|
||||
if let Err(e) = sqlx::query(
|
||||
&format!("INSERT INTO {} (identity_id, identity_type, identity_value, confidence, metadata) \
|
||||
VALUES ($1, 'speaker', $2, $3, $4::jsonb) \
|
||||
ON CONFLICT (identity_id, identity_type, identity_value) \
|
||||
DO UPDATE SET confidence = EXCLUDED.confidence, metadata = EXCLUDED.metadata", ib_table)
|
||||
)
|
||||
.bind(identity_id)
|
||||
.bind(&best_speaker)
|
||||
.bind(file_uuid)
|
||||
.bind(overlap_ratio)
|
||||
.bind(&metadata)
|
||||
.execute(pool).await;
|
||||
.execute(pool).await
|
||||
{
|
||||
tracing::error!("[SpeakerBind] INSERT failed for trace_id={}, identity_id={}, speaker={}: {}", trace_id, identity_id, best_speaker, e);
|
||||
}
|
||||
|
||||
// Also update speaker_detections with the identity_id
|
||||
let sd_table = schema::table_name("speaker_detections");
|
||||
@@ -915,16 +970,40 @@ pub async fn bind_speakers(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Resu
|
||||
/// Pipeline-triggered entry point: runs the full identity agent for a file.
|
||||
/// Reads face_clustered.json + asrx.json, extracts persons/speakers, creates identities,
|
||||
/// runs iterative face matching, and binds speakers.
|
||||
pub async fn run_identity_agent(db: &PostgresDb, file_uuid: &str) -> anyhow::Result<()> {
|
||||
pub async fn run_identity_agent(
|
||||
db: &PostgresDb,
|
||||
file_uuid: &str,
|
||||
redis: Option<std::sync::Arc<RedisClient>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
|
||||
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
|
||||
|
||||
let pool = db.pool();
|
||||
|
||||
// Step 1: 先跑 face matching(不需 face_clustered.json)
|
||||
let mut progress = AgentProgress::new(file_uuid);
|
||||
if let Some(r) = redis.as_ref() {
|
||||
publish_agent_progress(&r, file_uuid, &progress).await;
|
||||
}
|
||||
|
||||
// Step 1: Face matching (iterative TMDb matching)
|
||||
progress.update_phase(AgentPhase::TmdbMatching, 0.3, "Running face matching...");
|
||||
if let Some(r) = redis.as_ref() {
|
||||
publish_agent_progress(&r, file_uuid, &progress).await;
|
||||
}
|
||||
|
||||
let matched = match_faces_iterative(pool, file_uuid).await.unwrap_or(0);
|
||||
progress.stats.tmdb_matches = matched as i64;
|
||||
progress.update_phase(AgentPhase::TmdbMatching, 1.0, &format!("Face matching: {} matches", matched));
|
||||
if let Some(r) = redis.as_ref() {
|
||||
publish_agent_progress(&r, file_uuid, &progress).await;
|
||||
}
|
||||
|
||||
// Step 2: Load face_clustered.json and create identities
|
||||
progress.update_phase(AgentPhase::FaceClustering, 0.5, "Loading face clusters...");
|
||||
if let Some(r) = redis.as_ref() {
|
||||
publish_agent_progress(&r, file_uuid, &progress).await;
|
||||
}
|
||||
|
||||
// Step 2: 試著載入 face_clustered.json 建立新 identities
|
||||
let video_dir = PathBuf::from(&output_dir).join(file_uuid);
|
||||
let face_clustered_path = video_dir.join(format!("{}.face_clustered.json", file_uuid));
|
||||
let face_clustered_path = if face_clustered_path.exists() {
|
||||
@@ -947,6 +1026,8 @@ pub async fn run_identity_agent(db: &PostgresDb, file_uuid: &str) -> anyhow::Res
|
||||
let speakers = extract_speakers_from_asrx_data(&asrx_data);
|
||||
let identities = analyze_person_speaker_overlap(&persons, &speakers);
|
||||
|
||||
progress.stats.clusters = identities.len() as i64;
|
||||
|
||||
let _ = identities.len();
|
||||
if !identities.is_empty() {
|
||||
let metadata = serde_json::json!({
|
||||
@@ -969,6 +1050,13 @@ pub async fn run_identity_agent(db: &PostgresDb, file_uuid: &str) -> anyhow::Res
|
||||
.execute(pool)
|
||||
.await;
|
||||
}
|
||||
progress.stats.identities_created = identities.len() as i64;
|
||||
progress.update_phase(AgentPhase::IdentityCreation, 1.0, &format!(
|
||||
"Created {} identities from clusters", identities.len()
|
||||
));
|
||||
if let Some(r) = redis.as_ref() {
|
||||
publish_agent_progress(&r, file_uuid, &progress).await;
|
||||
}
|
||||
tracing::info!(
|
||||
"[IdentityAgent] Analyzed {} face clusters from face_clustered for {}",
|
||||
identities.len(),
|
||||
@@ -979,9 +1067,29 @@ pub async fn run_identity_agent(db: &PostgresDb, file_uuid: &str) -> anyhow::Res
|
||||
"[IdentityAgent] face_clustered.json not found for {}, skipping identity creation",
|
||||
file_uuid
|
||||
);
|
||||
progress.update_phase(AgentPhase::IdentityCreation, 0.0, "No face_clustered.json");
|
||||
if let Some(r) = redis.as_ref() {
|
||||
publish_agent_progress(&r, file_uuid, &progress).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Speaker binding
|
||||
progress.update_phase(AgentPhase::SpeakerBinding, 0.5, "Binding speakers...");
|
||||
if let Some(r) = redis.as_ref() {
|
||||
publish_agent_progress(&r, file_uuid, &progress).await;
|
||||
}
|
||||
|
||||
let bound = bind_speakers(pool, file_uuid).await.unwrap_or(0);
|
||||
progress.stats.speaker_bindings = bound as i64;
|
||||
progress.update_phase(AgentPhase::SpeakerBinding, 1.0, &format!("Speaker binding: {} bound", bound));
|
||||
if let Some(r) = redis.as_ref() {
|
||||
publish_agent_progress(&r, file_uuid, &progress).await;
|
||||
}
|
||||
|
||||
progress.mark_completed();
|
||||
if let Some(r) = redis.as_ref() {
|
||||
publish_agent_progress(&r, file_uuid, &progress).await;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"[IdentityAgent] Done for {}: {} face matches, {} speaker bindings",
|
||||
@@ -999,14 +1107,12 @@ async fn generate_seeds_handler(
|
||||
let db = &state.db;
|
||||
let pool = db.pool();
|
||||
|
||||
let count = generate_seed_embeddings(db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"success": false, "message": format!("{}", e)})),
|
||||
)
|
||||
})?;
|
||||
let count = generate_seed_embeddings(db).await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"success": false, "message": format!("{}", e)})),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Auto-trigger identity agent for all ready files
|
||||
if count > 0 {
|
||||
@@ -1019,13 +1125,13 @@ async fn generate_seeds_handler(
|
||||
);
|
||||
for file_uuid in &ready_files {
|
||||
let db = state.db.clone();
|
||||
let redis = crate::core::db::RedisClient::new().ok().map(Arc::new);
|
||||
let fid = file_uuid.clone();
|
||||
tokio::spawn(async move {
|
||||
match run_identity_agent(&db, &fid).await {
|
||||
Ok(_) => tracing::info!(
|
||||
"[GenerateSeeds] Identity agent completed for {}",
|
||||
fid
|
||||
),
|
||||
match run_identity_agent(&db, &fid, redis).await {
|
||||
Ok(_) => {
|
||||
tracing::info!("[GenerateSeeds] Identity agent completed for {}", fid)
|
||||
}
|
||||
Err(e) => tracing::warn!(
|
||||
"[GenerateSeeds] Identity agent failed for {}: {}",
|
||||
fid,
|
||||
@@ -1044,16 +1150,28 @@ async fn generate_seeds_handler(
|
||||
})))
|
||||
}
|
||||
|
||||
/// Find videos that are ready for identity processing (have face embeddings).
|
||||
/// Find videos that are ready for identity processing (have face embeddings in Qdrant).
|
||||
async fn find_ready_files(pool: &sqlx::PgPool) -> anyhow::Result<Vec<String>> {
|
||||
let fd_table = crate::core::db::schema::table_name("face_detections");
|
||||
let rows: Vec<(String,)> = sqlx::query_as(&format!(
|
||||
"SELECT DISTINCT file_uuid FROM {} WHERE embedding IS NOT NULL AND identity_id IS NULL",
|
||||
fd_table
|
||||
))
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
Ok(rows.into_iter().map(|r| r.0).collect())
|
||||
use crate::core::db::qdrant_db::QdrantDb;
|
||||
use serde_json::json;
|
||||
|
||||
let qdrant = QdrantDb::new();
|
||||
// Find files with faces that don't have identity_id set
|
||||
let filter = json!({
|
||||
"must": [
|
||||
{"key": "identity_id", "match": {"value": null}}
|
||||
]
|
||||
});
|
||||
let points = qdrant.scroll_all_points("_faces", filter, 1000).await.unwrap_or_default();
|
||||
|
||||
let mut file_uuids: std::collections::HashSet<String> = std::collections::HashSet::new();
|
||||
for point in &points {
|
||||
if let Some(fu) = point["payload"]["file_uuid"].as_str() {
|
||||
file_uuids.insert(fu.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(file_uuids.into_iter().collect())
|
||||
}
|
||||
|
||||
/// API handler: POST /api/v1/agents/identity/run
|
||||
@@ -1071,7 +1189,8 @@ async fn run_identity_handler(
|
||||
)
|
||||
})?;
|
||||
|
||||
match run_identity_agent(&state.db, file_uuid).await {
|
||||
let redis = crate::core::db::RedisClient::new().ok().map(Arc::new);
|
||||
match run_identity_agent(&state.db, file_uuid, redis).await {
|
||||
Ok(()) => Ok(Json(serde_json::json!({
|
||||
"success": true,
|
||||
"message": format!("Identity agent completed for {}", file_uuid),
|
||||
@@ -1109,29 +1228,28 @@ async fn confirm_identity_handler(
|
||||
Json(req): Json<ConfirmIdentityRequest>,
|
||||
) -> Result<Json<ConfirmIdentityResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
use crate::core::processor::executor::PythonExecutor;
|
||||
|
||||
|
||||
let executor = PythonExecutor::new().map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"success": false, "message": format!("PythonExecutor error: {}", e)})),
|
||||
)
|
||||
})?;
|
||||
|
||||
|
||||
let scripts_dir = executor.script_dir();
|
||||
let python_path = executor.python_path();
|
||||
let script_path = scripts_dir.join("confirm_identity.py");
|
||||
|
||||
let qdrant_url = std::env::var("QDRANT_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:6333".to_string());
|
||||
let qdrant_api_key = std::env::var("QDRANT_API_KEY")
|
||||
.unwrap_or_else(|_| "Test3200Test3200Test3200".to_string());
|
||||
|
||||
let qdrant_url =
|
||||
std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://localhost:6333".to_string());
|
||||
let qdrant_api_key =
|
||||
std::env::var("QDRANT_API_KEY").unwrap_or_else(|_| "Test3200Test3200Test3200".to_string());
|
||||
let db_url = std::env::var("DATABASE_URL")
|
||||
.unwrap_or_else(|_| "postgresql://accusys@localhost:5432/momentry".to_string());
|
||||
let db_schema = std::env::var("DATABASE_SCHEMA")
|
||||
.unwrap_or_else(|_| "dev".to_string());
|
||||
|
||||
let db_schema = std::env::var("DATABASE_SCHEMA").unwrap_or_else(|_| "dev".to_string());
|
||||
|
||||
let propagate = req.propagate.unwrap_or(true);
|
||||
|
||||
|
||||
let mut cmd = tokio::process::Command::new(python_path);
|
||||
cmd.env("DATABASE_URL", &db_url);
|
||||
cmd.env("DATABASE_SCHEMA", &db_schema);
|
||||
@@ -1144,31 +1262,39 @@ async fn confirm_identity_handler(
|
||||
cmd.arg("--identity-id").arg(req.identity_id.to_string());
|
||||
cmd.arg("--identity-uuid").arg(&req.identity_uuid);
|
||||
cmd.arg("--name").arg(&req.name);
|
||||
|
||||
|
||||
if !propagate {
|
||||
cmd.arg("--no-propagate");
|
||||
}
|
||||
|
||||
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
|
||||
|
||||
tracing::info!(
|
||||
"[ConfirmIdentity] Starting for {} trace {} -> {} ({})",
|
||||
req.file_uuid, req.trace_id, req.identity_uuid, req.name
|
||||
req.file_uuid,
|
||||
req.trace_id,
|
||||
req.identity_uuid,
|
||||
req.name
|
||||
);
|
||||
|
||||
|
||||
let output = cmd.output().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"success": false, "message": format!("Command failed: {}", e)})),
|
||||
Json(
|
||||
serde_json::json!({"success": false, "message": format!("Command failed: {}", e)}),
|
||||
),
|
||||
)
|
||||
})?;
|
||||
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
|
||||
|
||||
if !output.status.success() {
|
||||
tracing::error!("[ConfirmIdentity] Script failed with exit code: {:?}", output.status.code());
|
||||
tracing::error!(
|
||||
"[ConfirmIdentity] Script failed with exit code: {:?}",
|
||||
output.status.code()
|
||||
);
|
||||
tracing::error!("[ConfirmIdentity] stderr: {}", stderr);
|
||||
tracing::error!("[ConfirmIdentity] stdout: {}", stdout);
|
||||
return Err((
|
||||
@@ -1180,9 +1306,9 @@ async fn confirm_identity_handler(
|
||||
})),
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
tracing::info!("[ConfirmIdentity] stdout: {}", stdout);
|
||||
|
||||
|
||||
let json_start = stdout.find('{');
|
||||
if json_start.is_none() {
|
||||
return Err((
|
||||
@@ -1195,7 +1321,7 @@ async fn confirm_identity_handler(
|
||||
));
|
||||
}
|
||||
let json_str = &stdout[json_start.unwrap()..];
|
||||
|
||||
|
||||
let result: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
@@ -1207,14 +1333,17 @@ async fn confirm_identity_handler(
|
||||
})),
|
||||
)
|
||||
})?;
|
||||
|
||||
|
||||
Ok(Json(ConfirmIdentityResponse {
|
||||
success: result.get("status").and_then(|v| v.as_str()) == Some("success"),
|
||||
file_uuid: req.file_uuid,
|
||||
trace_id: req.trace_id,
|
||||
identity_uuid: req.identity_uuid,
|
||||
name: req.name,
|
||||
steps: result.get("steps").cloned().unwrap_or(serde_json::json!({})),
|
||||
steps: result
|
||||
.get("steps")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::json!({})),
|
||||
propagation: result.get("propagation").cloned(),
|
||||
}))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user