# -*- coding: utf-8 -*-
|
import logging
|
import os
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from app.asr import asr_available, match_keywords, transcribe
|
from app.frame_sampler import cleanup_frames, sample_frames
|
from app.fusion import fuse_results
|
from app.onnx_infer import ModelRegistry, score_frames
|
from app.quality import refine_time_in_window
|
from app.schemas import AnalyzeRequest, AnalyzeResponse, KeywordConfig
|
from app.temporal import find_peaks_ordered
|
from app.video_io import temp_video
|
|
logger = logging.getLogger(__name__)
|
|
_registry: ModelRegistry = None
|
|
|
def get_registry() -> ModelRegistry:
|
global _registry
|
if _registry is None:
|
model_dir = os.environ.get("SNAPSHOT_MODEL_DIR", os.path.join(os.path.dirname(__file__), "..", "models"))
|
_registry = ModelRegistry(os.path.abspath(model_dir))
|
return _registry
|
|
|
def run_analyze(req: AnalyzeRequest) -> AnalyzeResponse:
|
registry = get_registry()
|
if not registry.ready:
|
return AnalyzeResponse(
|
success=False,
|
model_version=registry.version,
|
message="ONNX 模型未加载,请将 storefront/handover ONNX 放入 models/ 目录",
|
)
|
if not req.video_url:
|
return AnalyzeResponse(success=False, model_version=registry.version, message="video_url 不能为空")
|
|
keywords = req.keywords or KeywordConfig()
|
sample_fps = req.sample_fps if req.sample_fps and req.sample_fps > 0 else float(os.environ.get("SNAPSHOT_SAMPLE_FPS", "0.5"))
|
|
try:
|
with temp_video(req.video_url, req.duration_sec or 0.0) as (video_path, duration):
|
if req.duration_sec and req.duration_sec > 0:
|
duration = req.duration_sec
|
|
asr_hits = []
|
sf_vision = ho_vision = None
|
|
def vision_task():
|
frames = sample_frames(video_path, sample_fps, duration)
|
try:
|
sf_scores, ho_scores = score_frames(registry, frames)
|
return find_peaks_ordered(sf_scores, ho_scores, duration), frames
|
finally:
|
cleanup_frames(frames)
|
|
def asr_task():
|
if not req.enable_asr or not asr_available():
|
return []
|
segments = transcribe(video_path)
|
return match_keywords(segments, keywords)
|
|
with ThreadPoolExecutor(max_workers=2) as pool:
|
futures = {pool.submit(vision_task): "vision"}
|
if req.enable_asr:
|
futures[pool.submit(asr_task)] = "asr"
|
vision_result = None
|
for fut in as_completed(futures):
|
if futures[fut] == "vision":
|
vision_result, _ = fut.result()
|
else:
|
asr_hits = fut.result()
|
|
if vision_result:
|
sf_vision, ho_vision = vision_result
|
|
storefront, handover = fuse_results(sf_vision, ho_vision, asr_hits, keywords, duration)
|
|
if storefront:
|
t, _ = refine_time_in_window(video_path, storefront.time_sec)
|
storefront.time_sec = t
|
if handover:
|
t, _ = refine_time_in_window(video_path, handover.time_sec)
|
handover.time_sec = t
|
|
if not storefront or not handover:
|
return AnalyzeResponse(
|
success=False,
|
model_version=registry.version,
|
duration_sec=duration,
|
storefront=storefront,
|
handover=handover,
|
asr_hits=asr_hits,
|
message="未能检测到门头或交付时刻",
|
)
|
|
return AnalyzeResponse(
|
success=True,
|
model_version=registry.version,
|
duration_sec=round(duration, 2),
|
storefront=storefront,
|
handover=handover,
|
asr_hits=asr_hits,
|
)
|
except Exception as e:
|
logger.exception("分析失败 media_id=%s", req.media_id)
|
return AnalyzeResponse(
|
success=False,
|
model_version=registry.version,
|
message=str(e),
|
)
|