激情久久久_欧美视频区_成人av免费_不卡视频一二三区_欧美精品在欧美一区二区少妇_欧美一区二区三区的

腳本之家,腳本語言編程技術及教程分享平臺!
分類導航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服務器之家 - 腳本之家 - Python - keras 回調函數Callbacks 斷點ModelCheckpoint教程

keras 回調函數Callbacks 斷點ModelCheckpoint教程

2020-06-18 10:42jieshaoxiansen Python

這篇文章主要介紹了keras 回調函數Callbacks 斷點ModelCheckpoint教程,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

整理自kerashttps://keras-cn.readthedocs.io/en/latest/other/callbacks/

回調函數Callbacks

回調函數是一個函數的合集,會在訓練的階段中所使用。你可以使用回調函數來查看訓練模型的內在狀態和統計。你可以傳遞一個列表的回調函數(作為 callbacks 關鍵字參數)到 Sequential 或 Model 類型的 .fit() 方法。在訓練時,相應的回調函數的方法就會被在各自的階段被調用。

Callback

keras.callbacks.Callback()

這是回調函數的抽象類,定義新的回調函數必須繼承自該類

類屬性

params:字典,訓練參數集(如信息顯示方法verbosity,batch大小,epoch數)

model:keras.models.Model對象,為正在訓練的模型的引用

回調函數以字典logs為參數,該字典包含了一系列與當前batch或epoch相關的信息。

目前,模型的.fit()中有下列參數會被記錄到logs中:

在每個epoch的結尾處(on_epoch_end),logs將包含訓練的正確率和誤差,acc和loss,如果指定了驗證集,還會包含驗證集正確率和誤差val_acc)和val_loss,val_acc還額外需要在.compile中啟用metrics=['accuracy']。

在每個batch的開始處(on_batch_begin):logs包含size,即當前batch的樣本數

在每個batch的結尾處(on_batch_end):logs包含loss,若啟用accuracy則還包含acc

ModelCheckpoint

keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)

該回調函數將在每個epoch后保存模型到filepath

filepath 可以包括命名格式選項,可以由 epoch 的值和 logs 的鍵(由 on_epoch_end 參數傳遞)來填充。

參數:

filepath: 字符串,保存模型的路徑。

monitor: 被監測的數據。val_acc或這val_loss

verbose: 詳細信息模式,0 或者 1 。0為不打印輸出信息,1打印

save_best_only: 如果 save_best_only=True, 將只保存在驗證集上性能最好的模型

mode: {auto, min, max} 的其中之一。 如果 save_best_only=True,那么是否覆蓋保存文件的決定就取決于被監測數據的最大或者最小值。 對于 val_acc,模式就會是 max,而對于 val_loss,模式就需要是 min,等等。 在 auto 模式中,方向會自動從被監測的數據的名字中判斷出來。

save_weights_only: 如果 True,那么只有模型的權重會被保存 (model.save_weights(filepath)), 否則的話,整個模型會被保存 (model.save(filepath))。

period: 每個檢查點之間的間隔(訓練輪數)。

代碼實現過程:

① 從keras.callbacks導入ModelCheckpoint類

from keras.callbacks import ModelCheckpoint

② 在訓練階段的model.compile之后加入下列代碼實現每一次epoch(period=1)保存最好的參數

checkpoint = ModelCheckpoint(filepath,
monitor='val_loss', save_weights_only=True,verbose=1,save_best_only=True, period=1)

③ 在訓練階段的model.fit之前加載先前保存的參數

?
1
2
3
4
if os.path.exists(filepath):
 model.load_weights(filepath)
 # 若成功加載前面保存的參數,輸出下列信息
 print("checkpoint_loaded")

④ 在model.fit添加callbacks=[checkpoint]實現回調

?
1
2
3
4
5
6
7
model.fit_generator(data_generator_wrap(lines[:num_train], batch_size, input_shape, anchors, num_classes),
 steps_per_epoch=max(1, num_train//batch_size),
 validation_data=data_generator_wrap(lines[num_train:], batch_size, input_shape, anchors, num_classes),
 validation_steps=max(1, num_val//batch_size),
 epochs=3,
 initial_epoch=0,
 callbacks=[checkpoint])

補充知識:keras之多輸入多輸出(多任務)模型

keras多輸入多輸出模型,以keras官網的demo為例,分析keras多輸入多輸出的適用。

主要輸入(main_input): 新聞標題本身,即一系列詞語。

輔助輸入(aux_input): 接受額外的數據,例如新聞標題的發布時間等。

該模型將通過兩個損失函數進行監督學習。

較早地在模型中使用主損失函數,是深度學習模型的一個良好正則方法。

完整過程圖示如下:

keras 回調函數Callbacks 斷點ModelCheckpoint教程

其中,紅圈中的操作為將輔助數據與LSTM層的輸出連接起來,輸入到模型中。

代碼實現:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import keras
from keras.layers import Input, Embedding, LSTM, Dense
from keras.models import Model
 
# 定義網絡模型
# 標題輸入:接收一個含有 100 個整數的序列,每個整數在 1 到 10000 之間
# 注意我們可以通過傳遞一個 `name` 參數來命名任何層
main_input = Input(shape=(100,), dtype='int32', name='main_input')
 
# Embedding 層將輸入序列編碼為一個稠密向量的序列,每個向量維度為 512
x = Embedding(output_dim=512, input_dim=10000, input_length=100)(main_input)
 
# LSTM 層把向量序列轉換成單個向量,它包含整個序列的上下文信息
lstm_out = LSTM(32)(x)
 
# 在這里我們添加輔助損失,使得即使在模型主損失很高的情況下,LSTM層和Embedding層都能被平穩地訓練
auxiliary_output = Dense(1, activation='sigmoid', name='aux_output')(lstm_out)
 
# 此時,我們將輔助輸入數據與LSTM層的輸出連接起來,輸入到模型中
auxiliary_input = Input(shape=(5,), name='aux_input')
x = keras.layers.concatenate([lstm_out, auxiliary_output])
 
# 再添加剩余的層
# 堆疊多個全連接網絡層
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
 
# 最后添加主要的邏輯回歸層
main_output = Dense(1, activation='sigmoid', name='main_output')(x)
 
# 定義這個具有兩個輸入和輸出的模型
model = Model(inputs=[main_input, auxiliary_input], outputs=[main_output, auxiliary_output])
 
# 編譯模型時候分配損失函數權重:編譯模型的時候,給 輔助損失 分配一個0.2的權重
model.compile(optimizer='rmsprop', loss='binary_crossentropy', loss_weights=[1., 0.2])
 
# 訓練模型:我們可以通過傳遞輸入數組和目標數組的列表來訓練模型
model.fit([headline_data, additional_data], [labels, labels], epochs=50, batch_size=32)
 
# 另外一種利用字典的編譯、訓練方式
# 由于輸入和輸出均被命名了(在定義時傳遞了一個 name 參數),我們也可以通過以下方式編譯模型
model.compile(optimizer='rmsprop',
    loss={'main_output': 'binary_crossentropy', 'aux_output': 'binary_crossentropy'},
    loss_weights={'main_output': 1., 'aux_output': 0.2})
# 然后使用以下方式訓練:
model.fit({'main_input': headline_data, 'aux_input': additional_data},
   {'main_output': labels, 'aux_output': labels},
   epochs=50, batch_size=32)

相關參考:https://keras.io/zh/getting-started/functional-api-guide/

以上這篇keras 回調函數Callbacks 斷點ModelCheckpoint教程就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持服務器之家。

原文鏈接:https://blog.csdn.net/jieshaoxiansen/article/details/82762922

延伸 · 閱讀

精彩推薦
主站蜘蛛池模板: 色婷婷久久久亚洲一区二区三区 | 成人在线视频精品 | 亚洲3p激情在线观看 | 欧美一级淫片a免费播放口 91九色蝌蚪国产 | 91午夜免费视频 | 史上最强炼体老祖动漫在线观看 | 日韩黄色三级视频 | 欧美黄色大片免费观看 | 欧美 日韩 三区 | 黄色羞羞视频在线观看 | 色av成人天堂桃色av | 国产91丝袜在线播放 | 精品国产一区二区久久 | 九九热在线精品视频 | 一级黄色毛片a | 成人免费福利 | 又黄又爽免费无遮挡在线观看 | 成人电影毛片 | av播放在线| 色淫湿视频 | 92自拍视频 | 精品999www | 羞羞视频免费网站日本动漫 | 在线小视频国产 | 久久久久成人免费 | 国产成人精品区一区二区不卡 | 国产美女视频一区二区三区 | 欧美一级毛片一级毛片 | 国产一区免费在线 | 成码无人av片在线观看网站 | 免费中文视频 | 欧美一级精品片在线看 | 国产精品视频成人 | free国产hd老熟bbw | 911色_911色sss主站色播 | 国产欧美亚洲精品 | 蜜桃成品人免费视频 | 深夜影院a | 成人超碰97 | 欧美一区二区三区免费不卡 | 亚州综合一区 |