激情久久久_欧美视频区_成人av免费_不卡视频一二三区_欧美精品在欧美一区二区少妇_欧美一区二区三区的

腳本之家,腳本語言編程技術及教程分享平臺!
分類導航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服務器之家 - 腳本之家 - Python - python使用tensorflow保存、加載和使用模型的方法

python使用tensorflow保存、加載和使用模型的方法

2021-01-11 00:07LordofRobots Python

本篇文章主要介紹了python使用tensorflow保存、加載和使用模型的方法,小編覺得挺不錯的,現在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧

使用Tensorflow進行深度學習訓練的時候,需要對訓練好的網絡模型和各種參數進行保存,以便在此基礎上繼續訓練或者使用。介紹這方面的博客有很多,我發現寫的最好的是這一篇官方英文介紹:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我對這篇文章進行了整理和匯總。

首先是模型的保存。直接上代碼:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut1_save.py
#Author: Wang 
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 11:04:25
############################
 
import tensorflow as tf
 
# prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2')
b1 = tf.Variable(2.0, name = 'bias1')
feed_dict = {w1:[10,3], w2:[5,5]}
 
# define a test operation that will be restored
w3 = tf.add(w1, w2) # without name, w3 will not be stored
w4 = tf.multiply(w3, b1, name = "op_to_restore")
 
#saver = tf.train.Saver()
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print sess.run(w4, feed_dict)
#saver.save(sess, 'my_test_model', global_step = 100)
saver.save(sess, 'my_test_model')
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要說明的有以下幾點:

1. 創建saver的時候可以指明要存儲的tensor,如果不指明,就會全部存下來。在這里也可以指明最大存儲數量和checkpoint的記錄時間。具體細節看英文博客。

2. saver.save()函數里面可以設定global_step和write_meta_graph,meta存儲的是網絡結構,只在開始運行程序的時候存儲一次即可,后續可以通過設置write_meta_graph = False加以限制。

3. 這個程序執行結束后,會在程序目錄下生成四個文件,分別是.meta(存儲網絡結構)、.data和.index(存儲訓練好的參數)、checkpoint(記錄最新的模型)。

下面是如何加載已經保存的網絡模型。這里有兩種方法,第一種是saver.restore(sess, 'aaaa.ckpt'),這種方法的本質是讀取全部參數,并加載到已經定義好的網絡結構上,因此相當于給網絡的weights和biases賦值并執行tf.global_variables_initializer()。這種方法的缺點是使用前必須重寫網絡結構,而且網絡結構要和保存的參數完全對上。第二種就比較高端了,直接把網絡結構加載進來(.meta),上代碼:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut2_import.py
#Author: Wang 
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 14:16:38
############################ 
import tensorflow as tf
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my_test_model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run('w1:0')

使用加載的模型,輸入新數據,計算輸出,還是直接上代碼:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut3_reuse.py
#Author: Wang
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 14:33:35
############################
 
import tensorflow as tf
 
sess = tf.Session()
 
# First, load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
 
# Second, access and create placeholders variables and create feed_dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name('w1:0')
w2 = graph.get_tensor_by_name('w2:0')
feed_dict = {w1:[-1,1], w2:[4,6]}
 
# Access the op that want to run
op_to_restore = graph.get_tensor_by_name('op_to_restore:0')
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已經加載的網絡后繼續加入新的網絡層:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import tensorflow as tf
sess=tf.Session()  
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
 
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
 
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
 
#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
 
#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)
 
print sess.run(add_on_op,feed_dict)
#This will print 120.

對加載的網絡進行局部修改和處理(這個最麻煩,我還沒搞太明白,后續會繼續補充):

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 
 
#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')
 
#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
 
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)
 
# Now, you run this with fine-tuning data in sess.run()

有了這樣的方法,無論是自行訓練、加載模型繼續訓練、使用經典模型還是finetune經典模型抑或是加載網絡跑前項,效果都是杠杠的。

以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支持服務器之家。

原文鏈接:http://blog.csdn.net/LordofRobots/article/details/77719020

延伸 · 閱讀

精彩推薦
主站蜘蛛池模板: 久久久久国产成人精品亚洲午夜 | 久久免费视频一区 | 妇女毛片| 国产激情精品一区二区三区 | 久久久一区二区精品 | 精品久久久久久 | 国产精品啪一品二区三区粉嫩 | 国产精品伦视频看免费三 | 二区三区偷拍浴室洗澡视频 | 99视频有精品视频高清 | 免费国产 | 在线一级片 | 国产精品1区,2区,3区 | a级在线 | 欧美成人小视频 | 在线香蕉视频 | 亚洲一区国产二区 | 日韩欧美电影一区二区三区 | 日本教室三级在线看 | 艹艹艹逼 | 日本成人在线免费 | 精品av在线播放 | av在线播放电影 | 第一区免费在线观看 | 日本不卡一区二区三区在线 | 日韩在线欧美在线 | 91精品久久久久久久 | 国产精品久久久久久久久久东京 | 九草在线视频 | 欧美成人精品欧美一级乱黄 | 国产精品爱久久久久久久 | 一二区成人影院电影网 | 最新中文字幕在线视频 | 日韩精品久久久 | 娇妻被各种姿势c到高潮小说 | 嫩草影院在线观看网站成人 | 精品国产一级毛片 | 久久99精品久久久久久小说 | 特级西西444www大精品视频免费看 | 午夜精品久久久久久毛片 | 欧美日韩手机在线观看 |