激情久久久_欧美视频区_成人av免费_不卡视频一二三区_欧美精品在欧美一区二区少妇_欧美一区二区三区的

腳本之家,腳本語言編程技術及教程分享平臺!
分類導航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服務器之家 - 腳本之家 - Python - pytorch 利用lstm做mnist手寫數字識別分類的實例

pytorch 利用lstm做mnist手寫數字識別分類的實例

2020-04-29 09:44xckkcxxck Python

今天小編就為大家分享一篇pytorch 利用lstm做mnist手寫數字識別分類的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

代碼如下,U我認為對于新手來說最重要的是學會rnn讀取數據的格式。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定義數據
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定義模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用兩層lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#將最后一個的rnn使用全連接的到最后的輸出結果
     
   def forward(self, x):
     #x的大小為(batch,1,28,28),所以我們需要將其轉化為rnn的輸入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,變成(batch, 28,28)
     x = x.permute(2, 0, 1)#將最后一維放到第一維,變成(batch,28,28)
     out, _ = self.rnn(x) #使用默認的隱藏狀態,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一個,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分類結果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定義訓練過程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)   

以上這篇pytorch 利用lstm做mnist手寫數字識別分類的實例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持服務器之家。

原文鏈接:https://blog.csdn.net/xckkcxxck/article/details/82978942

延伸 · 閱讀

精彩推薦
主站蜘蛛池模板: 永久免费不卡在线观看黄网站 | 最新欧美精品一区二区三区 | ,欧美特黄特色三级视频在线观看 | 久久影城 | tube69xxxxxhd| 欧美成人视 | 黄色免费入口 | 日韩av成人 | 成人午夜影院 | 在线播放视频一区二区 | 欧美成人一级片 | 久久免费视频1 | 成人在线影视 | 中日无线码1区 | 成人av一区二区免费播放 | 黄色片网站在线播放 | 免费观看国产精品视频 | 毛片视频网址 | 成人在线网站 | 国产精品亚洲欧美一级在线 | 青青草免费观看 | 亚洲最新无码中文字幕久久 | 黄色片观看 | 日韩视频在线一区二区三区 | 国产美女爽到喷白浆的 | 日本欧美在线播放 | 午夜在线视频一区二区三区 | 亚洲第一成人在线观看 | 中文字幕精品亚洲 | 成人区一区二区三区 | 一本色道久久综合狠狠躁篇适合什么人看 | 黄色aaa视频 | www.99久久久| 黄网站在线观 | 九九热在线精品视频 | 国产一区二区精彩视频 | 成人偷拍片视频在线观看 | 一级黄色影片在线观看 | 久久久久久久免费看 | 国产亚洲精品久久久久久久 | 国产久草视频在线 |