730 lines
28 KiB
Python
Executable File
730 lines
28 KiB
Python
Executable File
"""
|
||
SelfASRXFixed - 7 步 Hybrid Speaker Diarization Pipeline
|
||
|
||
Pipeline:
|
||
1. whisper.transcribe(full_audio) → rough segments + text + language
|
||
2. VAD scan each rough segment → refined segments
|
||
3. whisper per refined segment → {text, language, lang_prob}
|
||
4. ECAPA-TDNN per refined segment → 192-dim embeddings
|
||
5. AgglomerativeClustering → speaker_labels
|
||
6. Store all embeddings in Qdrant (payload: file_uuid, speaker_id, text, ...)
|
||
7. High-quality embeddings → gender classify + store reference in Qdrant
|
||
"""
|
||
|
||
import sys
|
||
import json
|
||
import time
|
||
import os
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from urllib.request import Request, urlopen
|
||
from urllib.error import URLError
|
||
|
||
|
||
def _load_audio(path):
|
||
"""載入音頻文件,回傳 (wav_numpy, sample_rate)"""
|
||
import soundfile as sf
|
||
wav, sr = sf.read(path)
|
||
if len(wav.shape) > 1:
|
||
wav = np.mean(wav, axis=1)
|
||
return wav, sr
|
||
|
||
|
||
def _load_whisper_model(size="small"):
|
||
from whisper_local import load_model
|
||
return load_model(size)
|
||
|
||
|
||
def _load_vad():
|
||
from vad import load_vad_model
|
||
return load_vad_model()
|
||
|
||
|
||
def _load_speaker_encoder():
|
||
from speaker_encoder import load_speaker_encoder
|
||
return load_speaker_encoder()
|
||
|
||
|
||
def _load_gender_classifier():
|
||
try:
|
||
from speechbrain.inference.classifiers import EncoderClassifier
|
||
classifier = EncoderClassifier.from_hparams(
|
||
source="speechbrain/gender-recognition-ecapa",
|
||
run_opts={"device": "cpu"},
|
||
)
|
||
print("[Gender] Classifier loaded: speechbrain/gender-recognition-ecapa")
|
||
return classifier
|
||
except Exception as e:
|
||
print(f"[Gender] Classifier not available: {e}")
|
||
return None
|
||
|
||
|
||
def _ensure_speaker_collection(qdrant_url, api_key, collection):
|
||
"""確認 Qdrant speaker collection 存在,不存在則建立 (dim=192, cosine)"""
|
||
try:
|
||
url = f"{qdrant_url}/collections/{collection}"
|
||
req = Request(url, method="GET",
|
||
headers={"api-key": api_key} if api_key else {})
|
||
try:
|
||
urlopen(req)
|
||
return True
|
||
except URLError as e:
|
||
if getattr(e, "code", None) == 404:
|
||
body = json.dumps({
|
||
"vectors": {
|
||
"size": 192,
|
||
"distance": "Cosine"
|
||
}
|
||
}).encode()
|
||
req = Request(url, data=body, method="PUT",
|
||
headers={"Content-Type": "application/json",
|
||
**({"api-key": api_key} if api_key else {})})
|
||
urlopen(req)
|
||
print(f"[Qdrant] Created collection: {collection} (dim=192)")
|
||
return True
|
||
raise
|
||
except Exception as e:
|
||
print(f"[Qdrant] Cannot access Qdrant: {e}")
|
||
return False
|
||
|
||
|
||
def _qdrant_upsert(qdrant_url, api_key, collection, points):
|
||
"""批量寫入 Qdrant points"""
|
||
try:
|
||
url = f"{qdrant_url}/collections/{collection}/points?wait=true"
|
||
body = json.dumps({"points": points}).encode()
|
||
headers = {"Content-Type": "application/json"}
|
||
if api_key:
|
||
headers["api-key"] = api_key
|
||
req = Request(url, data=body, headers=headers, method="PUT")
|
||
urlopen(req)
|
||
return True
|
||
except Exception as e:
|
||
print(f"[Qdrant] Upsert failed: {e}")
|
||
return False
|
||
|
||
|
||
def _hash_point_id(file_uuid, label):
|
||
"""產生一致的 point ID"""
|
||
s = f"{file_uuid}_{label}"
|
||
return hash(s) & 0x7FFFFFFFFFFFFFFF
|
||
|
||
|
||
def _save_checkpoint(path: str, data: dict):
|
||
"""原子寫入 checkpoint(先 .tmp 再 rename)"""
|
||
tmp = path + ".tmp"
|
||
Path(tmp).parent.mkdir(parents=True, exist_ok=True)
|
||
with open(tmp, "w", encoding="utf-8") as f:
|
||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||
os.replace(tmp, path)
|
||
|
||
|
||
def compute_embedding_quality(embeddings, labels):
|
||
"""每個 embedding 到所屬 cluster centroid 的餘弦相似度"""
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
unique_labels = set(labels)
|
||
centroids = {}
|
||
for label in unique_labels:
|
||
mask = labels == label
|
||
centroid = np.mean(embeddings[mask], axis=0)
|
||
norm = np.linalg.norm(centroid)
|
||
if norm > 0:
|
||
centroid = centroid / norm
|
||
centroids[label] = centroid
|
||
qualities = []
|
||
for emb, label in zip(embeddings, labels):
|
||
sim = cosine_similarity([emb], [centroids[label]])[0][0]
|
||
qualities.append(sim)
|
||
return np.array(qualities)
|
||
|
||
|
||
class SelfASRXFixed:
|
||
"""7 步 Hybrid Speaker Diarization Pipeline"""
|
||
|
||
def __init__(self):
|
||
print("[SelfASRX] Initializing models...")
|
||
|
||
print("[SelfASRX] Loading whisper model...")
|
||
self.whisper = _load_whisper_model("small")
|
||
|
||
print("[SelfASRX] Loading VAD model (Silero)...")
|
||
self.vad_model, self.vad_utils = _load_vad()
|
||
|
||
print("[SelfASRX] Loading speaker encoder (ECAPA-TDNN)...")
|
||
self.speaker_encoder = _load_speaker_encoder()
|
||
|
||
print("[SelfASRX] Loading gender classifier...")
|
||
self.gender_classifier = _load_gender_classifier()
|
||
|
||
# Qdrant 設定
|
||
self.qdrant_url = os.environ.get("QDRANT_URL", "http://localhost:6333")
|
||
self.qdrant_api_key = os.environ.get("QDRANT_API_KEY", "")
|
||
schema = os.environ.get("DATABASE_SCHEMA", "public")
|
||
self.qdrant_collection = os.environ.get(
|
||
"QDRANT_SPEAKER_COLLECTION",
|
||
f"momentry_{schema}_speaker"
|
||
)
|
||
self._qdrant_ok = False
|
||
|
||
print("[SelfASRX] Models loaded successfully")
|
||
|
||
def process(self, audio_path, output_path=None, file_uuid=None,
|
||
max_speakers=10, quality_threshold=0.85,
|
||
checkpoint_path=None):
|
||
"""7 步 speaker diarization pipeline
|
||
|
||
Args:
|
||
audio_path: 音頻文件路徑 (WAV 16kHz mono)
|
||
output_path: 輸出 JSON 路徑 (可選)
|
||
file_uuid: 檔案 UUID (用於 Qdrant 儲存)
|
||
max_speakers: 最大說話人數
|
||
quality_threshold: 高品質聲紋門檻 (0-1)
|
||
checkpoint_path: Step 3 完成後儲存 checkpoint 路徑
|
||
|
||
Returns:
|
||
dict: segments, speaker_stats, n_speakers, total_duration, references
|
||
"""
|
||
start_time = time.time()
|
||
print(f"\n[SelfASRX] Processing: {audio_path}")
|
||
print("=" * 60)
|
||
|
||
# 載入音頻
|
||
wav, sample_rate = _load_audio(audio_path)
|
||
total_duration = len(wav) / sample_rate
|
||
print(f" Audio: {total_duration:.2f}s, {sample_rate}Hz")
|
||
|
||
# ── Step 1: whisper 粗略定位 (faster-whisper) ──
|
||
print("\n[Step 1] Initial whisper transcription...")
|
||
t1 = time.time()
|
||
seg_gen, info = self.whisper.transcribe(audio_path)
|
||
rough_segments = []
|
||
for seg in seg_gen:
|
||
rough_segments.append({"start": seg.start, "end": seg.end, "text": seg.text})
|
||
language = info.language if info else None
|
||
print(f" Rough segments: {len(rough_segments)}")
|
||
print(f" Language: {language}")
|
||
print(f" Step 1 time: {time.time() - t1:.2f}s")
|
||
|
||
if not rough_segments:
|
||
print("[SelfASRX] No speech detected by whisper!")
|
||
return {"error": "No speech detected", "segments": []}
|
||
|
||
# ── Step 2: VAD scan 每個 rough segment 細切 ──
|
||
print("\n[Step 2] VAD scan for refined segmentation...")
|
||
t2 = time.time()
|
||
refined_segments = []
|
||
for seg in rough_segments:
|
||
s = seg["start"]
|
||
e = seg["end"]
|
||
sub = self._vad_scan_segment(wav, sample_rate, s, e)
|
||
if sub:
|
||
refined_segments.extend(sub)
|
||
else:
|
||
refined_segments.append((s, e))
|
||
print(f" Refined segments: {len(refined_segments)}")
|
||
print(f" Step 2 time: {time.time() - t2:.2f}s")
|
||
|
||
if not refined_segments:
|
||
return {"error": "No segments after VAD scan", "segments": []}
|
||
|
||
# ── Step 3: whisper per refined segment ──
|
||
print("\n[Step 3] Per-segment transcription...")
|
||
t3 = time.time()
|
||
CHECKPOINT_INTERVAL = 50
|
||
|
||
segment_texts = []
|
||
resume_from = 0
|
||
|
||
# 載入既有 partial checkpoint(中斷續接)
|
||
if checkpoint_path and os.path.exists(checkpoint_path):
|
||
try:
|
||
with open(checkpoint_path, "r") as f:
|
||
cp = json.load(f)
|
||
if cp.get("checkpoint_version") == 2 and not cp.get("step3_completed"):
|
||
saved = cp.get("segment_texts", [])
|
||
if saved:
|
||
resume_from = len(saved)
|
||
segment_texts = saved
|
||
print(f"[Step 3] Resuming from #{resume_from}/{len(refined_segments)}")
|
||
except Exception:
|
||
pass
|
||
|
||
for i, (start_sec, end_sec) in enumerate(refined_segments):
|
||
if i < resume_from:
|
||
continue
|
||
seg_text = self._transcribe_segment(wav, sample_rate, start_sec, end_sec)
|
||
segment_texts.append(seg_text)
|
||
|
||
if checkpoint_path and (i + 1) % CHECKPOINT_INTERVAL == 0:
|
||
_save_checkpoint(checkpoint_path, {
|
||
"checkpoint_version": 2,
|
||
"step3_completed": False,
|
||
"step3_progress": i + 1,
|
||
"language": language,
|
||
"total_duration": total_duration,
|
||
"refined_segments": [[s, e] for s, e in refined_segments],
|
||
"segment_texts": [{
|
||
"text": st["text"],
|
||
"language": st["language"],
|
||
"lang_prob": st["lang_prob"],
|
||
} for st in segment_texts],
|
||
"file_uuid": file_uuid,
|
||
"max_speakers": max_speakers,
|
||
"quality_threshold": quality_threshold,
|
||
})
|
||
print(f"[Checkpoint] Step 3: {i+1}/{len(refined_segments)}")
|
||
|
||
print(f" Step 3 time: {time.time() - t3:.2f}s")
|
||
|
||
# ── Save final checkpoint after Step 3 ──
|
||
if checkpoint_path:
|
||
_save_checkpoint(checkpoint_path, {
|
||
"checkpoint_version": 2,
|
||
"step3_completed": True,
|
||
"language": language,
|
||
"total_duration": total_duration,
|
||
"refined_segments": [[s, e] for s, e in refined_segments],
|
||
"segment_texts": [{
|
||
"text": st["text"],
|
||
"language": st["language"],
|
||
"lang_prob": st["lang_prob"],
|
||
} for st in segment_texts],
|
||
"file_uuid": file_uuid,
|
||
"max_speakers": max_speakers,
|
||
"quality_threshold": quality_threshold,
|
||
})
|
||
print(f"[Checkpoint] Step 3 complete, saved to {checkpoint_path}")
|
||
|
||
# ── Step 4: ECAPA-TDNN per refined segment ──
|
||
print("\n[Step 4] Speaker embedding extraction...")
|
||
t4 = time.time()
|
||
audio_segments = []
|
||
for start_sec, end_sec in refined_segments:
|
||
s = int(start_sec * sample_rate)
|
||
e = int(end_sec * sample_rate)
|
||
audio_segments.append(wav[s:min(e, len(wav))])
|
||
|
||
from speaker_encoder import extract_speaker_embeddings_batch, normalize_embeddings
|
||
embeddings = extract_speaker_embeddings_batch(
|
||
self.speaker_encoder, audio_segments, sample_rate
|
||
)
|
||
embeddings = normalize_embeddings(embeddings)
|
||
print(f" Embeddings: {embeddings.shape}")
|
||
print(f" Step 4 time: {time.time() - t4:.2f}s")
|
||
|
||
# ── Step 5: AgglomerativeClustering ──
|
||
print("\n[Step 5] Speaker clustering...")
|
||
t5 = time.time()
|
||
from speaker_cluster_fixed import robust_speaker_clustering
|
||
speaker_labels, estimated_n_speakers = robust_speaker_clustering(
|
||
embeddings, n_speakers=None, max_speakers=max_speakers
|
||
)
|
||
print(f" Speakers: {estimated_n_speakers}")
|
||
print(f" Step 5 time: {time.time() - t5:.2f}s")
|
||
|
||
# 品質計算
|
||
qualities = compute_embedding_quality(embeddings, speaker_labels)
|
||
|
||
# 建立輸出 segments
|
||
segments = []
|
||
for i, ((start_sec, end_sec), label) in enumerate(
|
||
zip(refined_segments, speaker_labels)):
|
||
seg = {
|
||
"start": round(start_sec, 3),
|
||
"end": round(end_sec, 3),
|
||
"start_frame": int(start_sec * 30),
|
||
"end_frame": int(end_sec * 30),
|
||
"text": segment_texts[i]["text"],
|
||
"language": segment_texts[i]["language"],
|
||
"lang_prob": segment_texts[i]["lang_prob"],
|
||
"speaker": f"SPEAKER_{int(label)}",
|
||
"speaker_id": f"SPEAKER_{int(label)}",
|
||
"quality": float(qualities[i]),
|
||
}
|
||
segments.append(seg)
|
||
|
||
# 統計
|
||
speaker_stats = {}
|
||
for seg in segments:
|
||
spk = seg["speaker_id"]
|
||
dur = seg["end"] - seg["start"]
|
||
if spk not in speaker_stats:
|
||
speaker_stats[spk] = {"count": 0, "duration": 0}
|
||
speaker_stats[spk]["count"] += 1
|
||
speaker_stats[spk]["duration"] += dur
|
||
|
||
result = {
|
||
"language": language or "",
|
||
"segments": segments,
|
||
"n_speakers": int(estimated_n_speakers),
|
||
"speaker_stats": speaker_stats,
|
||
"total_duration": total_duration,
|
||
"n_segments": len(segments),
|
||
}
|
||
|
||
# ── Step 6: Store embeddings in Qdrant ──
|
||
if file_uuid:
|
||
print("\n[Step 6] Storing embeddings in Qdrant...")
|
||
t6 = time.time()
|
||
self._store_speaker_embeddings(segments, embeddings, speaker_labels,
|
||
file_uuid)
|
||
print(f" Step 6 time: {time.time() - t6:.2f}s")
|
||
|
||
# ── Step 7: High-quality classification ──
|
||
if file_uuid:
|
||
print("\n[Step 7] Classifying high-quality embeddings...")
|
||
t7 = time.time()
|
||
references = self._classify_high_quality_speakers(
|
||
segments, embeddings, speaker_labels, file_uuid,
|
||
wav, sample_rate, quality_threshold
|
||
)
|
||
if references:
|
||
result["references"] = references
|
||
print(f" Step 7 time: {time.time() - t7:.2f}s")
|
||
|
||
total_time = time.time() - start_time
|
||
result["processing_time"] = round(total_time, 2)
|
||
if total_duration > 0:
|
||
result["realtime_factor"] = round(total_duration / total_time, 2)
|
||
|
||
# 保存輸出
|
||
if output_path:
|
||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
json.dump(result, f, indent=2, ensure_ascii=False)
|
||
print(f"\n[SelfASRX] Saved to: {output_path}")
|
||
|
||
print(f"\n[SelfASRX] Done! {len(segments)} segments, "
|
||
f"{estimated_n_speakers} speakers, "
|
||
f"{total_time:.2f}s")
|
||
|
||
return result
|
||
|
||
def resume_from_checkpoint(self, checkpoint_path, audio_path,
|
||
output_path=None):
|
||
"""從 checkpoint 載入 Steps 1-3 結果,執行 Steps 4-7"""
|
||
print(f"\n[SelfASRX] Resuming from checkpoint: {checkpoint_path}")
|
||
print("=" * 60)
|
||
|
||
with open(checkpoint_path, "r", encoding="utf-8") as f:
|
||
cp = json.load(f)
|
||
|
||
if not cp.get("step3_completed"):
|
||
error_msg = f"Checkpoint step3 not completed (progress: {cp.get('step3_progress', '?')})"
|
||
print(f"[SelfASRX] {error_msg}")
|
||
return {"error": error_msg, "segments": []}
|
||
|
||
wav, sample_rate = _load_audio(audio_path)
|
||
refined_segments = [tuple(s) for s in cp["refined_segments"]]
|
||
segment_texts = cp["segment_texts"]
|
||
language = cp.get("language", "")
|
||
total_duration = cp.get("total_duration", 0)
|
||
file_uuid = cp.get("file_uuid")
|
||
max_speakers = cp.get("max_speakers", 10)
|
||
quality_threshold = cp.get("quality_threshold", 0.85)
|
||
|
||
print(f" Loaded checkpoint: {len(refined_segments)} segments, "
|
||
f"language={language}, duration={total_duration:.2f}s")
|
||
|
||
start_time = time.time()
|
||
|
||
# ── Step 4: ECAPA-TDNN per refined segment ──
|
||
print("\n[Step 4] Speaker embedding extraction...")
|
||
t4 = time.time()
|
||
audio_segments = []
|
||
for start_sec, end_sec in refined_segments:
|
||
s = int(start_sec * sample_rate)
|
||
e = int(end_sec * sample_rate)
|
||
audio_segments.append(wav[s:min(e, len(wav))])
|
||
|
||
from speaker_encoder import extract_speaker_embeddings_batch, normalize_embeddings
|
||
embeddings = extract_speaker_embeddings_batch(
|
||
self.speaker_encoder, audio_segments, sample_rate
|
||
)
|
||
embeddings = normalize_embeddings(embeddings)
|
||
print(f" Embeddings: {embeddings.shape}")
|
||
print(f" Step 4 time: {time.time() - t4:.2f}s")
|
||
|
||
# ── Step 5: AgglomerativeClustering ──
|
||
print("\n[Step 5] Speaker clustering...")
|
||
t5 = time.time()
|
||
from speaker_cluster_fixed import robust_speaker_clustering
|
||
speaker_labels, estimated_n_speakers = robust_speaker_clustering(
|
||
embeddings, n_speakers=None, max_speakers=max_speakers
|
||
)
|
||
print(f" Speakers: {estimated_n_speakers}")
|
||
print(f" Step 5 time: {time.time() - t5:.2f}s")
|
||
|
||
# 品質計算
|
||
qualities = compute_embedding_quality(embeddings, speaker_labels)
|
||
|
||
# 建立輸出 segments
|
||
segments = []
|
||
for i, ((start_sec, end_sec), label) in enumerate(
|
||
zip(refined_segments, speaker_labels)):
|
||
seg = {
|
||
"start": round(start_sec, 3),
|
||
"end": round(end_sec, 3),
|
||
"start_frame": int(start_sec * 30),
|
||
"end_frame": int(end_sec * 30),
|
||
"text": segment_texts[i]["text"],
|
||
"language": segment_texts[i]["language"],
|
||
"lang_prob": segment_texts[i]["lang_prob"],
|
||
"speaker": f"SPEAKER_{int(label)}",
|
||
"speaker_id": f"SPEAKER_{int(label)}",
|
||
"quality": float(qualities[i]),
|
||
}
|
||
segments.append(seg)
|
||
|
||
# 統計
|
||
speaker_stats = {}
|
||
for seg in segments:
|
||
spk = seg["speaker_id"]
|
||
dur = seg["end"] - seg["start"]
|
||
if spk not in speaker_stats:
|
||
speaker_stats[spk] = {"count": 0, "duration": 0}
|
||
speaker_stats[spk]["count"] += 1
|
||
speaker_stats[spk]["duration"] += dur
|
||
|
||
result = {
|
||
"language": language or "",
|
||
"segments": segments,
|
||
"n_speakers": int(estimated_n_speakers),
|
||
"speaker_stats": speaker_stats,
|
||
"total_duration": total_duration,
|
||
"n_segments": len(segments),
|
||
}
|
||
|
||
# ── Step 6: Store embeddings in Qdrant ──
|
||
if file_uuid:
|
||
print("\n[Step 6] Storing embeddings in Qdrant...")
|
||
t6 = time.time()
|
||
self._store_speaker_embeddings(segments, embeddings, speaker_labels,
|
||
file_uuid)
|
||
print(f" Step 6 time: {time.time() - t6:.2f}s")
|
||
|
||
# ── Step 7: High-quality classification ──
|
||
if file_uuid:
|
||
print("\n[Step 7] Classifying high-quality embeddings...")
|
||
t7 = time.time()
|
||
references = self._classify_high_quality_speakers(
|
||
segments, embeddings, speaker_labels, file_uuid,
|
||
wav, sample_rate, quality_threshold
|
||
)
|
||
if references:
|
||
result["references"] = references
|
||
print(f" Step 7 time: {time.time() - t7:.2f}s")
|
||
|
||
total_time = time.time() - start_time
|
||
result["processing_time"] = round(total_time, 2)
|
||
if total_duration > 0:
|
||
result["realtime_factor"] = round(total_duration / total_time, 2)
|
||
|
||
# 保存輸出
|
||
if output_path:
|
||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
json.dump(result, f, indent=2, ensure_ascii=False)
|
||
print(f"\n[SelfASRX] Saved to: {output_path}")
|
||
|
||
print(f"\n[SelfASRX] Done! {len(segments)} segments, "
|
||
f"{estimated_n_speakers} speakers, "
|
||
f"{total_time:.2f}s")
|
||
|
||
return result
|
||
|
||
# ── Internal helpers ──
|
||
|
||
def _vad_scan_segment(self, wav, sample_rate, start_sec, end_sec):
|
||
"""VAD 細切單一段落"""
|
||
from vad import scan_within_segment
|
||
return scan_within_segment(
|
||
wav, sample_rate, start_sec, end_sec,
|
||
self.vad_model, self.vad_utils
|
||
)
|
||
|
||
def _transcribe_segment(self, wav, sample_rate, start_sec, end_sec):
|
||
"""轉錄單一段落"""
|
||
from whisper_local import transcribe_segment
|
||
return transcribe_segment(wav, sample_rate, start_sec, end_sec, self.whisper)
|
||
|
||
def _store_speaker_embeddings(self, segments, embeddings, labels, file_uuid):
|
||
"""Step 6: 所有 embedding 存入 Qdrant"""
|
||
if not self._ensure_qdrant():
|
||
return
|
||
|
||
points = []
|
||
for i, (seg, emb, label) in enumerate(
|
||
zip(segments, embeddings, labels)):
|
||
point_id = _hash_point_id(file_uuid, f"{i}")
|
||
points.append({
|
||
"id": point_id,
|
||
"vector": emb.tolist(),
|
||
"payload": {
|
||
"type": "speaker_embedding",
|
||
"file_uuid": file_uuid,
|
||
"speaker_id": seg["speaker_id"],
|
||
"text": seg["text"],
|
||
"language": seg["language"],
|
||
"start_time": seg["start"],
|
||
"end_time": seg["end"],
|
||
}
|
||
})
|
||
|
||
ok = _qdrant_upsert(self.qdrant_url, self.qdrant_api_key,
|
||
self.qdrant_collection, points)
|
||
if ok:
|
||
print(f" Stored {len(points)} speaker embeddings to Qdrant")
|
||
return ok
|
||
|
||
def _classify_high_quality_speakers(self, segments, embeddings, labels,
|
||
file_uuid, wav, sample_rate,
|
||
threshold=0.85):
|
||
"""Step 7: 高品質聲紋分級 + 性別分類 → Qdrant reference"""
|
||
qualities = compute_embedding_quality(embeddings, labels)
|
||
high_mask = qualities >= threshold
|
||
|
||
if not np.any(high_mask):
|
||
print(" No high-quality embeddings found")
|
||
return []
|
||
|
||
unique_labels = set(labels)
|
||
references = []
|
||
for label in unique_labels:
|
||
mask = (labels == label) & high_mask
|
||
if not np.any(mask):
|
||
continue
|
||
high_indices = [i for i in range(len(segments)) if mask[i]]
|
||
high_segs = [segments[i] for i in high_indices]
|
||
|
||
# 取品質最高的 segment index
|
||
best_idx = high_indices[int(np.argmax(qualities[mask]))]
|
||
best_seg = segments[best_idx]
|
||
|
||
centroid = np.mean(embeddings[mask], axis=0)
|
||
norm = np.linalg.norm(centroid)
|
||
if norm > 0:
|
||
centroid = centroid / norm
|
||
|
||
avg_quality = float(np.mean(qualities[mask]))
|
||
speaker_id = f"SPEAKER_{int(label)}"
|
||
text_samples = [s["text"] for s in high_segs[:5] if s["text"]]
|
||
total_dur = sum(s["end"] - s["start"] for s in high_segs)
|
||
|
||
ref_id = _hash_point_id(file_uuid, f"ref_{label}")
|
||
ref_payload = {
|
||
"type": "speaker_reference",
|
||
"file_uuid": file_uuid,
|
||
"speaker_id": speaker_id,
|
||
"n_segments": int(np.sum(mask)),
|
||
"avg_quality": avg_quality,
|
||
"total_duration": round(total_dur, 2),
|
||
"language": best_seg.get("language", ""),
|
||
"text_samples": text_samples,
|
||
}
|
||
|
||
# 性別分類:用最佳 segment 的音頻
|
||
if self.gender_classifier is not None:
|
||
try:
|
||
import torch
|
||
s = int(best_seg["start"] * sample_rate)
|
||
e = int(best_seg["end"] * sample_rate)
|
||
seg_wav = wav[s:min(e, len(wav))]
|
||
seg_tensor = torch.from_numpy(seg_wav).float().unsqueeze(0)
|
||
# SpeechBrain gender classifier 接受音頻
|
||
out = self.gender_classifier.classify_batch(seg_tensor)
|
||
probs = torch.softmax(out[0], dim=-1).squeeze().cpu().detach().numpy()
|
||
if len(probs) >= 2:
|
||
idx = int(np.argmax(probs))
|
||
ref_payload["gender"] = "male" if idx == 0 else "female"
|
||
ref_payload["gender_conf"] = float(probs[idx])
|
||
else:
|
||
ref_payload["gender"] = "unknown"
|
||
ref_payload["gender_conf"] = 0.0
|
||
except Exception as e:
|
||
print(f"[Gender] Classify error: {e}")
|
||
ref_payload["gender"] = "unknown"
|
||
ref_payload["gender_conf"] = 0.0
|
||
else:
|
||
ref_payload["gender"] = "unknown"
|
||
ref_payload["gender_conf"] = 0.0
|
||
|
||
_qdrant_upsert(self.qdrant_url, self.qdrant_api_key,
|
||
self.qdrant_collection, [{
|
||
"id": ref_id,
|
||
"vector": centroid.tolist(),
|
||
"payload": ref_payload,
|
||
}])
|
||
|
||
references.append({
|
||
"speaker_id": speaker_id,
|
||
"n_segments": int(np.sum(mask)),
|
||
"avg_quality": avg_quality,
|
||
"gender": ref_payload["gender"],
|
||
})
|
||
|
||
print(f" Ref: {speaker_id}, gender={ref_payload['gender']}"
|
||
f" ({ref_payload['gender_conf']:.2f}), q={avg_quality:.3f}")
|
||
|
||
return references
|
||
|
||
def _ensure_qdrant(self):
|
||
"""確保 Qdrant collection 可用"""
|
||
if not self._qdrant_ok:
|
||
ok = _ensure_speaker_collection(
|
||
self.qdrant_url, self.qdrant_api_key, self.qdrant_collection
|
||
)
|
||
self._qdrant_ok = ok
|
||
return self._qdrant_ok
|
||
|
||
|
||
def main():
|
||
import argparse
|
||
parser = argparse.ArgumentParser(description="SelfASRX - Hybrid Speaker Diarization")
|
||
parser.add_argument("audio_path", help="Path to audio file (WAV)")
|
||
parser.add_argument("-o", "--output", help="Output JSON path")
|
||
parser.add_argument("--file-uuid", help="File UUID for Qdrant storage")
|
||
parser.add_argument("--max-speakers", type=int, default=10)
|
||
parser.add_argument("--quality-threshold", type=float, default=0.85)
|
||
parser.add_argument("--resume", help="Checkpoint path to resume from")
|
||
parser.add_argument("--checkpoint", help="Save checkpoint path after Step 3")
|
||
args = parser.parse_args()
|
||
|
||
asrx = SelfASRXFixed()
|
||
|
||
if args.resume:
|
||
if not Path(args.resume).exists():
|
||
print(f"Error: Checkpoint not found: {args.resume}")
|
||
sys.exit(1)
|
||
result = asrx.resume_from_checkpoint(
|
||
args.resume, args.audio_path,
|
||
output_path=args.output,
|
||
)
|
||
else:
|
||
if not Path(args.audio_path).exists():
|
||
print(f"Error: Audio file not found: {args.audio_path}")
|
||
sys.exit(1)
|
||
|
||
result = asrx.process(
|
||
args.audio_path,
|
||
output_path=args.output,
|
||
file_uuid=args.file_uuid,
|
||
max_speakers=args.max_speakers,
|
||
quality_threshold=args.quality_threshold,
|
||
checkpoint_path=args.checkpoint,
|
||
)
|
||
|
||
if "error" not in result:
|
||
print("\n[Summary]")
|
||
print(f" Duration: {result['total_duration']:.2f}s")
|
||
print(f" Segments: {result['n_segments']}")
|
||
print(f" Speakers: {result['n_speakers']}")
|
||
if "references" in result:
|
||
for ref in result["references"]:
|
||
print(f" {ref['speaker_id']}: gender={ref['gender']}, "
|
||
f"quality={ref['avg_quality']:.3f}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|