激情久久久_欧美视频区_成人av免费_不卡视频一二三区_欧美精品在欧美一区二区少妇_欧美一区二区三区的

腳本之家,腳本語言編程技術及教程分享平臺!
分類導航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服務器之家 - 腳本之家 - Python - TensorFlow——Checkpoint為模型添加檢查點的實例

TensorFlow——Checkpoint為模型添加檢查點的實例

2020-04-05 12:40Baby-Lily Python

今天小編就為大家分享一篇TensorFlow——Checkpoint為模型添加檢查點的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

1.檢查點

保存模型并不限于在訓練模型后,在訓練模型之中也需要保存,因為TensorFlow訓練模型時難免會出現中斷的情況,我們自然希望能夠將訓練得到的參數保存下來,否則下次又要重新訓練。

這種在訓練中保存模型,習慣上稱之為保存檢查點。

2.添加保存點

通過添加檢查點,可以生成載入檢查點文件,并能夠指定生成檢查文件的個數,例如使用saver的另一個參數——max_to_keep=1,表明最多只保存一個檢查點文件,在保存時使用如下的代碼傳入迭代次數。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
 
train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5
 
plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()
 
tf.reset_default_graph()
 
X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)
 
w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')
 
z = tf.multiply(X, w) + b
 
cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
 
init = tf.global_variables_initializer()
 
training_epochs = 20
display_step = 2
 
 
saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"
 
 
if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})
 
   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)
 
   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
 
   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)
 
  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()
 
 load_epoch = 10
 
 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代碼中,我們使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)將訓練的參數傳入檢查點進行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一個文件,這樣在訓練過程中得到的新的模型就會覆蓋以前的模型。

?
1
2
3
4
5
6
cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
  saver.restore(sess2, cpkt.model_checkpoint_path)
 
kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的兩種方法也可以對checkpoint文件進行加載,tf.train.latest_checkpoint(savedir)為加載最后的檢查點文件。這種方式,我們可以通過保存指定訓練次數的檢查點,比如保存5的倍數次保存一下檢查點。

3.簡便保存檢查點

我們還可以用更加簡單的方法進行檢查點的保存,tf.train.MonitoredTrainingSession()函數,該函數可以直接實現保存載入檢查點模型的文件,與前面的方法不同的是,它是按照訓練時間來保存檢查點的,可以通過指定save_checkpoint_secs參數的具體秒數,設置多久保存一次檢查點。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
 
train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5
 
# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()
 
tf.reset_default_graph()
 
X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)
 
w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')
 
z = tf.multiply(X, w) + b
 
cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
 
init = tf.global_variables_initializer()
 
training_epochs = 30
display_step = 2
 
 
global_step = tf.train.get_or_create_global_step()
 
step = tf.assign_add(global_step, 1)
 
saver = tf.train.Saver()
 
savedir = "check-point/"
 
if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})
 
   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)
 
   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)
 
  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()
 
 load_epoch = 10
 
 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
 
  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))
 
  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')
 
  saver.restore(sess2, kpt)
 
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代碼中,我們設置了沒訓練了5秒中之后,就保存一次檢查點,它默認的保存時間間隔是10分鐘,這種按照時間的保存模式更適合使用大型數據集訓練復雜模型的情況,注意在使用上述的方法時,要定義global_step變量,在訓練完一個批次或者一個樣本之后,要將其進行加1的操作,否則將會報錯。

TensorFlow——Checkpoint為模型添加檢查點的實例

以上這篇TensorFlow——Checkpoint為模型添加檢查點的實例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持服務器之家。

原文鏈接:https://www.cnblogs.com/baby-lily/p/10930591.html

延伸 · 閱讀

精彩推薦
主站蜘蛛池模板: 亚洲国产精品久久久久婷婷老年 | 中日无线码1区 | 黄色免费av网站 | 久久久久久久久久综合 | 成人免费毛片明星色大师 | 91精品国产手机 | 精品久久久久久久久久中出 | 久久久在线 | 成人毛片av在线 | 久久久久久三区 | 亚洲aⅴ免费在线观看 | 久久精品视频首页 | 法国性hdfreexxxx人妖 | 羞羞的动漫在线观看 | 毛片av网址 | 草免费视频 | 九九热视频免费观看 | 欧美精品一区二区性色 | 国产一级做a爱片在线看免 2019天天干夜夜操 | 国产精品视频一区二区三区综合 | 日本免费大片免费视频 | 九九热精品免费视频 | 正在播放91精 | 亚洲成人午夜精品 | 国产精品久久久久久久久粉嫩 | 国产成人在线免费视频 | 国产毛毛片一区二区三区四区 | 久久手机在线视频 | 媚药按摩痉挛w中文字幕 | 国产午夜精品一区二区三区在线观看 | 久久金品 | 免费亚洲视频在线观看 | 国产成人av免费 | 欧美xxxxx视频 | javhdfreejaⅴhd| 日韩精品久久久久久久电影99爱 | 曰韩黄色片 | 色中色综合 | 成人不卡 | 久久九九热re6这里有精品 | 国产1区2区在线 |