机器学习-模型转换(1)


格式汇总

重现模型的关键是模型结构/模型参数/数据集, 当提供或者希望被提供这些信息, 需要一个交流的规范, 而 .pt/.pth/.bin/.onnx 就是约定的格式.

使用 torch.load() 方法加载模型信息时, 不根据文件后缀进行模型读取, 而是根据模型文件的内容自动识别.

格式 解释 适用场景 对应后缀
.pt/.pth PyTorch 的默认模型文件格式, 保存和加载完整的 PyTorch 模型, 包含模型的结构和参数等信息。 需要保存和加载完整的 PyTorch 模型, 例如在训练中保存最佳的模型, 或在部署中加载训练好的模型. .pt/.pth
.bin 通用二进制格式, 保存和加载各种类型的模型和数据. 需要将 PyTorch 模型转换为通用的二进制格式. .bin
ONNX 通用模型交换格式, 将模型从一个深度学习框架转换到另一个深度学习框架或硬件平台. PyTorch 中使用 torch.onnx.export 将模型转换为 ONNX 格式. 需要将 PyTorch 模型转换为其他深度学习框架或硬件平台可用的格式. .onnx

.pt/.pth格式

完整的 Pytorch 模型文件, 包含如下参数:

  • model_state_dict: 模型参数
  • optimizer_state_dict: 优化器状态
  • epoch: 训练轮数
  • loss: 损失值

其中 model.state_dict() 返回包含所有参数和持久化缓存的字典, torch.save() 将所有组件保存到文件中,

模型保存

import torch
import torch.nn as nn

class Net(nn.Module):
    ...

model = Net()
torch.save({
    'epoch': 10,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss
}, PATH)

模型加载

import torch
import torch.nn as nn

class Net(nn.Module):
    ...

model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()

.bin 格式

二进制文件, 保存模型的参数和持久化缓存. 文件大小较小, 加载速度较快, 在生产环境中使用较多.

保存模型

import torch
import torch.nn as nn

class Net(nn.Module):
	...

model = Net()
torch.save(model.state_dict(), PATH)

加载模型

import torch
import torch.nn as nn

class Net(nn.Module):
    ...

model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()

.onnx 格式

PyTorch 提供 torch.onnx.export 函数将模型转化为 ONNX 格式, 在其他深度学习框架中使用.

保存模型

import torch
import torch.onnx

model = torch.nn.Linear(3, 1)
torch.save(model.state_dict(), "model.bin")

model = torch.nn.Linear(3, 1)
model.load_state_dict(torch.load("model.bin"))

example_input = torch.randn(1, 3)
torch.onnx.export(model, example_input, "model.onnx", input_names=["input"], output_names=["output"])

加载模型

import onnx
import onnxruntime

# 将 ONNX 文件转化为 ORT 格式
onnx_model = onnx.load("model.onnx")
ort_session = onnxruntime.InferenceSession("model.onnx")
input_data = np.random.random(size=(1, 3)).astype(np.float32)
outputs = ort_session.run(None, {"input": input_data})

保存完整模型

之前的方式都保存了 .state_dict(), 但没有保存模型的结构, 使用时必须重新定义相同结构的模型, 才能够加载模型参数进行使用. 想把整个模型都保存下来, 可以按如下操作:

保存模型

PATH = "entire_model.pt"
torch.save(model, PATH)

加载模型

model = torch.load("entire_model.pt")
model.eval()

文章作者: Chengsx
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Chengsx !
  目录