doum
2 天以前 ce44d803b73a65b2cc31db5bcc662139029463d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""导出 PyTorch 权重为 ONNX,并可选 INT8 动态量化。"""
import argparse
import json
import os
import shutil
import subprocess
import sys
from datetime import datetime
from pathlib import Path
 
import torch
import yaml
from train import build_model
 
 
def load_config(path):
    with open(path, encoding="utf-8") as f:
        cfg = yaml.safe_load(f)
    base = Path(path).resolve().parent
    out = cfg["export"]["output_dir"]
    if not os.path.isabs(out):
        cfg["export"]["output_dir"] = str((base / out).resolve())
    return cfg
 
 
def try_quantize_subprocess(onnx_path: str, int8_path: str) -> bool:
    """在子进程执行 ORT 量化,避免主进程因 Windows 上 ORT bug 崩溃。"""
    code = (
        "import sys\n"
        "from onnxruntime.quantization import QuantType, quantize_dynamic\n"
        "quantize_dynamic(sys.argv[1], sys.argv[2], weight_type=QuantType.QUInt8)\n"
        "print('OK')\n"
    )
    try:
        result = subprocess.run(
            [sys.executable, "-c", code, onnx_path, int8_path],
            capture_output=True,
            text=True,
            timeout=600,
        )
    except subprocess.TimeoutExpired:
        print(f"量化超时: {int8_path}")
        return False
    if result.returncode != 0:
        err = (result.stderr or result.stdout or "").strip()
        if err:
            print(f"量化失败 exit={result.returncode}: {err[:500]}")
        else:
            print(f"量化失败 exit={result.returncode}(Windows 上 ORT 量化器可能崩溃)")
        return False
    return os.path.isfile(int8_path) and os.path.getsize(int8_path) > 0
 
 
def export_one(task, cfg, do_quantize: bool):
    out_dir = cfg["export"]["output_dir"]
    os.makedirs(out_dir, exist_ok=True)
    size = cfg["model"]["image_size"]
    ckpt = os.path.join(out_dir, f"{task}.pt")
    if not os.path.isfile(ckpt):
        raise FileNotFoundError(f"未找到权重 {ckpt},请先运行 train.py")
 
    model = build_model()
    state = torch.load(ckpt, map_location="cpu")
    model.load_state_dict(state)
    model.eval()
 
    dummy = torch.randn(1, 3, size, size)
    onnx_path = os.path.join(out_dir, f"{task}.onnx")
    torch.onnx.export(
        model,
        dummy,
        onnx_path,
        input_names=["input"],
        output_names=["logits"],
        dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
        opset_version=cfg["export"]["opset"],
    )
    print(f"导出 float ONNX: {onnx_path}")
 
    int8_path = os.path.join(out_dir, f"{task}_int8.onnx")
    if do_quantize and cfg["export"].get("quantize"):
        if try_quantize_subprocess(onnx_path, int8_path):
            print(f"量化 INT8: {int8_path}")
            return os.path.basename(int8_path)
        print(
            f"WARN: {task}_int8.onnx 量化未成功,推理服务将使用 float 模型 {task}.onnx\n"
            "      常见原因: Windows 上 onnxruntime.quantization 崩溃。"
            " 可升级/降级 onnxruntime,或使用 --no-quantize 跳过量化。"
        )
    elif do_quantize:
        print(f"跳过量化(config quantize=false)")
 
    return os.path.basename(onnx_path)
 
 
def main():
    parser = argparse.ArgumentParser(description="导出 storefront/handover ONNX 模型")
    parser.add_argument("-c", "--config", default=str(Path(__file__).parent / "config.yaml"))
    parser.add_argument("--version", default="1.0.0")
    parser.add_argument("--no-quantize", action="store_true", help="仅导出 float .onnx,不尝试 INT8")
    args = parser.parse_args()
    cfg = load_config(args.config)
    do_quantize = not args.no_quantize
 
    paths = {}
    for task in ("storefront", "handover"):
        paths[task] = export_one(task, cfg, do_quantize)
 
    version_info = {
        "model_version": args.version,
        "storefront_model": paths["storefront"],
        "handover_model": paths["handover"],
        "image_size": cfg["model"]["image_size"],
        "quantized": do_quantize and paths["storefront"].endswith("_int8.onnx"),
        "exported_at": datetime.utcnow().isoformat() + "Z",
    }
    version_path = os.path.join(cfg["export"]["output_dir"], "version.json")
    with open(version_path, "w", encoding="utf-8") as f:
        json.dump(version_info, f, indent=2, ensure_ascii=False)
    print(f"写入 {version_path}")
    print(f"  storefront -> {paths['storefront']}")
    print(f"  handover   -> {paths['handover']}")
 
 
if __name__ == "__main__":
    main()