- Update ASR, face, OCR, pose processors - Add release pre-flight check script - Add synonym generation, chunk processing scripts - Add face recognition, stamp search utilities
115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
#!/opt/homebrew/bin/python3.11
|
|
"""
|
|
Test OWL-ViT for "Stamps" Detection
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import cv2
|
|
import torch
|
|
from PIL import Image
|
|
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
|
|
|
UUID = "384b0ff44aaaa1f1"
|
|
VIDEO_PATH = f"output/{UUID}/{UUID}.mp4"
|
|
ASR_PATH = f"output/{UUID}/{UUID}.asr.json"
|
|
OUTPUT_DIR = f"output/{UUID}/owl_vit_results"
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
# 1. Find timestamps where "stamp" is mentioned
|
|
print("🔍 Analyzing ASR for 'stamp' mentions...")
|
|
with open(ASR_PATH) as f:
|
|
asr_data = json.load(f)
|
|
|
|
target_times = []
|
|
for seg in asr_data.get("segments", []):
|
|
text = seg.get("text", "").lower()
|
|
if "stamp" in text:
|
|
target_times.append(seg.get("start", 0))
|
|
print(f" 🗣️ Found: '{seg['text']}' @ {seg['start']:.2f}s")
|
|
|
|
if not target_times:
|
|
print("❌ No mentions of 'stamp' found.")
|
|
exit()
|
|
|
|
# Prioritize timestamps around the "Stamps" chunk (Chunk 833, ~5851s) and the final confrontation (~6700s+)
|
|
# because early mentions might be just dialogue about them without showing them.
|
|
priority_times = [5851.6, 5860.4, 6756.6, 6846.0]
|
|
print(f"🔥 Prioritizing high-probability timestamps: {priority_times}")
|
|
target_times = priority_times
|
|
|
|
print(f"✅ Found {len(target_times)} candidate timestamps.")
|
|
|
|
# 2. Load Model (using base for speed, large is more accurate but slower)
|
|
print("🧠 Loading OWL-ViT model...")
|
|
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
|
|
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
|
|
|
|
# 3. Process Frames
|
|
cap = cv2.VideoCapture(VIDEO_PATH)
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
|
|
for i, t in enumerate(target_times): # Check all target times
|
|
cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
continue
|
|
|
|
# Convert to PIL for model
|
|
image_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
|
|
# Define text queries
|
|
texts = [["a postage stamp", "a stamp on a letter", "a stamp in an album"]]
|
|
|
|
inputs = processor(text=texts, images=image_pil, return_tensors="pt")
|
|
outputs = model(**inputs)
|
|
|
|
# Post-process
|
|
target_sizes = torch.Tensor([image_pil.size[::-1]])
|
|
results = processor.post_process_object_detection(
|
|
outputs=outputs, target_sizes=target_sizes, threshold=0.1
|
|
)
|
|
i = 0
|
|
box_found = False
|
|
for box, score, label in zip(
|
|
results[i]["boxes"], results[i]["scores"], results[i]["labels"]
|
|
):
|
|
if score > 0.15: # Confidence threshold
|
|
box_found = True
|
|
x_min, y_min, x_max, y_max = box.int().tolist()
|
|
label_text = texts[i][label.item()]
|
|
print(f" ✅ Detected '{label_text}' ({score.item():.2f}) at {t:.2f}s")
|
|
|
|
# Draw
|
|
cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
|
|
cv2.putText(
|
|
frame,
|
|
f"{label_text} {score.item():.2f}",
|
|
(x_min, y_min - 10),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.5,
|
|
(0, 255, 0),
|
|
2,
|
|
)
|
|
|
|
if not box_found:
|
|
print(f" ❌ No stamp detected at {t:.2f}s")
|
|
cv2.putText(
|
|
frame,
|
|
"No Stamp Found",
|
|
(50, 50),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1,
|
|
(0, 0, 255),
|
|
2,
|
|
)
|
|
else:
|
|
# Save result
|
|
save_path = os.path.join(OUTPUT_DIR, f"stamp_detect_{int(t)}.jpg")
|
|
cv2.imwrite(save_path, frame)
|
|
print(f" 💾 Saved to {save_path}")
|
|
|
|
cap.release()
|
|
print("🏁 Done.")
|