#!/usr/bin/env python3
|
# -*- coding: utf-8 -*-
|
"""导出 PyTorch 权重为 ONNX,并可选 INT8 动态量化。"""
|
import argparse
|
import json
|
import os
|
import shutil
|
import subprocess
|
import sys
|
from datetime import datetime
|
from pathlib import Path
|
|
import torch
|
import yaml
|
from train import build_model
|
|
|
def load_config(path):
|
with open(path, encoding="utf-8") as f:
|
cfg = yaml.safe_load(f)
|
base = Path(path).resolve().parent
|
out = cfg["export"]["output_dir"]
|
if not os.path.isabs(out):
|
cfg["export"]["output_dir"] = str((base / out).resolve())
|
return cfg
|
|
|
def try_quantize_subprocess(onnx_path: str, int8_path: str) -> bool:
|
"""在子进程执行 ORT 量化,避免主进程因 Windows 上 ORT bug 崩溃。"""
|
code = (
|
"import sys\n"
|
"from onnxruntime.quantization import QuantType, quantize_dynamic\n"
|
"quantize_dynamic(sys.argv[1], sys.argv[2], weight_type=QuantType.QUInt8)\n"
|
"print('OK')\n"
|
)
|
try:
|
result = subprocess.run(
|
[sys.executable, "-c", code, onnx_path, int8_path],
|
capture_output=True,
|
text=True,
|
timeout=600,
|
)
|
except subprocess.TimeoutExpired:
|
print(f"量化超时: {int8_path}")
|
return False
|
if result.returncode != 0:
|
err = (result.stderr or result.stdout or "").strip()
|
if err:
|
print(f"量化失败 exit={result.returncode}: {err[:500]}")
|
else:
|
print(f"量化失败 exit={result.returncode}(Windows 上 ORT 量化器可能崩溃)")
|
return False
|
return os.path.isfile(int8_path) and os.path.getsize(int8_path) > 0
|
|
|
def export_one(task, cfg, do_quantize: bool):
|
out_dir = cfg["export"]["output_dir"]
|
os.makedirs(out_dir, exist_ok=True)
|
size = cfg["model"]["image_size"]
|
ckpt = os.path.join(out_dir, f"{task}.pt")
|
if not os.path.isfile(ckpt):
|
raise FileNotFoundError(f"未找到权重 {ckpt},请先运行 train.py")
|
|
model = build_model()
|
state = torch.load(ckpt, map_location="cpu")
|
model.load_state_dict(state)
|
model.eval()
|
|
dummy = torch.randn(1, 3, size, size)
|
onnx_path = os.path.join(out_dir, f"{task}.onnx")
|
torch.onnx.export(
|
model,
|
dummy,
|
onnx_path,
|
input_names=["input"],
|
output_names=["logits"],
|
dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
|
opset_version=cfg["export"]["opset"],
|
)
|
print(f"导出 float ONNX: {onnx_path}")
|
|
int8_path = os.path.join(out_dir, f"{task}_int8.onnx")
|
if do_quantize and cfg["export"].get("quantize"):
|
if try_quantize_subprocess(onnx_path, int8_path):
|
print(f"量化 INT8: {int8_path}")
|
return os.path.basename(int8_path)
|
print(
|
f"WARN: {task}_int8.onnx 量化未成功,推理服务将使用 float 模型 {task}.onnx\n"
|
" 常见原因: Windows 上 onnxruntime.quantization 崩溃。"
|
" 可升级/降级 onnxruntime,或使用 --no-quantize 跳过量化。"
|
)
|
elif do_quantize:
|
print(f"跳过量化(config quantize=false)")
|
|
return os.path.basename(onnx_path)
|
|
|
def main():
|
parser = argparse.ArgumentParser(description="导出 storefront/handover ONNX 模型")
|
parser.add_argument("-c", "--config", default=str(Path(__file__).parent / "config.yaml"))
|
parser.add_argument("--version", default="1.0.0")
|
parser.add_argument("--no-quantize", action="store_true", help="仅导出 float .onnx,不尝试 INT8")
|
args = parser.parse_args()
|
cfg = load_config(args.config)
|
do_quantize = not args.no_quantize
|
|
paths = {}
|
for task in ("storefront", "handover"):
|
paths[task] = export_one(task, cfg, do_quantize)
|
|
version_info = {
|
"model_version": args.version,
|
"storefront_model": paths["storefront"],
|
"handover_model": paths["handover"],
|
"image_size": cfg["model"]["image_size"],
|
"quantized": do_quantize and paths["storefront"].endswith("_int8.onnx"),
|
"exported_at": datetime.utcnow().isoformat() + "Z",
|
}
|
version_path = os.path.join(cfg["export"]["output_dir"], "version.json")
|
with open(version_path, "w", encoding="utf-8") as f:
|
json.dump(version_info, f, indent=2, ensure_ascii=False)
|
print(f"写入 {version_path}")
|
print(f" storefront -> {paths['storefront']}")
|
print(f" handover -> {paths['handover']}")
|
|
|
if __name__ == "__main__":
|
main()
|