doum
2 天以前 ce44d803b73a65b2cc31db5bcc662139029463d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# -*- coding: utf-8 -*-
from typing import List, Optional, Tuple
 
from app.schemas import AsrHit, KeywordConfig, SnapshotHit
 
 
def fuse_hit(
    vision: Optional[Tuple[float, float]],
    asr_time: Optional[float],
    asr_weight: float = 0.7,
    window: float = 2.0,
) -> Optional[SnapshotHit]:
    if vision and asr_time is not None:
        vt, vc = vision
        if abs(vt - asr_time) <= window * 3:
            t = round(asr_weight * asr_time + (1 - asr_weight) * vt, 2)
            conf = min(0.99, vc + 0.1)
            return SnapshotHit(time_sec=t, confidence=round(conf, 4), source="hybrid")
        t = round(asr_weight * asr_time + (1 - asr_weight) * vt, 2)
        return SnapshotHit(time_sec=t, confidence=round(vc, 4), source="hybrid")
    if asr_time is not None:
        return SnapshotHit(time_sec=round(asr_time, 2), confidence=0.75, source="asr")
    if vision:
        return SnapshotHit(time_sec=round(vision[0], 2), confidence=round(vision[1], 4), source="ai")
    return None
 
 
def fuse_results(
    sf_vision: Optional[Tuple[float, float]],
    ho_vision: Optional[Tuple[float, float]],
    asr_hits: List[AsrHit],
    keywords: KeywordConfig,
    duration: float,
) -> Tuple[Optional[SnapshotHit], Optional[SnapshotHit]]:
    from app.asr import best_asr_time
 
    sf_asr = best_asr_time(asr_hits, keywords.storefront)
    ho_asr = best_asr_time(asr_hits, keywords.handover)
 
    storefront = fuse_hit(sf_vision, sf_asr)
    min_ho = (storefront.time_sec + 30.0) if storefront else 0.0
    if ho_vision and ho_vision[0] < min_ho:
        ho_vision = None
    if ho_asr is not None and ho_asr < min_ho:
        ho_asr = None
 
    handover = fuse_hit(ho_vision, ho_asr)
 
    if storefront and handover and handover.time_sec <= storefront.time_sec:
        handover.time_sec = round(min(duration - 1, storefront.time_sec + 60), 2)
 
    return storefront, handover