如何用ONNX Runtime运行PyTorch导出的模型并解决类型不兼容问题?(导出.如何用.不兼容.模型.运行...)
利用ONNX Runtime高效运行PyTorch模型
本文将指导您如何使用ONNX Runtime运行经torch.onnx.export导出的PyTorch模型,并重点解决PyTorch张量与ONNX Runtime所需NumPy数组类型不兼容的问题。
首先,我们来看一个PyTorch模型导出示例:
PHP
import torch
class SumModule(torch.nn.Module):
def forward(self, x):
return torch.sum(x, dim=1)
torch.onnx.export(
SumModule(),
(torch.ones(2, 2),),
"onnx.pb",
input_names=["x"],
output_names=["sum"]
)
这段代码定义了一个简单的PyTorch模型SumModule,并将其导出为名为onnx.pb的ONNX模型文件。
直接使用PyTorch张量作为ONNX Runtime的输入会导致错误,因为ONNX Runtime期望的是NumPy数组。 错误信息通常提示输入类型错误。
为了解决这个问题,我们需要将PyTorch张量转换为NumPy数组。 正确的代码如下:
PHP
import onnxruntime
import numpy as np
import torch
ort_session = onnxruntime.InferenceSession("onnx.pb")
# 关键修改:将torch.Tensor转换为np.ndarray
x = np.ones((2, 2), dtype=np.float32)
inputs = {ort_session.get_inputs()[0].name: x}
print(ort_session.run(None, inputs))
这段代码加载onnx.pb文件,创建一个形状为(2, 2),数据类型为float32的NumPy数组作为模型输入。 ort_session.get_inputs()[0].name 获取输入张量的名称,确保输入数据与模型定义匹配。 ort_session.run 函数运行模型并打印输出结果。
更简洁的等效代码:
PHP
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession("onnx.pb")
input_data = np.ones((2, 2)).astype(np.float32)
output_data = sess.run(None, {"x": input_data})[0]
print(output_data)
这段代码功能相同,但更简洁易读。 关键在于使用NumPy数组作为输入。
通过以上方法,您可以成功加载并运行使用torch.onnx.export导出的PyTorch模型。 请确保输入数据的类型和形状与模型的预期输入相匹配。
以上就是如何用ONNX Runtime运行PyTorch导出的模型并解决类型不兼容问题?的详细内容,更多请关注知识资源分享宝库其它相关文章!