#!/usr/bin/env python3
|
# -*- coding: utf-8 -*-
|
"""训练 storefront / handover 二分类模型(MobileNetV3-Small)。"""
|
import argparse
|
import csv
|
import os
|
from pathlib import Path
|
|
import torch
|
import torch.nn as nn
|
import yaml
|
from PIL import Image
|
from torch.utils.data import DataLoader, Dataset
|
from torchvision import models, transforms
|
from tqdm import tqdm
|
|
|
class FrameDataset(Dataset):
|
def __init__(self, rows, frames_dir, transform):
|
self.rows = rows
|
self.frames_dir = frames_dir
|
self.transform = transform
|
|
def __len__(self):
|
return len(self.rows)
|
|
def __getitem__(self, idx):
|
path, label = self.rows[idx]
|
img = Image.open(os.path.join(self.frames_dir, path)).convert("RGB")
|
return self.transform(img), torch.tensor(label, dtype=torch.float32)
|
|
|
def load_config(path):
|
with open(path, encoding="utf-8") as f:
|
cfg = yaml.safe_load(f)
|
base = Path(path).resolve().parent
|
for key in ("frames_dir", "labels_csv"):
|
p = cfg["data"][key]
|
if not os.path.isabs(p):
|
cfg["data"][key] = str((base / p).resolve())
|
out = cfg["export"]["output_dir"]
|
if not os.path.isabs(out):
|
cfg["export"]["output_dir"] = str((base / out).resolve())
|
return cfg
|
|
|
def read_labels(csv_path, task, split=None):
|
"""读取某任务的训练样本:该 task 行含 label 0/1;兼容旧版 task=other 负样本。"""
|
rows = []
|
with open(csv_path, newline="", encoding="utf-8") as f:
|
for r in csv.DictReader(f):
|
if split and r.get("split") and r["split"] != split:
|
continue
|
row_task = r["task"]
|
label = int(r["label"])
|
if row_task == task:
|
rows.append((r["frame_path"], label))
|
elif row_task == "other" and label == 0:
|
rows.append((r["frame_path"], 0))
|
return rows
|
|
|
def count_labels(rows):
|
pos = sum(1 for _, y in rows if y == 1)
|
neg = len(rows) - pos
|
return pos, neg
|
|
|
def build_model(freeze_backbone=False):
|
m = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
|
if freeze_backbone:
|
for p in m.features.parameters():
|
p.requires_grad = False
|
in_f = m.classifier[0].in_features
|
m.classifier = nn.Sequential(
|
nn.Linear(in_f, 128),
|
nn.Hardswish(),
|
nn.Dropout(0.2),
|
nn.Linear(128, 1),
|
)
|
return m
|
|
|
def build_transforms(size, augment=False):
|
if augment:
|
return transforms.Compose([
|
transforms.Resize((size, size)),
|
transforms.RandomApply([
|
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
|
], p=0.7),
|
transforms.RandomApply([
|
transforms.GaussianBlur(kernel_size=3),
|
], p=0.2),
|
transforms.ToTensor(),
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
])
|
return transforms.Compose([
|
transforms.Resize((size, size)),
|
transforms.ToTensor(),
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
])
|
|
|
def train_task(task, cfg):
|
frames_dir = cfg["data"]["frames_dir"]
|
train_rows = read_labels(cfg["data"]["labels_csv"], task, "train")
|
val_rows = read_labels(cfg["data"]["labels_csv"], task, "val")
|
if not train_rows:
|
train_rows = read_labels(cfg["data"]["labels_csv"], task)
|
if not val_rows:
|
val_rows = train_rows[: max(1, len(train_rows) // 10)]
|
|
train_pos, train_neg = count_labels(train_rows)
|
val_pos, val_neg = count_labels(val_rows)
|
print(f"[{task}] train pos={train_pos} neg={train_neg} | val pos={val_pos} neg={val_neg}")
|
if train_neg == 0:
|
print(f"WARN: [{task}] 训练集无负样本!请重新运行 prepare_dataset.py 生成 labels.csv")
|
if train_pos < 10:
|
print(f"WARN: [{task}] 正样本过少(<10),建议标注至少 80+ 条视频")
|
|
size = cfg["model"]["image_size"]
|
train_loader = DataLoader(
|
FrameDataset(train_rows, frames_dir, build_transforms(size, augment=True)),
|
batch_size=cfg["train"]["batch_size"],
|
shuffle=True,
|
num_workers=cfg["train"]["num_workers"],
|
)
|
val_loader = DataLoader(
|
FrameDataset(val_rows, frames_dir, build_transforms(size, augment=False)),
|
batch_size=cfg["train"]["batch_size"],
|
shuffle=False,
|
num_workers=cfg["train"]["num_workers"],
|
)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
freeze = cfg["train"].get("freeze_backbone", True) and train_pos < 200
|
model = build_model(freeze_backbone=freeze).to(device)
|
if freeze:
|
print(f"[{task}] 小样本模式:冻结 backbone,仅训练分类头")
|
|
pos_count = max(train_pos, 1)
|
neg_count = max(train_neg, 1)
|
auto_pos_weight = min(neg_count / pos_count, 10.0)
|
pos_weight_val = cfg["model"].get("pos_weight") or auto_pos_weight
|
pos_weight = torch.tensor([pos_weight_val], device=device)
|
print(f"[{task}] pos_weight={pos_weight_val:.2f}")
|
|
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
lr = cfg["train"]["lr"]
|
if freeze:
|
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
|
else:
|
optimizer = torch.optim.Adam([
|
{"params": model.features.parameters(), "lr": lr * 0.1},
|
{"params": model.classifier.parameters(), "lr": lr},
|
])
|
|
best_loss = float("inf")
|
patience = 0
|
os.makedirs(cfg["export"]["output_dir"], exist_ok=True)
|
ckpt_path = os.path.join(cfg["export"]["output_dir"], f"{task}.pt")
|
|
for epoch in range(cfg["train"]["epochs"]):
|
model.train()
|
for x, y in tqdm(train_loader, desc=f"{task} epoch {epoch+1}"):
|
x, y = x.to(device), y.to(device).unsqueeze(1)
|
optimizer.zero_grad()
|
loss = criterion(model(x), y)
|
loss.backward()
|
optimizer.step()
|
|
model.eval()
|
val_loss = 0.0
|
n = 0
|
with torch.no_grad():
|
for x, y in val_loader:
|
x, y = x.to(device), y.to(device).unsqueeze(1)
|
val_loss += criterion(model(x), y).item() * x.size(0)
|
n += x.size(0)
|
val_loss /= max(n, 1)
|
print(f"{task} epoch {epoch+1} val_loss={val_loss:.4f}")
|
if val_loss < best_loss:
|
best_loss = val_loss
|
patience = 0
|
torch.save(model.state_dict(), ckpt_path)
|
else:
|
patience += 1
|
if patience >= cfg["train"]["early_stop_patience"]:
|
break
|
print(f"保存 {ckpt_path}")
|
return ckpt_path
|
|
|
def main():
|
parser = argparse.ArgumentParser()
|
parser.add_argument("-c", "--config", default=str(Path(__file__).parent / "config.yaml"))
|
parser.add_argument("--task", choices=["storefront", "handover", "both"], default="both")
|
args = parser.parse_args()
|
cfg = load_config(args.config)
|
tasks = ["storefront", "handover"] if args.task == "both" else [args.task]
|
for t in tasks:
|
train_task(t, cfg)
|
|
|
if __name__ == "__main__":
|
main()
|