doum
3 天以前 ce44d803b73a65b2cc31db5bcc662139029463d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# -*- coding: utf-8 -*-
import logging
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
 
from app.asr import asr_available, match_keywords, transcribe
from app.frame_sampler import cleanup_frames, sample_frames
from app.fusion import fuse_results
from app.onnx_infer import ModelRegistry, score_frames
from app.quality import refine_time_in_window
from app.schemas import AnalyzeRequest, AnalyzeResponse, KeywordConfig
from app.temporal import find_peaks_ordered
from app.video_io import temp_video
 
logger = logging.getLogger(__name__)
 
_registry: ModelRegistry = None
 
 
def get_registry() -> ModelRegistry:
    global _registry
    if _registry is None:
        model_dir = os.environ.get("SNAPSHOT_MODEL_DIR", os.path.join(os.path.dirname(__file__), "..", "models"))
        _registry = ModelRegistry(os.path.abspath(model_dir))
    return _registry
 
 
def run_analyze(req: AnalyzeRequest) -> AnalyzeResponse:
    registry = get_registry()
    if not registry.ready:
        return AnalyzeResponse(
            success=False,
            model_version=registry.version,
            message="ONNX 模型未加载,请将 storefront/handover ONNX 放入 models/ 目录",
        )
    if not req.video_url:
        return AnalyzeResponse(success=False, model_version=registry.version, message="video_url 不能为空")
 
    keywords = req.keywords or KeywordConfig()
    sample_fps = req.sample_fps if req.sample_fps and req.sample_fps > 0 else float(os.environ.get("SNAPSHOT_SAMPLE_FPS", "0.5"))
 
    try:
        with temp_video(req.video_url, req.duration_sec or 0.0) as (video_path, duration):
            if req.duration_sec and req.duration_sec > 0:
                duration = req.duration_sec
 
            asr_hits = []
            sf_vision = ho_vision = None
 
            def vision_task():
                frames = sample_frames(video_path, sample_fps, duration)
                try:
                    sf_scores, ho_scores = score_frames(registry, frames)
                    return find_peaks_ordered(sf_scores, ho_scores, duration), frames
                finally:
                    cleanup_frames(frames)
 
            def asr_task():
                if not req.enable_asr or not asr_available():
                    return []
                segments = transcribe(video_path)
                return match_keywords(segments, keywords)
 
            with ThreadPoolExecutor(max_workers=2) as pool:
                futures = {pool.submit(vision_task): "vision"}
                if req.enable_asr:
                    futures[pool.submit(asr_task)] = "asr"
                vision_result = None
                for fut in as_completed(futures):
                    if futures[fut] == "vision":
                        vision_result, _ = fut.result()
                    else:
                        asr_hits = fut.result()
 
            if vision_result:
                sf_vision, ho_vision = vision_result
 
            storefront, handover = fuse_results(sf_vision, ho_vision, asr_hits, keywords, duration)
 
            if storefront:
                t, _ = refine_time_in_window(video_path, storefront.time_sec)
                storefront.time_sec = t
            if handover:
                t, _ = refine_time_in_window(video_path, handover.time_sec)
                handover.time_sec = t
 
            if not storefront or not handover:
                return AnalyzeResponse(
                    success=False,
                    model_version=registry.version,
                    duration_sec=duration,
                    storefront=storefront,
                    handover=handover,
                    asr_hits=asr_hits,
                    message="未能检测到门头或交付时刻",
                )
 
            return AnalyzeResponse(
                success=True,
                model_version=registry.version,
                duration_sec=round(duration, 2),
                storefront=storefront,
                handover=handover,
                asr_hits=asr_hits,
            )
    except Exception as e:
        logger.exception("分析失败 media_id=%s", req.media_id)
        return AnalyzeResponse(
            success=False,
            model_version=registry.version,
            message=str(e),
        )