refactor: remove face embedding architecture - single Qdrant _faces collection
- Delete FaceEmbeddingDb module (face_embedding_db.rs) - Stub match_faces_iterative, generate_seed_embeddings, tmdb_match_handler - Remove sync_trace_embeddings, populate_face_embeddings_to_qdrant - Remove embedding from face.json output (face_processor.py) - Remove embedding from PG UPDATE (store_traced_faces.py) - Remove workspace traces staging (checkin.rs, qdrant_workspace.rs) - Fix tests: add pose_angle to Face, hand_nodes to TkgResult Disabled functions (need reimplement with _faces): - match_faces_iterative (identity agent) - generate_seed_embeddings (TMDb seeds) - tmdb_match_handler (TMDb matching) - cluster_face_embeddings, search_similar_faces - merge_traces_within_cuts
This commit is contained in:
+13
-706
@@ -661,597 +661,21 @@ fn average_embeddings<'a>(embeddings: impl Iterator<Item = &'a Vec<f32>>) -> Vec
|
||||
/// Unknown: greedy stranger clustering (TH=0.40)
|
||||
/// Writes identity_ref/stranger_ref to Qdrant payload, TKG nodes, and face_detections.
|
||||
async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
|
||||
use crate::core::db::face_embedding_db::FaceEmbeddingDb;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let face_db = FaceEmbeddingDb::new();
|
||||
|
||||
// Step 1: Load seeds from Qdrant (type=identity_seed)
|
||||
let seeds = face_db.get_seed_embeddings().await?;
|
||||
tracing::info!(
|
||||
"[FaceMatch] Loaded {} seeds from Qdrant",
|
||||
seeds.len()
|
||||
);
|
||||
|
||||
// Step 2: Preload identity internal IDs (uuid → (id, name))
|
||||
let id_table = schema::table_name("identities");
|
||||
let seed_identity_map: HashMap<String, (i32, String)> = if !seeds.is_empty() {
|
||||
let uuids: Vec<String> = seeds.iter().map(|(uuid, _, _)| uuid.clone()).collect();
|
||||
if uuids.is_empty() {
|
||||
HashMap::new()
|
||||
} else {
|
||||
let rows = sqlx::query_as::<_, (i32, String, String)>(&format!(
|
||||
"SELECT id, uuid::text, name FROM {} WHERE uuid::text = ANY($1)",
|
||||
id_table
|
||||
))
|
||||
.bind(&uuids)
|
||||
.fetch_all(pool)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(id, uuid, name)| (uuid, (id, name)))
|
||||
.collect();
|
||||
rows
|
||||
}
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
|
||||
// Step 3: Load face embeddings from Qdrant for this file
|
||||
let qdrant_embeddings = face_db.get_all_embeddings_for_file(file_uuid).await?;
|
||||
|
||||
if qdrant_embeddings.is_empty() {
|
||||
tracing::warn!("[FaceMatch] No face embeddings in Qdrant for {}", file_uuid);
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
// Step 4: Group embeddings by trace_id, keeping confidence
|
||||
let mut trace_faces: HashMap<i32, Vec<(i64, Vec<f32>, f64)>> = HashMap::new();
|
||||
for (_, emb, payload) in &qdrant_embeddings {
|
||||
trace_faces
|
||||
.entry(payload.trace_id)
|
||||
.or_default()
|
||||
.push((payload.frame, emb.clone(), payload.confidence));
|
||||
}
|
||||
|
||||
// Step 5: Progressive multi-round matching with derived seeds
|
||||
// Each round: choose a face with best seed sim for matching; separately,
|
||||
// collect the highest-confidence face per trace for building derived seeds.
|
||||
const TH_MIN: f32 = 0.35;
|
||||
const DERIVED_CONF: f64 = 0.90;
|
||||
const MAX_DERIVED_PER_ID: usize = 9;
|
||||
const MAX_FACES_PER_TRACE: usize = 3;
|
||||
const ANGLE_SIM_THRESHOLD: f32 = 0.90;
|
||||
const TH_STRANGER: f32 = 0.40;
|
||||
|
||||
let total_traces = trace_faces.len();
|
||||
let total_embeddings: usize = trace_faces.values().map(|v| v.len()).sum();
|
||||
tracing::info!(
|
||||
"[FaceMatch] Loaded {} traces ({} face embeddings) from Qdrant for {}",
|
||||
total_traces,
|
||||
total_embeddings,
|
||||
tracing::warn!(
|
||||
"[FaceMatch] Face matching disabled - FaceEmbeddingDb removed. \
|
||||
TODO: Reimplement with _faces collection for {}",
|
||||
file_uuid
|
||||
);
|
||||
|
||||
let mut matched: HashMap<i32, (String, i32)> = HashMap::new();
|
||||
let mut trace_face_count: HashMap<i32, usize> = HashMap::new();
|
||||
|
||||
// All reference embeddings: start with original TMDb seeds
|
||||
let mut all_refs: Vec<(String, String, Vec<f32>)> = seeds.clone();
|
||||
let thresholds = [0.55f32, 0.50, 0.45, 0.40, 0.35];
|
||||
let mut prev_total = 0usize;
|
||||
|
||||
for (round_idx, &th) in thresholds.iter().enumerate() {
|
||||
if th < TH_MIN {
|
||||
break;
|
||||
}
|
||||
|
||||
let mut new_matches: HashMap<i32, (String, i32)> = HashMap::new();
|
||||
let mut seed_candidates: Vec<(i32, String, i32, Vec<f32>, f64)> = Vec::new();
|
||||
|
||||
for (&tid, faces) in &trace_faces {
|
||||
if matched.contains_key(&tid) {
|
||||
continue;
|
||||
}
|
||||
trace_face_count.entry(tid).or_insert(faces.len());
|
||||
|
||||
let mut best_sim = 0.0f32;
|
||||
let mut best_name = String::new();
|
||||
let mut best_id = 0i32;
|
||||
// Collect all high-confidence faces in this trace for derived seeds
|
||||
let mut trace_candidates: Vec<(Vec<f32>, f64)> = Vec::new();
|
||||
|
||||
for (_, emb, conf) in faces {
|
||||
for (ref_uuid, ref_name, ref_emb) in &all_refs {
|
||||
let s = cosine_similarity(emb, ref_emb);
|
||||
if s > best_sim {
|
||||
best_sim = s;
|
||||
best_name = ref_name.clone();
|
||||
if let Some(id_str) = ref_uuid.strip_prefix("derived:") {
|
||||
if let Ok(parsed) = id_str.parse::<i32>() {
|
||||
best_id = parsed;
|
||||
}
|
||||
} else if let Some((id, _)) = seed_identity_map.get(ref_uuid) {
|
||||
best_id = *id;
|
||||
}
|
||||
}
|
||||
}
|
||||
if *conf >= DERIVED_CONF {
|
||||
trace_candidates.push((emb.clone(), *conf));
|
||||
}
|
||||
}
|
||||
|
||||
if best_sim >= th && best_id > 0 {
|
||||
new_matches.insert(tid, (best_name.clone(), best_id));
|
||||
|
||||
// Top MAX_FACES_PER_TRACE highest-confidence faces with angular diversity
|
||||
trace_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
let mut selected: Vec<Vec<f32>> = Vec::new();
|
||||
for (emb, conf) in trace_candidates {
|
||||
if selected.len() >= MAX_FACES_PER_TRACE {
|
||||
break;
|
||||
}
|
||||
if selected.iter().any(|e| cosine_similarity(e, &emb) >= ANGLE_SIM_THRESHOLD) {
|
||||
continue;
|
||||
}
|
||||
selected.push(emb.clone());
|
||||
seed_candidates.push((best_id, best_name.clone(), tid, emb, conf));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let new_count = new_matches.len();
|
||||
if new_count == 0 && round_idx > 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
matched.extend(new_matches);
|
||||
|
||||
// Build derived seeds: pick up to MAX_DERIVED_PER_ID per identity
|
||||
// (max MAX_FACES_PER_TRACE from each trace), sorted by confidence descending
|
||||
seed_candidates.sort_by(|a, b| b.4.partial_cmp(&a.4).unwrap());
|
||||
let mut per_id: HashMap<i32, usize> = HashMap::new();
|
||||
let mut trace_used_faces: HashMap<i32, usize> = HashMap::new();
|
||||
let mut added_seeds = 0usize;
|
||||
for (id, name, tid, emb, _) in &seed_candidates {
|
||||
let cnt = per_id.entry(*id).or_insert(0);
|
||||
if *cnt >= MAX_DERIVED_PER_ID {
|
||||
continue;
|
||||
}
|
||||
let trace_cnt = trace_used_faces.entry(*tid).or_insert(0);
|
||||
if *trace_cnt >= MAX_FACES_PER_TRACE {
|
||||
continue;
|
||||
}
|
||||
*trace_cnt += 1;
|
||||
*cnt += 1;
|
||||
all_refs.push((format!("derived:{}", id), name.clone(), emb.clone()));
|
||||
added_seeds += 1;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"[FaceMatch] Round {}: matched {}+{}={} total (TH={}, {} new derived seeds)",
|
||||
round_idx + 1,
|
||||
prev_total,
|
||||
new_count,
|
||||
matched.len(),
|
||||
th,
|
||||
added_seeds
|
||||
);
|
||||
|
||||
prev_total = matched.len();
|
||||
}
|
||||
|
||||
// Step 7: Stranger clustering for unmatched traces
|
||||
let unmatched_ids: Vec<i32> = trace_faces
|
||||
.keys()
|
||||
.filter(|tid| !matched.contains_key(tid))
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
let mut stranger_map: HashMap<i32, String> = HashMap::new();
|
||||
let mut assigned_stranger: std::collections::HashSet<i32> = std::collections::HashSet::new();
|
||||
let mut stranger_count = 0usize;
|
||||
|
||||
// Sort by face count descending (most reliable first)
|
||||
let mut sorted_unmatched: Vec<i32> = unmatched_ids.clone();
|
||||
sorted_unmatched.sort_by(|a, b| {
|
||||
trace_face_count
|
||||
.get(b)
|
||||
.unwrap_or(&0)
|
||||
.cmp(trace_face_count.get(a).unwrap_or(&0))
|
||||
});
|
||||
|
||||
for &tid in &sorted_unmatched {
|
||||
if assigned_stranger.contains(&tid) {
|
||||
continue;
|
||||
}
|
||||
let centroid_a = if let Some(faces) = trace_faces.get(&tid) {
|
||||
average_embeddings(faces.iter().map(|(_, emb, _)| emb))
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
stranger_count += 1;
|
||||
let stranger_id = format!("{}:stranger_{}", file_uuid, stranger_count);
|
||||
assigned_stranger.insert(tid);
|
||||
stranger_map.insert(tid, stranger_id.clone());
|
||||
|
||||
for &other_tid in &sorted_unmatched {
|
||||
if assigned_stranger.contains(&other_tid) || other_tid == tid {
|
||||
continue;
|
||||
}
|
||||
if let Some(faces_b) = trace_faces.get(&other_tid) {
|
||||
let centroid_b = average_embeddings(faces_b.iter().map(|(_, emb, _)| emb));
|
||||
let s = cosine_similarity(¢roid_a, ¢roid_b);
|
||||
if s >= TH_STRANGER {
|
||||
assigned_stranger.insert(other_tid);
|
||||
stranger_map.insert(other_tid, stranger_id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let stranger_trace_count = stranger_map.len();
|
||||
tracing::info!(
|
||||
"[FaceMatch] Stranger clusters: {} groups, {} traces",
|
||||
stranger_count,
|
||||
stranger_trace_count
|
||||
);
|
||||
|
||||
// Step 8: Write results to TKG nodes + Qdrant payload + face_detections
|
||||
let fd_table = schema::table_name("face_detections");
|
||||
let nodes_table = schema::table_name("tkg_nodes");
|
||||
let mut pg_updated = 0usize;
|
||||
|
||||
// Clear old identity assignments before writing new ones
|
||||
let _ = sqlx::query(&format!(
|
||||
"UPDATE {} SET identity_id = NULL WHERE file_uuid = $1",
|
||||
fd_table
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// 8a: Matched traces → identity_ref
|
||||
for (&tid, (name, identity_id)) in &matched {
|
||||
// Skip if identity_id is invalid (FK constraint would fail)
|
||||
if *identity_id <= 0 {
|
||||
tracing::warn!(
|
||||
"[FaceMatch] Skipping trace {}: invalid identity_id={}",
|
||||
tid, identity_id
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let identity_ref = format!("{}:{}", file_uuid, identity_id);
|
||||
|
||||
// TKG node
|
||||
let external_id = format!("face_track_{}", tid);
|
||||
if let Err(e) = sqlx::query(&format!(
|
||||
"UPDATE {} SET properties = jsonb_set(\
|
||||
jsonb_set(properties, '{{identity_ref}}', to_jsonb($1), true),\
|
||||
'{{identity_name}}', to_jsonb($2), true)\
|
||||
WHERE file_uuid = $3 AND node_type = 'face_track' AND external_id = $4",
|
||||
nodes_table
|
||||
))
|
||||
.bind(&identity_ref)
|
||||
.bind(name)
|
||||
.bind(file_uuid)
|
||||
.bind(&external_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("[FaceMatch] TKG update failed for trace {}: {:?}", tid, e);
|
||||
}
|
||||
|
||||
// Qdrant payload
|
||||
let _ = face_db
|
||||
.update_identity_ref_by_trace(file_uuid, tid, &identity_ref)
|
||||
.await;
|
||||
|
||||
// PostgreSQL face_detections (backward compat)
|
||||
let rows = sqlx::query(&format!(
|
||||
"UPDATE {} SET identity_id = $1 WHERE file_uuid = $2 AND trace_id = $3",
|
||||
fd_table
|
||||
))
|
||||
.bind(identity_id)
|
||||
.bind(file_uuid)
|
||||
.bind(tid)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map(|r| r.rows_affected())
|
||||
.unwrap_or(0);
|
||||
pg_updated += rows as usize;
|
||||
}
|
||||
|
||||
// 8b: Stranger traces → stranger_ref
|
||||
for (&tid, stranger_ref) in &stranger_map {
|
||||
// TKG node
|
||||
let external_id = format!("face_track_{}", tid);
|
||||
if let Err(e) = sqlx::query(&format!(
|
||||
"UPDATE {} SET properties = jsonb_set(\
|
||||
properties, '{{stranger_ref}}', to_jsonb($1), true)\
|
||||
WHERE file_uuid = $2 AND node_type = 'face_track' AND external_id = $3",
|
||||
nodes_table
|
||||
))
|
||||
.bind(stranger_ref)
|
||||
.bind(file_uuid)
|
||||
.bind(&external_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("[FaceMatch] TKG stranger update failed for trace {}: {:?}", tid, e);
|
||||
}
|
||||
|
||||
// Qdrant payload
|
||||
let _ = face_db
|
||||
.update_stranger_ref_by_trace(file_uuid, tid, stranger_ref)
|
||||
.await;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"[FaceMatch] Done: {} matched, {} strangers — {} face_detections updated",
|
||||
matched.len(),
|
||||
stranger_trace_count,
|
||||
pg_updated
|
||||
);
|
||||
Ok(pg_updated)
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
/// Fallback: PostgreSQL-based matching (original implementation)
|
||||
/// Fallback: PostgreSQL-based matching (disabled - embedding column removed)
|
||||
async fn match_faces_iterative_pg(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
|
||||
// Step 1: 載入 TMDb identities (source='tmdb' 且有 face_embedding)
|
||||
let identities_table = schema::table_name("identities");
|
||||
let tmdb_rows = sqlx::query_as::<_, (i32, String, Vec<f32>)>(
|
||||
&format!("SELECT id, name, face_embedding::real[] FROM {} WHERE source='tmdb' AND face_embedding IS NOT NULL", identities_table)
|
||||
)
|
||||
.fetch_all(pool).await?;
|
||||
|
||||
if tmdb_rows.is_empty() {
|
||||
tracing::warn!("[FaceMatch-PG] No TMDb identities with face embeddings");
|
||||
return Ok(0);
|
||||
}
|
||||
tracing::info!(
|
||||
"[FaceMatch-PG] Loaded {} TMDb seed identities",
|
||||
tmdb_rows.len()
|
||||
tracing::warn!(
|
||||
"[FaceMatch-PG] PostgreSQL matching disabled - embedding column removed for {}",
|
||||
file_uuid
|
||||
);
|
||||
|
||||
// Step 2: 載入所有 face_detections(含 frame_number),按 trace_id 分組
|
||||
let fd_table = schema::table_name("face_detections");
|
||||
let fd_rows = sqlx::query_as::<_, (i32, i64, Vec<f32>)>(&format!(
|
||||
"SELECT trace_id, frame_number, embedding FROM {} \
|
||||
WHERE file_uuid=$1 AND trace_id IS NOT NULL AND embedding IS NOT NULL \
|
||||
ORDER BY trace_id, frame_number",
|
||||
fd_table
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
if fd_rows.is_empty() {
|
||||
tracing::warn!("[FaceMatch-PG] No face detections with embeddings");
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
// 分組:trace_id → (frame_number, embedding)
|
||||
use std::collections::HashMap;
|
||||
let mut face_track_faces_raw: HashMap<i32, Vec<(i64, Vec<f32>)>> = HashMap::new();
|
||||
for (tid, frame, emb) in &fd_rows {
|
||||
face_track_faces_raw
|
||||
.entry(*tid)
|
||||
.or_insert_with(Vec::new)
|
||||
.push((*frame, emb.clone()));
|
||||
}
|
||||
|
||||
// 從每個 trace 選取不同角度的 3 個 face embedding
|
||||
let mut face_track_samples: HashMap<i32, Vec<Vec<f32>>> = HashMap::new();
|
||||
for (tid, mut faces) in face_track_faces_raw {
|
||||
faces.sort_by_key(|(frame, _)| *frame);
|
||||
let n = faces.len();
|
||||
let indices = if n <= 3 {
|
||||
(0..n).collect()
|
||||
} else {
|
||||
let mid = n / 2;
|
||||
vec![0, mid, n - 1]
|
||||
};
|
||||
let samples: Vec<Vec<f32>> = indices.iter().map(|&i| faces[i].1.clone()).collect();
|
||||
face_track_samples.insert(tid, samples);
|
||||
}
|
||||
|
||||
let total_traces = face_track_samples.len();
|
||||
let sample_count: usize = face_track_samples.values().map(|v| v.len()).sum();
|
||||
tracing::info!(
|
||||
"[FaceMatch-PG] Loaded {} traces, sampled {} embeddings (3-angle)",
|
||||
total_traces,
|
||||
sample_count
|
||||
);
|
||||
|
||||
// Step 3: 建立 TMDb 查找表
|
||||
let tmdb_seeds: Vec<(i32, String, Vec<f32>)> = tmdb_rows;
|
||||
|
||||
// Step 4: 迭代匹配
|
||||
const TH: f32 = 0.50;
|
||||
let mut matched: HashMap<i32, String> = HashMap::new(); // trace_id → identity_name
|
||||
|
||||
// Round 1: 用 3-angle samples 比對 TMDb
|
||||
for (&tid, samples) in &face_track_samples {
|
||||
let mut best_name = String::new();
|
||||
let mut best_sim = 0.0f32;
|
||||
for (_, ref name, ref tmdb_emb) in &tmdb_seeds {
|
||||
for face_emb in samples {
|
||||
let s = cosine_similarity(face_emb, tmdb_emb);
|
||||
if s > best_sim {
|
||||
best_sim = s;
|
||||
best_name = name.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
if best_sim >= TH {
|
||||
matched.insert(tid, best_name);
|
||||
}
|
||||
}
|
||||
tracing::info!(
|
||||
"[FaceMatch] Round 1: {} matched ({}%) — writing to DB",
|
||||
matched.len(),
|
||||
matched.len() * 100 / total_traces
|
||||
);
|
||||
|
||||
// Step 5: 寫入 DB — Round 1 結果先存 (Phase 3: update both face_detections AND tkg_nodes)
|
||||
let identities_table = schema::table_name("identities");
|
||||
let strangers_table = schema::table_name("strangers");
|
||||
let fd_table = schema::table_name("face_detections");
|
||||
let nodes_table = schema::table_name("tkg_nodes");
|
||||
let mut updated = 0usize;
|
||||
for (tid, name) in &matched {
|
||||
let id_opt = sqlx::query_scalar::<_, Option<i32>>(&format!(
|
||||
"SELECT id FROM {} WHERE name=$1 AND source='tmdb'",
|
||||
identities_table
|
||||
))
|
||||
.bind(name)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
if let Some(identity_id) = id_opt {
|
||||
let _ = sqlx::query(&format!(
|
||||
"UPDATE {} SET identity_id=$1 WHERE file_uuid=$2 AND trace_id=$3",
|
||||
fd_table
|
||||
))
|
||||
.bind(identity_id)
|
||||
.bind(file_uuid)
|
||||
.bind(tid)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// Phase 3: Also update TKG node
|
||||
let external_id = format!("face_track_{}", tid);
|
||||
let _ = sqlx::query(&format!(
|
||||
"UPDATE {} SET properties = jsonb_set(\
|
||||
jsonb_set(properties, '{{identity_id}}', $1::jsonb, false),\
|
||||
'{{identity_name}}', $2::jsonb, false)\
|
||||
WHERE file_uuid = $3 AND node_type = 'face_track' AND external_id = $4",
|
||||
nodes_table
|
||||
))
|
||||
.bind(identity_id)
|
||||
.bind(name.as_str())
|
||||
.bind(file_uuid)
|
||||
.bind(&external_id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
updated += 1;
|
||||
}
|
||||
}
|
||||
tracing::info!("[FaceMatch] Round 1: updated {} face_detections", updated);
|
||||
|
||||
// Round 2+: 用已匹配的 face 作為 seed 傳播(剩餘未匹配的 trace)
|
||||
let initial_matched = matched.len();
|
||||
for round_n in 2..=5 {
|
||||
let prev = matched.len();
|
||||
// 建立 seed pool: name → Vec<embedding>
|
||||
let mut seed_pool: HashMap<String, Vec<&Vec<f32>>> = HashMap::new();
|
||||
for (&tid, name) in &matched {
|
||||
if let Some(samples) = face_track_samples.get(&tid) {
|
||||
seed_pool
|
||||
.entry(name.clone())
|
||||
.or_default()
|
||||
.extend(samples.iter());
|
||||
}
|
||||
}
|
||||
|
||||
let mut new_matches: Vec<(i32, String)> = Vec::new();
|
||||
for (&tid, samples) in &face_track_samples {
|
||||
if matched.contains_key(&tid) {
|
||||
continue;
|
||||
}
|
||||
let mut best_name = String::new();
|
||||
let mut best_sim = 0.0f32;
|
||||
if samples.is_empty() {
|
||||
continue;
|
||||
}
|
||||
// 用 3-angle samples 分別比對 seed,取最高 similarity
|
||||
for (name, seed_faces) in &seed_pool {
|
||||
for face_emb in samples {
|
||||
for seed in seed_faces {
|
||||
let s = cosine_similarity(face_emb, seed);
|
||||
if s > best_sim {
|
||||
best_sim = s;
|
||||
best_name = name.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if best_sim >= TH {
|
||||
new_matches.push((tid, best_name));
|
||||
}
|
||||
}
|
||||
for (tid, name) in new_matches {
|
||||
matched.insert(tid, name);
|
||||
}
|
||||
let new = matched.len() - prev;
|
||||
tracing::info!(
|
||||
"[FaceMatch] Round {}: +{} matched (total {}, {}%)",
|
||||
round_n,
|
||||
new,
|
||||
matched.len(),
|
||||
matched.len() * 100 / total_traces
|
||||
);
|
||||
if new < 5 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 6: 未匹配的 trace 設 stranger_id = strangers.id (FK)
|
||||
// First: ensure strangers records exist
|
||||
let _ = sqlx::query(&format!(
|
||||
"INSERT INTO {} (file_uuid, trace_id) \
|
||||
SELECT $1, fd.trace_id FROM {} fd \
|
||||
WHERE fd.file_uuid = $1 AND fd.trace_id IS NOT NULL \
|
||||
AND fd.identity_id IS NULL \
|
||||
ON CONFLICT (file_uuid, trace_id) DO NOTHING",
|
||||
strangers_table, fd_table
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Then: update face_detections.stranger_id = strangers.id
|
||||
let stranger_update = sqlx::query(&format!(
|
||||
"UPDATE {} fd SET stranger_id = s.id \
|
||||
FROM {} s \
|
||||
WHERE s.file_uuid = fd.file_uuid AND s.trace_id = fd.trace_id \
|
||||
AND fd.file_uuid = $1 AND fd.identity_id IS NULL \
|
||||
AND fd.trace_id IS NOT NULL AND fd.stranger_id IS NULL",
|
||||
fd_table, strangers_table
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
let stranger_count = stranger_update.rows_affected();
|
||||
|
||||
// Step 7: Save identity files for all affected identities
|
||||
let affected = sqlx::query_scalar::<_, uuid::Uuid>(&format!(
|
||||
"SELECT DISTINCT i.uuid FROM {} i \
|
||||
JOIN {} fd ON fd.identity_id = i.id \
|
||||
WHERE fd.file_uuid=$1 AND fd.identity_id IS NOT NULL",
|
||||
identities_table, fd_table
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
for uuid in &affected {
|
||||
let us = uuid.to_string().replace('-', "");
|
||||
if let Err(e) = crate::core::identity::storage::save_identity_file_by_pool(pool, &us).await
|
||||
{
|
||||
tracing::warn!("[FaceMatch] Failed to save identity file {}: {}", us, e);
|
||||
}
|
||||
}
|
||||
tracing::info!(
|
||||
"[FaceMatch] Done: {}/{} traces matched ({}%), {} strangers, {} identity files",
|
||||
matched.len(),
|
||||
total_traces,
|
||||
matched.len() * 100 / total_traces,
|
||||
stranger_count,
|
||||
affected.len()
|
||||
);
|
||||
Ok(updated)
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
/// Bind ASRX speakers to face traces based on temporal overlap.
|
||||
@@ -1589,126 +1013,9 @@ async fn run_identity_handler(
|
||||
|
||||
/// Read all TMDb identities with profile photos, extract face embeddings, store in Qdrant as seeds.
|
||||
pub async fn generate_seed_embeddings(db: &PostgresDb) -> anyhow::Result<usize> {
|
||||
use crate::core::db::face_embedding_db::FaceEmbeddingDb;
|
||||
use std::path::Path;
|
||||
|
||||
let pool = db.pool();
|
||||
let id_table = schema::table_name("identities");
|
||||
|
||||
let rows = sqlx::query_as::<_, (i32, String, String, i32, String)>(&format!(
|
||||
"SELECT id, name, uuid::text, tmdb_id, tmdb_profile FROM {} \
|
||||
WHERE source='tmdb' AND tmdb_profile IS NOT NULL",
|
||||
id_table
|
||||
))
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
if rows.is_empty() {
|
||||
tracing::warn!("[GenerateSeeds] No TMDb identities with profile photos");
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let scripts_dir = std::env::var("MOMENTRY_SCRIPTS_DIR")
|
||||
.unwrap_or_else(|_| "/Users/accusys/momentry_core_0.1/scripts".to_string());
|
||||
let python_path = std::env::var("MOMENTRY_PYTHON_PATH")
|
||||
.unwrap_or_else(|_| "/opt/homebrew/bin/python3.11".to_string());
|
||||
|
||||
let extract_script = Path::new(&scripts_dir).join("extract_face_embedding.py");
|
||||
let face_db = FaceEmbeddingDb::new();
|
||||
|
||||
let mut success = 0usize;
|
||||
for (id, name, uuid, tmdb_id, profile_url) in &rows {
|
||||
tracing::info!("[GenerateSeeds] Processing {} ({})", name, uuid);
|
||||
|
||||
// Download profile image
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new());
|
||||
let resp = client.get(profile_url).send().await;
|
||||
let image_bytes = match resp {
|
||||
Ok(r) if r.status().is_success() => r.bytes().await.unwrap_or_default(),
|
||||
_ => {
|
||||
tracing::warn!("[GenerateSeeds] Failed to download: {} from {}", name, profile_url);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if image_bytes.is_empty() {
|
||||
tracing::warn!("[GenerateSeeds] Empty image for {}", name);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Save to temp file
|
||||
let temp_dir = std::env::temp_dir().join("momentry_seed_faces");
|
||||
std::fs::create_dir_all(&temp_dir)?;
|
||||
let temp_img = temp_dir.join(format!("{}.jpg", uuid));
|
||||
std::fs::write(&temp_img, &image_bytes)?;
|
||||
|
||||
// Extract embedding with timeout
|
||||
use tokio::time::timeout;
|
||||
let output = timeout(
|
||||
std::time::Duration::from_secs(180),
|
||||
tokio::process::Command::new(&python_path)
|
||||
.arg(&extract_script)
|
||||
.arg(&temp_img)
|
||||
.output(),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("Extract embedding timed out for {}", name))??;
|
||||
|
||||
let _ = std::fs::remove_file(&temp_img);
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
tracing::warn!(
|
||||
"[GenerateSeeds] Extraction failed for {}: {}",
|
||||
name,
|
||||
stderr.trim()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let extract_result: serde_json::Value = match serde_json::from_str(&stdout) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::warn!("[GenerateSeeds] Parse error for {}: {}", name, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let embedding: Vec<f64> = match serde_json::from_value(
|
||||
extract_result.get("embedding").ok_or_else(|| anyhow::anyhow!("No embedding"))?.clone(),
|
||||
) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::warn!("[GenerateSeeds] Embedding format error for {}: {}", name, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let embedding_f32: Vec<f32> = embedding.into_iter().map(|v| v as f32).collect();
|
||||
|
||||
// Store in Qdrant
|
||||
match face_db
|
||||
.upsert_seed_embedding(uuid, name, *tmdb_id, &embedding_f32)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
success += 1;
|
||||
tracing::info!("[GenerateSeeds] Stored seed for {}", name);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("[GenerateSeeds] Qdrant error for {}: {}", name, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"[GenerateSeeds] Done: {}/{} seeds generated",
|
||||
success,
|
||||
rows.len()
|
||||
tracing::warn!(
|
||||
"[GenerateSeeds] Seed embedding generation disabled - FaceEmbeddingDb removed. \
|
||||
TODO: Reimplement with _faces collection"
|
||||
);
|
||||
Ok(success)
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user