直接看代碼例子,有詳細注釋!!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
|
import tensorflow as tf import numpy as np d = np.arange( 0 , 60 ).reshape([ 6 , 10 ]) # 將array轉化為tensor data = tf.data.Dataset.from_tensor_slices(d) # 從data數據集中按順序抽取buffer_size個樣本放在buffer中,然后打亂buffer中的樣本 # buffer中樣本個數不足buffer_size,繼續從data數據集中安順序填充至buffer_size, # 此時會再次打亂 data = data.shuffle(buffer_size = 3 ) # 每次從buffer中抽取4個樣本 data = data.batch( 4 ) # 將data數據集重復,其實就是2個epoch數據集 data = data.repeat( 2 ) # 構造獲取數據的迭代器 iters = data.make_one_shot_iterator() # 每次從迭代器中獲取一批數據 batch = iters.get_next() sess = tf.Session() sess.run(batch) # 數據集完成遍歷完之后,繼續抽取的話會報錯:OutOfRangeError |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
In [ 21 ]: d Out[ 21 ]: array([[ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ], [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ], [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ], [ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ], [ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ]]) In [ 22 ]: sess.run(batch) Out[ 22 ]: array([[ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ], [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ], [ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ]]) In [ 23 ]: sess.run(batch) Out[ 23 ]: array([[ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ], [ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ]]) |
從輸出結果可以看出:
shuffle是按順序將數據放入buffer里面的;
當repeat函數在shuffle之后的話,是將一個epoch的數據集抽取完畢,再進行下一個epoch的。
那么,當repeat函數在shuffle之前會怎么樣呢?如下:
1
2
3
4
5
|
data = data.repeat( 2 ) data = data.shuffle(buffer_size = 3 ) data = data.batch( 4 ) |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
In [ 25 ]: sess.run(batch) Out[ 25 ]: array([[ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ], [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ], [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ]]) In [ 26 ]: sess.run(batch) Out[ 26 ]: array([[ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ], [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ], [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ]]) In [ 27 ]: sess.run(batch) Out[ 27 ]: array([[ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ], [ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ], [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ], [ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ]]) |
可以看出,其實它就是先將數據集復制一遍,然后把兩個epoch當成同一個新的數據集,一直shuffle和batch下去。
以上這篇TensorFlow dataset.shuffle、batch、repeat的使用詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持服務器之家。
原文鏈接:https://blog.csdn.net/sgyuanshi/article/details/90183610