激情久久久_欧美视频区_成人av免费_不卡视频一二三区_欧美精品在欧美一区二区少妇_欧美一区二区三区的

腳本之家,腳本語言編程技術及教程分享平臺!
分類導航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服務器之家 - 腳本之家 - Python - PyTorch深度學習模型的保存和加載流程詳解

PyTorch深度學習模型的保存和加載流程詳解

2022-02-14 20:46軟耳朵DONG Python

PyTorch是一個開源的Python機器學習庫,基于Torch,用于自然語言處理等應用程序。2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch,這篇文章主要介紹了PyTorch模型的保存和加載流程

一、模型參數的保存和加載

  •  torch.save(module.state_dict(), path):使用module.state_dict()函數獲取各層已經訓練好的參數和緩沖區,然后將參數和緩沖區保存到path所指定的文件存放路徑(常用文件格式為.pt.pth.pkl)。
  • torch.nn.Module.load_state_dict(state_dict):從state_dict中加載參數和緩沖區到Module及其子類中 。
  • torch.nn.Module.state_dict()函數返回python中的一個OrderedDict類型字典對象,該對象將每一層與它的對應參數和緩沖區建立映射關系,字典的鍵值是參數或緩沖區的名稱。只有那些參數可以訓練的層才會被保存到OrderedDict中,例如:卷積層、線性層等。
  • Python中的字典類以“鍵:值”方式存取數據,OrderedDict是它的一個子類,實現了對字典對象中元素的排序(OrderedDict根據放入元素的先后順序進行排序)。由于進行了排序,所以順序不同的兩個OrderedDict字典對象會被當做是兩個不同的對象。
  • 示例:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

# 初始化網絡
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 獲取state_dict
state_dict = net.state_dict()
# 字典的遍歷默認是遍歷key,所以param_tensor實際上是鍵值
for param_tensor in state_dict: 
    print(param_tensor,":
",state_dict[param_tensor])
# 保存模型參數
torch.save(state_dict,"net_params.pth")
# 通過加載state_dict獲取模型參數
net.load_state_dict(state_dict)

輸出:

PyTorch深度學習模型的保存和加載流程詳解

二、完整模型的保存和加載

  •  torch.save(module, path):將訓練完的整個網絡模型module保存到path所指定的文件存放路徑(常用文件格式為.pt.pth)。
  • torch.load(path):加載保存到path中的整個神經網絡模型。
  • 示例:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

# 初始化網絡
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 保存整個網絡
torch.save(net,"net.pth")
# 加載網絡
net = torch.load("net.pth")

到此這篇關于PyTorch深度學習模型的保存和加載流程詳解的文章就介紹到這了,更多相關PyTorch 模型的保存 內容請搜索服務器之家以前的文章或繼續瀏覽下面的相關文章希望大家以后多多支持服務器之家!

原文鏈接:https://blog.csdn.net/m0_52650517/article/details/120836999

延伸 · 閱讀

精彩推薦
主站蜘蛛池模板: 黄色伊人网站 | 日韩视频一区在线 | 久草在线综合 | 国产精品久久久久久久久久iiiii | 美女啪网站 | 成人啪啪18免费网站 | 久久成人国产精品 | 99视频观看 | 久色porn | 久久久看 | 毛片视频播放 | 国产又白又嫩又紧又爽18p | 久久久久久久久久久久99 | 91麻豆精品国产91久久久更新资源速度超快 | 曰韩av在线| 久草最新在线 | www.91sao| 欧美a级大胆视频 | 狠狠操视频网站 | 久久2019中文字幕 | 成人不卡一区二区 | 99久久久精品视频 | 久久精品成人影院 | 黄色av网站免费 | 最新欧美精品一区二区三区 | 日韩欧美电影一区二区三区 | 92看片淫黄大片欧美看国产片 | 国产精品久久久久久久久久尿 | 毛片在线免费观看视频 | 午夜在线观看视频网站 | 国产做爰 | 亚洲成人在线视频网 | 久久久久久久久久久av | 成人羞羞在线观看网站 | 国产一区视频在线免费观看 | 中文字幕亚洲情99在线 | 日韩激情一区二区三区 | 视频一区国产精品 | 国产精品久久久久久久久久 | 国产毛片自拍 | 国产精品三级a三级三级午夜 |