#!/usr/bin/env python3
|
# -*- coding: utf-8 -*-
|
"""视频级评估:与推理 pipeline 一致的时序平滑 + 顺序约束。"""
|
import argparse
|
import json
|
import os
|
import subprocess
|
import sys
|
import tempfile
|
from pathlib import Path
|
|
import numpy as np
|
import onnxruntime as ort
|
import yaml
|
from PIL import Image
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
from app.temporal import find_peaks_ordered, smooth_scores
|
|
|
def load_config(path):
|
with open(path, encoding="utf-8") as f:
|
return yaml.safe_load(f)
|
|
|
def resolve_model(models_dir, name):
|
version_path = models_dir / "version.json"
|
if version_path.is_file():
|
meta = json.loads(version_path.read_text(encoding="utf-8"))
|
key = f"{name}_model"
|
if meta.get(key):
|
p = models_dir / meta[key]
|
if p.is_file():
|
return p
|
for suffix in (f"{name}_int8.onnx", f"{name}.onnx"):
|
p = models_dir / suffix
|
if p.is_file():
|
return p
|
return None
|
|
|
def ffprobe_duration(video_path):
|
cmd = [
|
"ffprobe", "-v", "error", "-show_entries", "format=duration",
|
"-of", "default=noprint_wrappers=1:nokey=1", video_path,
|
]
|
try:
|
out = subprocess.check_output(cmd, stderr=subprocess.DEVNULL, text=True).strip()
|
return float(out) if out else 0.0
|
except Exception:
|
return 0.0
|
|
|
def download_if_needed(video_path, cache_dir):
|
if os.path.isfile(video_path):
|
return video_path, None
|
if not video_path.startswith("http"):
|
return None, None
|
import httpx
|
os.makedirs(cache_dir, exist_ok=True)
|
local = os.path.join(cache_dir, "eval_tmp.mp4")
|
with httpx.stream("GET", video_path, timeout=600.0, follow_redirects=True) as r:
|
r.raise_for_status()
|
with open(local, "wb") as f:
|
for chunk in r.iter_bytes():
|
f.write(chunk)
|
return local, local
|
|
|
def sample_scores(video_path, session, sample_fps, image_size):
|
duration = ffprobe_duration(video_path)
|
if duration <= 0:
|
return [], duration
|
times, scores = [], []
|
t, step = 0.0, 1.0 / sample_fps
|
tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
tmp.close()
|
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
try:
|
while t <= duration:
|
subprocess.run(
|
["ffmpeg", "-y", "-ss", str(t), "-i", video_path, "-frames:v", "1", "-q:v", "2", tmp.name],
|
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
|
)
|
if os.path.isfile(tmp.name) and os.path.getsize(tmp.name) > 0:
|
img = Image.open(tmp.name).convert("RGB").resize((image_size, image_size))
|
arr = (np.array(img).astype(np.float32) / 255.0 - mean) / std
|
arr = arr.transpose(2, 0, 1)[None].astype(np.float32)
|
logit = session.run(None, {"input": arr})[0][0][0]
|
prob = float(1.0 / (1.0 + np.exp(-logit)))
|
times.append(round(t, 2))
|
scores.append(prob)
|
t += step
|
finally:
|
if os.path.isfile(tmp.name):
|
os.remove(tmp.name)
|
return list(zip(times, scores)), duration
|
|
|
def main():
|
parser = argparse.ArgumentParser()
|
parser.add_argument("-c", "--config", default=str(Path(__file__).parent / "config.yaml"))
|
parser.add_argument("--annotations", default="../data/annotations.jsonl")
|
parser.add_argument("--models-dir", default="../models")
|
parser.add_argument("--sample-fps", type=float, default=0.5)
|
parser.add_argument("--split", default="val")
|
args = parser.parse_args()
|
cfg = load_config(args.config)
|
size = cfg["model"]["image_size"]
|
models_dir = Path(args.config).resolve().parent / args.models_dir
|
|
sf_path = resolve_model(models_dir, "storefront")
|
ho_path = resolve_model(models_dir, "handover")
|
if not sf_path or not ho_path:
|
print("未找到 ONNX 模型,请先 export_onnx.py")
|
return
|
|
sf_sess = ort.InferenceSession(str(sf_path), providers=["CPUExecutionProvider"])
|
ho_sess = ort.InferenceSession(str(ho_path), providers=["CPUExecutionProvider"])
|
|
ann_path = Path(args.config).resolve().parent / args.annotations
|
items = []
|
with open(ann_path, encoding="utf-8") as f:
|
for line in f:
|
if line.strip():
|
items.append(json.loads(line))
|
val_items = [i for i in items if i.get("split") == args.split] or items
|
|
cache_dir = str(Path(args.config).resolve().parent / "../data/eval_cache")
|
sf_mae = ho_mae = order_ok = hit5 = n = 0
|
for item in val_items:
|
vp = item["video_path"]
|
local, tmp = download_if_needed(vp, cache_dir)
|
if not local:
|
print(f"跳过 media_id={item['media_id']}: 视频不可访问")
|
continue
|
sf_scores, duration = sample_scores(local, sf_sess, args.sample_fps, size)
|
ho_scores, _ = sample_scores(local, ho_sess, args.sample_fps, size)
|
sf_peak, ho_peak = find_peaks_ordered(sf_scores, ho_scores, duration)
|
pred_sf = sf_peak[0] if sf_peak else 0.0
|
pred_ho = ho_peak[0] if ho_peak else 0.0
|
gt_sf = float(item["storefront_time_sec"])
|
gt_ho = float(item["handover_time_sec"])
|
sf_err = abs(pred_sf - gt_sf)
|
ho_err = abs(pred_ho - gt_ho)
|
sf_mae += sf_err
|
ho_mae += ho_err
|
if pred_ho > pred_sf:
|
order_ok += 1
|
if sf_err <= 5 and ho_err <= 5:
|
hit5 += 1
|
n += 1
|
print(
|
f"media_id={item['media_id']} gt_sf={gt_sf}s pred_sf={pred_sf:.1f}s err={sf_err:.1f}s | "
|
f"gt_ho={gt_ho}s pred_ho={pred_ho:.1f}s err={ho_err:.1f}s"
|
)
|
if tmp and os.path.isfile(tmp):
|
os.remove(tmp)
|
|
if n == 0:
|
print("无可用验证样本")
|
return
|
print("---")
|
print(f"样本数={n}")
|
print(f"门头 MAE={sf_mae/n:.2f}s 交付 MAE={ho_mae/n:.2f}s")
|
print(f"顺序正确率={order_ok/n*100:.1f}% 双5秒命中率={hit5/n*100:.1f}%")
|
|
|
if __name__ == "__main__":
|
main()
|