Files
momentry_core/scripts/asrx_self/main_fixed.py

730 lines
28 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
SelfASRXFixed - 7 步 Hybrid Speaker Diarization Pipeline
Pipeline:
1. whisper.transcribe(full_audio) → rough segments + text + language
2. VAD scan each rough segment → refined segments
3. whisper per refined segment → {text, language, lang_prob}
4. ECAPA-TDNN per refined segment → 192-dim embeddings
5. AgglomerativeClustering → speaker_labels
6. Store all embeddings in Qdrant (payload: file_uuid, speaker_id, text, ...)
7. High-quality embeddings → gender classify + store reference in Qdrant
"""
import sys
import json
import time
import os
import numpy as np
from pathlib import Path
from urllib.request import Request, urlopen
from urllib.error import URLError
def _load_audio(path):
"""載入音頻文件,回傳 (wav_numpy, sample_rate)"""
import soundfile as sf
wav, sr = sf.read(path)
if len(wav.shape) > 1:
wav = np.mean(wav, axis=1)
return wav, sr
def _load_whisper_model(size="small"):
from whisper_local import load_model
return load_model(size)
def _load_vad():
from vad import load_vad_model
return load_vad_model()
def _load_speaker_encoder():
from speaker_encoder import load_speaker_encoder
return load_speaker_encoder()
def _load_gender_classifier():
try:
from speechbrain.inference.classifiers import EncoderClassifier
classifier = EncoderClassifier.from_hparams(
source="speechbrain/gender-recognition-ecapa",
run_opts={"device": "cpu"},
)
print("[Gender] Classifier loaded: speechbrain/gender-recognition-ecapa")
return classifier
except Exception as e:
print(f"[Gender] Classifier not available: {e}")
return None
def _ensure_speaker_collection(qdrant_url, api_key, collection):
"""確認 Qdrant speaker collection 存在,不存在則建立 (dim=192, cosine)"""
try:
url = f"{qdrant_url}/collections/{collection}"
req = Request(url, method="GET",
headers={"api-key": api_key} if api_key else {})
try:
urlopen(req)
return True
except URLError as e:
if getattr(e, "code", None) == 404:
body = json.dumps({
"vectors": {
"size": 192,
"distance": "Cosine"
}
}).encode()
req = Request(url, data=body, method="PUT",
headers={"Content-Type": "application/json",
**({"api-key": api_key} if api_key else {})})
urlopen(req)
print(f"[Qdrant] Created collection: {collection} (dim=192)")
return True
raise
except Exception as e:
print(f"[Qdrant] Cannot access Qdrant: {e}")
return False
def _qdrant_upsert(qdrant_url, api_key, collection, points):
"""批量寫入 Qdrant points"""
try:
url = f"{qdrant_url}/collections/{collection}/points?wait=true"
body = json.dumps({"points": points}).encode()
headers = {"Content-Type": "application/json"}
if api_key:
headers["api-key"] = api_key
req = Request(url, data=body, headers=headers, method="PUT")
urlopen(req)
return True
except Exception as e:
print(f"[Qdrant] Upsert failed: {e}")
return False
def _hash_point_id(file_uuid, label):
"""產生一致的 point ID"""
s = f"{file_uuid}_{label}"
return hash(s) & 0x7FFFFFFFFFFFFFFF
def _save_checkpoint(path: str, data: dict):
"""原子寫入 checkpoint先 .tmp 再 rename"""
tmp = path + ".tmp"
Path(tmp).parent.mkdir(parents=True, exist_ok=True)
with open(tmp, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
os.replace(tmp, path)
def compute_embedding_quality(embeddings, labels):
"""每個 embedding 到所屬 cluster centroid 的餘弦相似度"""
from sklearn.metrics.pairwise import cosine_similarity
unique_labels = set(labels)
centroids = {}
for label in unique_labels:
mask = labels == label
centroid = np.mean(embeddings[mask], axis=0)
norm = np.linalg.norm(centroid)
if norm > 0:
centroid = centroid / norm
centroids[label] = centroid
qualities = []
for emb, label in zip(embeddings, labels):
sim = cosine_similarity([emb], [centroids[label]])[0][0]
qualities.append(sim)
return np.array(qualities)
class SelfASRXFixed:
"""7 步 Hybrid Speaker Diarization Pipeline"""
def __init__(self):
print("[SelfASRX] Initializing models...")
print("[SelfASRX] Loading whisper model...")
self.whisper = _load_whisper_model("small")
print("[SelfASRX] Loading VAD model (Silero)...")
self.vad_model, self.vad_utils = _load_vad()
print("[SelfASRX] Loading speaker encoder (ECAPA-TDNN)...")
self.speaker_encoder = _load_speaker_encoder()
print("[SelfASRX] Loading gender classifier...")
self.gender_classifier = _load_gender_classifier()
# Qdrant 設定
self.qdrant_url = os.environ.get("QDRANT_URL", "http://localhost:6333")
self.qdrant_api_key = os.environ.get("QDRANT_API_KEY", "")
schema = os.environ.get("DATABASE_SCHEMA", "public")
self.qdrant_collection = os.environ.get(
"QDRANT_SPEAKER_COLLECTION",
f"momentry_{schema}_speaker"
)
self._qdrant_ok = False
print("[SelfASRX] Models loaded successfully")
def process(self, audio_path, output_path=None, file_uuid=None,
max_speakers=10, quality_threshold=0.85,
checkpoint_path=None):
"""7 步 speaker diarization pipeline
Args:
audio_path: 音頻文件路徑 (WAV 16kHz mono)
output_path: 輸出 JSON 路徑 (可選)
file_uuid: 檔案 UUID (用於 Qdrant 儲存)
max_speakers: 最大說話人數
quality_threshold: 高品質聲紋門檻 (0-1)
checkpoint_path: Step 3 完成後儲存 checkpoint 路徑
Returns:
dict: segments, speaker_stats, n_speakers, total_duration, references
"""
start_time = time.time()
print(f"\n[SelfASRX] Processing: {audio_path}")
print("=" * 60)
# 載入音頻
wav, sample_rate = _load_audio(audio_path)
total_duration = len(wav) / sample_rate
print(f" Audio: {total_duration:.2f}s, {sample_rate}Hz")
# ── Step 1: whisper 粗略定位 (faster-whisper) ──
print("\n[Step 1] Initial whisper transcription...")
t1 = time.time()
seg_gen, info = self.whisper.transcribe(audio_path)
rough_segments = []
for seg in seg_gen:
rough_segments.append({"start": seg.start, "end": seg.end, "text": seg.text})
language = info.language if info else None
print(f" Rough segments: {len(rough_segments)}")
print(f" Language: {language}")
print(f" Step 1 time: {time.time() - t1:.2f}s")
if not rough_segments:
print("[SelfASRX] No speech detected by whisper!")
return {"error": "No speech detected", "segments": []}
# ── Step 2: VAD scan 每個 rough segment 細切 ──
print("\n[Step 2] VAD scan for refined segmentation...")
t2 = time.time()
refined_segments = []
for seg in rough_segments:
s = seg["start"]
e = seg["end"]
sub = self._vad_scan_segment(wav, sample_rate, s, e)
if sub:
refined_segments.extend(sub)
else:
refined_segments.append((s, e))
print(f" Refined segments: {len(refined_segments)}")
print(f" Step 2 time: {time.time() - t2:.2f}s")
if not refined_segments:
return {"error": "No segments after VAD scan", "segments": []}
# ── Step 3: whisper per refined segment ──
print("\n[Step 3] Per-segment transcription...")
t3 = time.time()
CHECKPOINT_INTERVAL = 50
segment_texts = []
resume_from = 0
# 載入既有 partial checkpoint中斷續接
if checkpoint_path and os.path.exists(checkpoint_path):
try:
with open(checkpoint_path, "r") as f:
cp = json.load(f)
if cp.get("checkpoint_version") == 2 and not cp.get("step3_completed"):
saved = cp.get("segment_texts", [])
if saved:
resume_from = len(saved)
segment_texts = saved
print(f"[Step 3] Resuming from #{resume_from}/{len(refined_segments)}")
except Exception:
pass
for i, (start_sec, end_sec) in enumerate(refined_segments):
if i < resume_from:
continue
seg_text = self._transcribe_segment(wav, sample_rate, start_sec, end_sec)
segment_texts.append(seg_text)
if checkpoint_path and (i + 1) % CHECKPOINT_INTERVAL == 0:
_save_checkpoint(checkpoint_path, {
"checkpoint_version": 2,
"step3_completed": False,
"step3_progress": i + 1,
"language": language,
"total_duration": total_duration,
"refined_segments": [[s, e] for s, e in refined_segments],
"segment_texts": [{
"text": st["text"],
"language": st["language"],
"lang_prob": st["lang_prob"],
} for st in segment_texts],
"file_uuid": file_uuid,
"max_speakers": max_speakers,
"quality_threshold": quality_threshold,
})
print(f"[Checkpoint] Step 3: {i+1}/{len(refined_segments)}")
print(f" Step 3 time: {time.time() - t3:.2f}s")
# ── Save final checkpoint after Step 3 ──
if checkpoint_path:
_save_checkpoint(checkpoint_path, {
"checkpoint_version": 2,
"step3_completed": True,
"language": language,
"total_duration": total_duration,
"refined_segments": [[s, e] for s, e in refined_segments],
"segment_texts": [{
"text": st["text"],
"language": st["language"],
"lang_prob": st["lang_prob"],
} for st in segment_texts],
"file_uuid": file_uuid,
"max_speakers": max_speakers,
"quality_threshold": quality_threshold,
})
print(f"[Checkpoint] Step 3 complete, saved to {checkpoint_path}")
# ── Step 4: ECAPA-TDNN per refined segment ──
print("\n[Step 4] Speaker embedding extraction...")
t4 = time.time()
audio_segments = []
for start_sec, end_sec in refined_segments:
s = int(start_sec * sample_rate)
e = int(end_sec * sample_rate)
audio_segments.append(wav[s:min(e, len(wav))])
from speaker_encoder import extract_speaker_embeddings_batch, normalize_embeddings
embeddings = extract_speaker_embeddings_batch(
self.speaker_encoder, audio_segments, sample_rate
)
embeddings = normalize_embeddings(embeddings)
print(f" Embeddings: {embeddings.shape}")
print(f" Step 4 time: {time.time() - t4:.2f}s")
# ── Step 5: AgglomerativeClustering ──
print("\n[Step 5] Speaker clustering...")
t5 = time.time()
from speaker_cluster_fixed import robust_speaker_clustering
speaker_labels, estimated_n_speakers = robust_speaker_clustering(
embeddings, n_speakers=None, max_speakers=max_speakers
)
print(f" Speakers: {estimated_n_speakers}")
print(f" Step 5 time: {time.time() - t5:.2f}s")
# 品質計算
qualities = compute_embedding_quality(embeddings, speaker_labels)
# 建立輸出 segments
segments = []
for i, ((start_sec, end_sec), label) in enumerate(
zip(refined_segments, speaker_labels)):
seg = {
"start": round(start_sec, 3),
"end": round(end_sec, 3),
"start_frame": int(start_sec * 30),
"end_frame": int(end_sec * 30),
"text": segment_texts[i]["text"],
"language": segment_texts[i]["language"],
"lang_prob": segment_texts[i]["lang_prob"],
"speaker": f"SPEAKER_{int(label)}",
"speaker_id": f"SPEAKER_{int(label)}",
"quality": float(qualities[i]),
}
segments.append(seg)
# 統計
speaker_stats = {}
for seg in segments:
spk = seg["speaker_id"]
dur = seg["end"] - seg["start"]
if spk not in speaker_stats:
speaker_stats[spk] = {"count": 0, "duration": 0}
speaker_stats[spk]["count"] += 1
speaker_stats[spk]["duration"] += dur
result = {
"language": language or "",
"segments": segments,
"n_speakers": int(estimated_n_speakers),
"speaker_stats": speaker_stats,
"total_duration": total_duration,
"n_segments": len(segments),
}
# ── Step 6: Store embeddings in Qdrant ──
if file_uuid:
print("\n[Step 6] Storing embeddings in Qdrant...")
t6 = time.time()
self._store_speaker_embeddings(segments, embeddings, speaker_labels,
file_uuid)
print(f" Step 6 time: {time.time() - t6:.2f}s")
# ── Step 7: High-quality classification ──
if file_uuid:
print("\n[Step 7] Classifying high-quality embeddings...")
t7 = time.time()
references = self._classify_high_quality_speakers(
segments, embeddings, speaker_labels, file_uuid,
wav, sample_rate, quality_threshold
)
if references:
result["references"] = references
print(f" Step 7 time: {time.time() - t7:.2f}s")
total_time = time.time() - start_time
result["processing_time"] = round(total_time, 2)
if total_duration > 0:
result["realtime_factor"] = round(total_duration / total_time, 2)
# 保存輸出
if output_path:
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
print(f"\n[SelfASRX] Saved to: {output_path}")
print(f"\n[SelfASRX] Done! {len(segments)} segments, "
f"{estimated_n_speakers} speakers, "
f"{total_time:.2f}s")
return result
def resume_from_checkpoint(self, checkpoint_path, audio_path,
output_path=None):
"""從 checkpoint 載入 Steps 1-3 結果,執行 Steps 4-7"""
print(f"\n[SelfASRX] Resuming from checkpoint: {checkpoint_path}")
print("=" * 60)
with open(checkpoint_path, "r", encoding="utf-8") as f:
cp = json.load(f)
if not cp.get("step3_completed"):
error_msg = f"Checkpoint step3 not completed (progress: {cp.get('step3_progress', '?')})"
print(f"[SelfASRX] {error_msg}")
return {"error": error_msg, "segments": []}
wav, sample_rate = _load_audio(audio_path)
refined_segments = [tuple(s) for s in cp["refined_segments"]]
segment_texts = cp["segment_texts"]
language = cp.get("language", "")
total_duration = cp.get("total_duration", 0)
file_uuid = cp.get("file_uuid")
max_speakers = cp.get("max_speakers", 10)
quality_threshold = cp.get("quality_threshold", 0.85)
print(f" Loaded checkpoint: {len(refined_segments)} segments, "
f"language={language}, duration={total_duration:.2f}s")
start_time = time.time()
# ── Step 4: ECAPA-TDNN per refined segment ──
print("\n[Step 4] Speaker embedding extraction...")
t4 = time.time()
audio_segments = []
for start_sec, end_sec in refined_segments:
s = int(start_sec * sample_rate)
e = int(end_sec * sample_rate)
audio_segments.append(wav[s:min(e, len(wav))])
from speaker_encoder import extract_speaker_embeddings_batch, normalize_embeddings
embeddings = extract_speaker_embeddings_batch(
self.speaker_encoder, audio_segments, sample_rate
)
embeddings = normalize_embeddings(embeddings)
print(f" Embeddings: {embeddings.shape}")
print(f" Step 4 time: {time.time() - t4:.2f}s")
# ── Step 5: AgglomerativeClustering ──
print("\n[Step 5] Speaker clustering...")
t5 = time.time()
from speaker_cluster_fixed import robust_speaker_clustering
speaker_labels, estimated_n_speakers = robust_speaker_clustering(
embeddings, n_speakers=None, max_speakers=max_speakers
)
print(f" Speakers: {estimated_n_speakers}")
print(f" Step 5 time: {time.time() - t5:.2f}s")
# 品質計算
qualities = compute_embedding_quality(embeddings, speaker_labels)
# 建立輸出 segments
segments = []
for i, ((start_sec, end_sec), label) in enumerate(
zip(refined_segments, speaker_labels)):
seg = {
"start": round(start_sec, 3),
"end": round(end_sec, 3),
"start_frame": int(start_sec * 30),
"end_frame": int(end_sec * 30),
"text": segment_texts[i]["text"],
"language": segment_texts[i]["language"],
"lang_prob": segment_texts[i]["lang_prob"],
"speaker": f"SPEAKER_{int(label)}",
"speaker_id": f"SPEAKER_{int(label)}",
"quality": float(qualities[i]),
}
segments.append(seg)
# 統計
speaker_stats = {}
for seg in segments:
spk = seg["speaker_id"]
dur = seg["end"] - seg["start"]
if spk not in speaker_stats:
speaker_stats[spk] = {"count": 0, "duration": 0}
speaker_stats[spk]["count"] += 1
speaker_stats[spk]["duration"] += dur
result = {
"language": language or "",
"segments": segments,
"n_speakers": int(estimated_n_speakers),
"speaker_stats": speaker_stats,
"total_duration": total_duration,
"n_segments": len(segments),
}
# ── Step 6: Store embeddings in Qdrant ──
if file_uuid:
print("\n[Step 6] Storing embeddings in Qdrant...")
t6 = time.time()
self._store_speaker_embeddings(segments, embeddings, speaker_labels,
file_uuid)
print(f" Step 6 time: {time.time() - t6:.2f}s")
# ── Step 7: High-quality classification ──
if file_uuid:
print("\n[Step 7] Classifying high-quality embeddings...")
t7 = time.time()
references = self._classify_high_quality_speakers(
segments, embeddings, speaker_labels, file_uuid,
wav, sample_rate, quality_threshold
)
if references:
result["references"] = references
print(f" Step 7 time: {time.time() - t7:.2f}s")
total_time = time.time() - start_time
result["processing_time"] = round(total_time, 2)
if total_duration > 0:
result["realtime_factor"] = round(total_duration / total_time, 2)
# 保存輸出
if output_path:
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
print(f"\n[SelfASRX] Saved to: {output_path}")
print(f"\n[SelfASRX] Done! {len(segments)} segments, "
f"{estimated_n_speakers} speakers, "
f"{total_time:.2f}s")
return result
# ── Internal helpers ──
def _vad_scan_segment(self, wav, sample_rate, start_sec, end_sec):
"""VAD 細切單一段落"""
from vad import scan_within_segment
return scan_within_segment(
wav, sample_rate, start_sec, end_sec,
self.vad_model, self.vad_utils
)
def _transcribe_segment(self, wav, sample_rate, start_sec, end_sec):
"""轉錄單一段落"""
from whisper_local import transcribe_segment
return transcribe_segment(wav, sample_rate, start_sec, end_sec, self.whisper)
def _store_speaker_embeddings(self, segments, embeddings, labels, file_uuid):
"""Step 6: 所有 embedding 存入 Qdrant"""
if not self._ensure_qdrant():
return
points = []
for i, (seg, emb, label) in enumerate(
zip(segments, embeddings, labels)):
point_id = _hash_point_id(file_uuid, f"{i}")
points.append({
"id": point_id,
"vector": emb.tolist(),
"payload": {
"type": "speaker_embedding",
"file_uuid": file_uuid,
"speaker_id": seg["speaker_id"],
"text": seg["text"],
"language": seg["language"],
"start_time": seg["start"],
"end_time": seg["end"],
}
})
ok = _qdrant_upsert(self.qdrant_url, self.qdrant_api_key,
self.qdrant_collection, points)
if ok:
print(f" Stored {len(points)} speaker embeddings to Qdrant")
return ok
def _classify_high_quality_speakers(self, segments, embeddings, labels,
file_uuid, wav, sample_rate,
threshold=0.85):
"""Step 7: 高品質聲紋分級 + 性別分類 → Qdrant reference"""
qualities = compute_embedding_quality(embeddings, labels)
high_mask = qualities >= threshold
if not np.any(high_mask):
print(" No high-quality embeddings found")
return []
unique_labels = set(labels)
references = []
for label in unique_labels:
mask = (labels == label) & high_mask
if not np.any(mask):
continue
high_indices = [i for i in range(len(segments)) if mask[i]]
high_segs = [segments[i] for i in high_indices]
# 取品質最高的 segment index
best_idx = high_indices[int(np.argmax(qualities[mask]))]
best_seg = segments[best_idx]
centroid = np.mean(embeddings[mask], axis=0)
norm = np.linalg.norm(centroid)
if norm > 0:
centroid = centroid / norm
avg_quality = float(np.mean(qualities[mask]))
speaker_id = f"SPEAKER_{int(label)}"
text_samples = [s["text"] for s in high_segs[:5] if s["text"]]
total_dur = sum(s["end"] - s["start"] for s in high_segs)
ref_id = _hash_point_id(file_uuid, f"ref_{label}")
ref_payload = {
"type": "speaker_reference",
"file_uuid": file_uuid,
"speaker_id": speaker_id,
"n_segments": int(np.sum(mask)),
"avg_quality": avg_quality,
"total_duration": round(total_dur, 2),
"language": best_seg.get("language", ""),
"text_samples": text_samples,
}
# 性別分類:用最佳 segment 的音頻
if self.gender_classifier is not None:
try:
import torch
s = int(best_seg["start"] * sample_rate)
e = int(best_seg["end"] * sample_rate)
seg_wav = wav[s:min(e, len(wav))]
seg_tensor = torch.from_numpy(seg_wav).float().unsqueeze(0)
# SpeechBrain gender classifier 接受音頻
out = self.gender_classifier.classify_batch(seg_tensor)
probs = torch.softmax(out[0], dim=-1).squeeze().cpu().detach().numpy()
if len(probs) >= 2:
idx = int(np.argmax(probs))
ref_payload["gender"] = "male" if idx == 0 else "female"
ref_payload["gender_conf"] = float(probs[idx])
else:
ref_payload["gender"] = "unknown"
ref_payload["gender_conf"] = 0.0
except Exception as e:
print(f"[Gender] Classify error: {e}")
ref_payload["gender"] = "unknown"
ref_payload["gender_conf"] = 0.0
else:
ref_payload["gender"] = "unknown"
ref_payload["gender_conf"] = 0.0
_qdrant_upsert(self.qdrant_url, self.qdrant_api_key,
self.qdrant_collection, [{
"id": ref_id,
"vector": centroid.tolist(),
"payload": ref_payload,
}])
references.append({
"speaker_id": speaker_id,
"n_segments": int(np.sum(mask)),
"avg_quality": avg_quality,
"gender": ref_payload["gender"],
})
print(f" Ref: {speaker_id}, gender={ref_payload['gender']}"
f" ({ref_payload['gender_conf']:.2f}), q={avg_quality:.3f}")
return references
def _ensure_qdrant(self):
"""確保 Qdrant collection 可用"""
if not self._qdrant_ok:
ok = _ensure_speaker_collection(
self.qdrant_url, self.qdrant_api_key, self.qdrant_collection
)
self._qdrant_ok = ok
return self._qdrant_ok
def main():
import argparse
parser = argparse.ArgumentParser(description="SelfASRX - Hybrid Speaker Diarization")
parser.add_argument("audio_path", help="Path to audio file (WAV)")
parser.add_argument("-o", "--output", help="Output JSON path")
parser.add_argument("--file-uuid", help="File UUID for Qdrant storage")
parser.add_argument("--max-speakers", type=int, default=10)
parser.add_argument("--quality-threshold", type=float, default=0.85)
parser.add_argument("--resume", help="Checkpoint path to resume from")
parser.add_argument("--checkpoint", help="Save checkpoint path after Step 3")
args = parser.parse_args()
asrx = SelfASRXFixed()
if args.resume:
if not Path(args.resume).exists():
print(f"Error: Checkpoint not found: {args.resume}")
sys.exit(1)
result = asrx.resume_from_checkpoint(
args.resume, args.audio_path,
output_path=args.output,
)
else:
if not Path(args.audio_path).exists():
print(f"Error: Audio file not found: {args.audio_path}")
sys.exit(1)
result = asrx.process(
args.audio_path,
output_path=args.output,
file_uuid=args.file_uuid,
max_speakers=args.max_speakers,
quality_threshold=args.quality_threshold,
checkpoint_path=args.checkpoint,
)
if "error" not in result:
print("\n[Summary]")
print(f" Duration: {result['total_duration']:.2f}s")
print(f" Segments: {result['n_segments']}")
print(f" Speakers: {result['n_speakers']}")
if "references" in result:
for ref in result["references"]:
print(f" {ref['speaker_id']}: gender={ref['gender']}, "
f"quality={ref['avg_quality']:.3f}")
if __name__ == "__main__":
main()