格式汇总
重现模型的关键是模型结构/模型参数/数据集, 当提供或者希望被提供这些信息, 需要一个交流的规范, 而 .pt/.pth/.bin/.onnx 就是约定的格式.
使用 torch.load()
方法加载模型信息时, 不根据文件后缀进行模型读取, 而是根据模型文件的内容自动识别.
格式 | 解释 | 适用场景 | 对应后缀 |
---|---|---|---|
.pt/.pth | PyTorch 的默认模型文件格式, 保存和加载完整的 PyTorch 模型, 包含模型的结构和参数等信息。 | 需要保存和加载完整的 PyTorch 模型, 例如在训练中保存最佳的模型, 或在部署中加载训练好的模型. | .pt/.pth |
.bin | 通用二进制格式, 保存和加载各种类型的模型和数据. | 需要将 PyTorch 模型转换为通用的二进制格式. | .bin |
ONNX | 通用模型交换格式, 将模型从一个深度学习框架转换到另一个深度学习框架或硬件平台. PyTorch 中使用 torch.onnx.export 将模型转换为 ONNX 格式. | 需要将 PyTorch 模型转换为其他深度学习框架或硬件平台可用的格式. | .onnx |
.pt/.pth格式
完整的 Pytorch 模型文件, 包含如下参数:
model_state_dict
: 模型参数optimizer_state_dict
: 优化器状态epoch
: 训练轮数loss
: 损失值
其中 model.state_dict()
返回包含所有参数和持久化缓存的字典, torch.save()
将所有组件保存到文件中,
模型保存
import torch
import torch.nn as nn
class Net(nn.Module):
...
model = Net()
torch.save({
'epoch': 10,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, PATH)
模型加载
import torch
import torch.nn as nn
class Net(nn.Module):
...
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
.bin 格式
二进制文件, 保存模型的参数和持久化缓存. 文件大小较小, 加载速度较快, 在生产环境中使用较多.
保存模型
import torch
import torch.nn as nn
class Net(nn.Module):
...
model = Net()
torch.save(model.state_dict(), PATH)
加载模型
import torch
import torch.nn as nn
class Net(nn.Module):
...
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()
.onnx 格式
PyTorch 提供 torch.onnx.export
函数将模型转化为 ONNX 格式, 在其他深度学习框架中使用.
保存模型
import torch
import torch.onnx
model = torch.nn.Linear(3, 1)
torch.save(model.state_dict(), "model.bin")
model = torch.nn.Linear(3, 1)
model.load_state_dict(torch.load("model.bin"))
example_input = torch.randn(1, 3)
torch.onnx.export(model, example_input, "model.onnx", input_names=["input"], output_names=["output"])
加载模型
import onnx
import onnxruntime
# 将 ONNX 文件转化为 ORT 格式
onnx_model = onnx.load("model.onnx")
ort_session = onnxruntime.InferenceSession("model.onnx")
input_data = np.random.random(size=(1, 3)).astype(np.float32)
outputs = ort_session.run(None, {"input": input_data})
保存完整模型
之前的方式都保存了 .state_dict()
, 但没有保存模型的结构, 使用时必须重新定义相同结构的模型, 才能够加载模型参数进行使用. 想把整个模型都保存下来, 可以按如下操作:
保存模型
PATH = "entire_model.pt"
torch.save(model, PATH)
加载模型
model = torch.load("entire_model.pt")
model.eval()