激情久久久_欧美视频区_成人av免费_不卡视频一二三区_欧美精品在欧美一区二区少妇_欧美一区二区三区的

腳本之家,腳本語言編程技術及教程分享平臺!
分類導航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服務器之家 - 腳本之家 - Python - 解決Keras中循環使用K.ctc_decode內存不釋放的問題

解決Keras中循環使用K.ctc_decode內存不釋放的問題

2020-06-29 12:13愛明_愛夏 Python

這篇文章主要介紹了解決Keras中循環使用K.ctc_decode內存不釋放的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

如下一段代碼,在多次調用了K.ctc_decode時,會發現程序占用的內存會越來越高,執行速度越來越慢。

?
1
2
3
4
5
6
7
8
9
data = generator(...)
model = init_model(...)
for i in range(NUM):
  x, y = next(data)
  _y = model.predict(x)
  shape = _y.shape
  input_length = np.ones(shape[0]) * shape[1]
  ctc_decode = K.ctc_decode(_y, input_length)[0][0]
  out = K.get_value(ctc_decode)

原因

每次執行ctc_decode時都會向計算圖中添加一個節點,這樣會導致計算圖逐漸變大,從而影響計算速度和內存。

PS:有資料說是由于get_value導致的,其中也給出了解決方案。

但是我將ctc_decode放在循環體之外就不再出現內存和速度問題,這是否說明get_value影響其實不大呢?

解決方案

通過K.function封裝K.ctc_decode,只需初始化一次,只向計算圖中添加一個計算節點,然后多次調用該節點(函數)

?
1
2
3
4
5
6
7
8
9
data = generator(...)
model = init_model(...)
x = model.output  # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.input, input_length], [ctc_decode[0][0]])
for i in range(NUM):
  _x, _y = next(data)
  out = decode([_x, np.ones(1)])

 

補充知識:CTC_loss和CTC_decode的模型封裝代碼避免節點不斷增加

該問題可以參考上面的描述,無論是CTC_decode還是CTC_loss,每次運行都會創建節點,避免的方法是將其封裝到model中,這樣就固定了計算節點。

測試方法: 在初始化節點后(注意是在運行fit/predict至少一次后,因為這些方法也會更改計算圖狀態),運行K.get_session().graph.finalize()鎖定節點,此時如果圖節點變了會報錯并提示出錯代碼。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
  '''
  用于計算CTC loss
  '''
  def ctc_lambda_func(self,args):
    """Runs CTC loss algorithm on each batch element.
 
    # Arguments
      y_true: tensor `(samples, max_string_length)` 真實標簽
      y_pred: tensor `(samples, time_steps, num_categories)` 預測前未經過softmax的向量
      input_length: tensor `(samples, 1)` 每一個y_pred的長度
      label_length: tensor `(samples, 1)` 每一個y_true的長度
 
      # Returns
        Tensor with shape (samples,1) 包含了每一個樣本的ctc loss
      """
    y_true, y_pred, input_length, label_length = args
 
    # y_pred = y_pred[:, :, :]
    # y_pred = y_pred[:, 2:, :]
    return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)
 
  def __call__(self, args):
    '''
    ctc_decode 每次創建會生成一個節點,這里參考了上面的內容
    將ctc封裝成模型,是否會解決這個問題還沒有測試過這種方法是否還會出現創建節點的問題
    '''
    y_true = Input(shape=(None,))
    y_pred = Input(shape=(None,None))
    input_length = Input(shape=(1,))
    label_length = Input(shape=(1,))
 
    lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
    model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")
 
    # return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
    return model(args)
 
  def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
    """Runs CTC loss algorithm on each batch element.
 
    # Arguments
      y_true: tensor `(samples, max_string_length)`
        containing the truth labels.
      y_pred: tensor `(samples, time_steps, num_categories)`
        containing the prediction, or output of the softmax.
      input_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_pred`.
      label_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_true`.
 
    # Returns
      Tensor with shape (samples,1) containing the
        CTC loss of each element.
    """
    label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
    input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
    sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))
 
    y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)
 
    # 注意這里的True是為了忽略解碼失敗的情況,此時loss會變成nan直到下一個個batch
    return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
                      labels=sparse_labels,
                      sequence_length=input_length,
                      ignore_longer_outputs_than_inputs=True), 1)
 
# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
 
class CTCDecodeLayer(Layer):
 
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
 
  def _ctc_decode(self,args):
    base_pred, in_len = args
    in_len = K.squeeze(in_len,axis=-1)
 
    r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
    r1 = r[0][0]
    prob = r[1][0]
    return [r1,prob]
 
  def call(self, inputs, **kwargs):
    return self._ctc_decode(inputs)
 
  def compute_output_shape(self, input_shape):
    return [(None,None),(1,)]
 
class CTCDecode():
  '''用與CTC 解碼,得到真實語音序列
      2019年7月18日所寫,對ctc_decode使用模型進行了封裝,從而在初始化完成后不會再有新節點的產生
  '''
  def __init__(self):
    base_pred = Input(shape=[None,None],name="pred")
    feature_len = Input(shape=[1,],name="feature_len")
    r1, prob = CTCDecodeLayer()([base_pred,feature_len])
    self.model = Model([base_pred,feature_len],[r1,prob])
    pass
 
  def ctc_decode(self,base_pred,in_len,return_prob = False):
    '''
    :param base_pred:[sample,timestamp,vector]
    :param in_len: [sample,1]
    :return:
    '''
    result,prob = self.model.predict([base_pred,in_len])
    if return_prob:
      return result,prob
    return result
 
  def __call__(self,base_pred,in_len,return_prob = False):
    return self.ctc_decode(base_pred,in_len,return_prob)
 
 
# 使用方法:(注意shape,是batch級的輸入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len)

以上這篇解決Keras中循環使用K.ctc_decode內存不釋放的問題就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持服務器之家。

原文鏈接:https://blog.csdn.net/u014484783/article/details/88849971

延伸 · 閱讀

精彩推薦
主站蜘蛛池模板: 久久久久九九九女人毛片 | 成人午夜高清 | 人人看人人舔 | 国产精品成人免费一区久久羞羞 | 国产1区2区在线观看 | 久久久电影电视剧免费看 | 精品国产91久久久久 | 操你视频| 久久欧美亚洲另类专区91大神 | 羞羞电影在线观看 | 懂色av懂色aⅴ精彩av | 久久最新视频 | 99亚洲 | 羞羞的动漫在线观看 | 国产精品探花在线观看 | 欧美片一区二区 | 欧产日产国产精品乱噜噜 | 欧美日韩亚州综合 | 国产三级国产精品国产普男人 | 欧美1—12sexvideos| 依人九九宗合九九九 | 免费一区二区三区 | 99精品视频在线观看免费播放 | 国产欧美精品一区二区三区四区 | 免费黄色大片在线观看 | 91高清观看| 99热久草 | 久久精品一区视频 | 成年免费观看视频 | 毛片大全免费看 | 亚洲一区二区不卡视频 | 久久久久一区 | 国产精品视频一区二区三区综合 | 久草在线视频看看 | 欧美精品电影一区 | 欧美大逼网 | 91免费在线视频 | 亚洲成人免费网站 | 国内精品久久久久久影视8 国产一区二区成人在线 | 国产成人在线免费观看视频 | 国产一级不卡毛片 |