PyTorch 模型转换示例
概述
本文档以 ResNet50 为例,演示如何将 PyTorch 模型转换为 ONNX 格式,再使用 Netrans 进行完整转换。由于 PyTorch 的动态图特性,需要先将 PyTorch 模型导出为 ONNX 格式后再转换。
快速开始
1. 环境准备
确保已安装 Netrans,如未安装请参考 安装指南。
2. 进入示例目录
1cd /home/xj/work/nudt/netrans/examples/pytorch
2mamba activate netrans # 激活 netrans 环境
3. 数据准备
本示例包含以下文件:
resnet50/
├── dataset.txt # 数据集路径配置
├── dog.jpg # 测试图像
├── export_resnet50_2_onnx.py # PyTorch转ONNX脚本
└── resnet50.onnx # 转换后的ONNX模型(生成)
注意:resnet50.onnx 需要通过导出脚本生成,详见下一步。
4. PyTorch转ONNX
1# 执行导出脚本,将PyTorch模型转为ONNX格式
2python export_resnet50_2_onnx.py
5. 一体化转换(推荐)
1# CLI 方式 - 一体化流程(推荐)
2netrans load resnet50 --mean 128 128 128 --scale 1 1 1
3netrans quantize resnet50 asymu8 --pre --post
4netrans export resnet50 asymu8
分步详细流程
1# 1. PyTorch转ONNX
2python export_resnet50_2_onnx.py
3
4# 2. 模型导入
5netrans load resnet50 --mean 128 128 128 --scale 1 1 1
6
7# 3. 模型量化
8netrans quantize resnet50 asymu8
9
10# 4. 前后处理集成
11netrans add_pre_post resnet50 asymu8 --preprocess --postprocess
12
13# 5. 模型导出
14netrans export resnet50 asymu8
Python API 示例
完整流程
from netrans import Netrans
# 初始化
model = Netrans()
# 加载模型(ImageNet标准化参数)
model.load('resnet50', mean=[128,128,128], scale=[1,1,1])
# 量化并集成前后处理
model.quantize('asymu8', pre=True, post=True)
# 导出NBG格式
model.export('asymu8')
分步API使用
from netrans import Netrans
model = Netrans()
# 分步执行
model.load('resnet50', mean=[128,128,128], scale=[1,1,1])
model.quantize('asymu8')
model.add_pre_post('asymu8', pre=True, post=True)
model.export('asymu8')
输出说明
环境准备完成后,执行转换流程会生成:
resnet50/
├── dataset.txt
├── dog.jpg
├── export_resnet50_2_onnx.py
├── resnet50.onnx
├── resnet50.data # 模型权重数据(生成)
├── resnet50.json # 模型结构描述(生成)
├── resnet50_asymu8.quantize # 量化配置(生成)
├── resnet50_inputmeta.yml # 输入元数据(生成)
├── resnet50_postprocess_file.yml # 后处理配置(生成)
└── wksp/ # 工作空间(生成)
├── resnet50_asymu8/ # 量化模型工程
└── resnet50_asymu8_nbg_unify/ # 最终NBG输出
├── network_binary.nb # NBG模型文件(核心输出)
├── nbg_meta.json # NBG元数据
├── main.c # 测试程序
└── ... # 其他部署文件
核心输出文件:
network_binary.nb- 最终NBG模型文件,可直接部署到PNNA芯片nbg_meta.json- 模型元数据,包含输入输出信息
相关文档
版本: 6.42.4+
更新日期: 2025-12-17
测试模型: ResNet50 (PyTorch→ONNX)