# -*- coding: utf-8 -*-
|
from typing import List, Optional, Tuple
|
|
from app.schemas import AsrHit, KeywordConfig, SnapshotHit
|
|
|
def fuse_hit(
|
vision: Optional[Tuple[float, float]],
|
asr_time: Optional[float],
|
asr_weight: float = 0.7,
|
window: float = 2.0,
|
) -> Optional[SnapshotHit]:
|
if vision and asr_time is not None:
|
vt, vc = vision
|
if abs(vt - asr_time) <= window * 3:
|
t = round(asr_weight * asr_time + (1 - asr_weight) * vt, 2)
|
conf = min(0.99, vc + 0.1)
|
return SnapshotHit(time_sec=t, confidence=round(conf, 4), source="hybrid")
|
t = round(asr_weight * asr_time + (1 - asr_weight) * vt, 2)
|
return SnapshotHit(time_sec=t, confidence=round(vc, 4), source="hybrid")
|
if asr_time is not None:
|
return SnapshotHit(time_sec=round(asr_time, 2), confidence=0.75, source="asr")
|
if vision:
|
return SnapshotHit(time_sec=round(vision[0], 2), confidence=round(vision[1], 4), source="ai")
|
return None
|
|
|
def fuse_results(
|
sf_vision: Optional[Tuple[float, float]],
|
ho_vision: Optional[Tuple[float, float]],
|
asr_hits: List[AsrHit],
|
keywords: KeywordConfig,
|
duration: float,
|
) -> Tuple[Optional[SnapshotHit], Optional[SnapshotHit]]:
|
from app.asr import best_asr_time
|
|
sf_asr = best_asr_time(asr_hits, keywords.storefront)
|
ho_asr = best_asr_time(asr_hits, keywords.handover)
|
|
storefront = fuse_hit(sf_vision, sf_asr)
|
min_ho = (storefront.time_sec + 30.0) if storefront else 0.0
|
if ho_vision and ho_vision[0] < min_ho:
|
ho_vision = None
|
if ho_asr is not None and ho_asr < min_ho:
|
ho_asr = None
|
|
handover = fuse_hit(ho_vision, ho_asr)
|
|
if storefront and handover and handover.time_sec <= storefront.time_sec:
|
handover.time_sec = round(min(duration - 1, storefront.time_sec + 60), 2)
|
|
return storefront, handover
|