Tensorflow 圖片數據增強(一): TFRecord 轉換

LUFOR129
7 min readJun 10, 2019

我上個月寫了一個簡單的畫作辨識Linebot,其中發現了一些有意思的東西與處理過程。怕之後忘記了,趕緊寫下來供未來檢視。

在處理大量圖片上,我第一步通常是會先將圖片轉換為TFRecord,TFRecord是Tensorflow官方通用的輸入格式。將圖片轉換為TFRecord後,除了能夠有效的紀錄數據,更重要的是能夠使用Tensorflow各項數據輸入API (包含圖片預處理、多執行序圖片輸入pipeline),讓數據輸入的效率提升,訓練速度更快。

想直接看程式碼可以到:

  1. 轉為 TFRecord

2. 讀取TFRecord

一、圖片轉換為TFRecord

TFRecord轉換相當簡單,只要將圖片依照TFRecord規定的格式進行轉換即可。TFRecord相關資料規定:

所有資料需要轉換為tf.train.Feature1. int value 轉換為tf.train.Feature(int64_list=tf.train.Int64List(value=value))
通常用於圖片label(onehot_encoded)、圖片width、圖片height、圖片channel
2. Bytes 轉換
tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
通常適用於圖片名稱、"圖片本身"

轉為TFRecord步驟為下:

  1. 建立 TFWriter = tf.python_io.TFRecordWriter(“file.tfrecords”)
  2. 將要保存資料轉為 tf.train.Feature
  3. 將所有資料合併為 tf.train.Features
  4. 將tf.train.Features 轉為tf.train.Example 這會對tf.train.Feature進行封裝
  5. TFWriter寫入

詳細轉換程式如下,可以依照您的需求做修改 :

值得注意的是,當你的圖片量龐大,一個TFRecord有可能會塞不下,此時可以拆分為多個TFRecord,詳細作法可以參考:

二、 讀取TFRecord

讀取TFRecord也很簡單,剛剛怎麼encode為TFRecord,則就怎麼去decode他。步驟如下:

1. 建立圖片輸入Queue

# TF檔
filename = './data/Train.tfrecords'
# 產生文件名隊列
filename_queue = tf.train.string_input_producer([filename],
shuffle=True,
num_epochs=3)
#由於可能不只一個TFRecords,因此用[filename]
#shuffle參數會隨機排多個TFRecords的順序
#num_pochs會決定一個TFRecords輸入多少次 (!!重要!!)

2. TFRecordReader() Read file_queue

# 數據讀取器
reader = tf.TFRecordReader()
key, serialized_example = reader.read(filename_queue)

3. 解析Example

img_features = tf.parse_single_example(
serialized_example,
features={'Label':tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([],tf.string),
'height':tf.FixedLenFeature([],tf.int64),
'width':tf.FixedLenFeature([],tf.int64),
'channel':tf.FixedLenFeature([],tf.int64)})
#根據當初如何轉為TFRecord個各項標籤(如圖片放在image_raw),反轉回去height = tf.cast(img_features['height'], tf.int64)
width = tf.cast(img_features['width'], tf.int64)
channel = tf.cast(img_features['channel'], tf.int64)
#還需要 tf.cast 轉換才可以獲得int格式的數字image = tf.decode_raw(img_features['image_raw'], tf.uint8)
image = tf.reshape(image, [height,width,channel])
label = tf.cast(img_features['Label'], tf.int64)
#圖片利用reshape可以轉變回原圖(224,224,3)。
#在Session 中Session.run([image,label])即可獲得圖片

4. 放入Session中

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
#建立多執行序處理輸入數據
coord = tf.train.Coordinator()

# 啟動文件隊列,開始讀取文件
threads = tf.train.start_queue_runners(coord=coord)
count = 0
try:
#印十張
while count<10:
# 這邊讀取
image_data, label_data = sess.run([image, label]

#打印出圖片與標籤
plt.imshow(image_data)
plt.show()
print(label_data)
count +=1

except tf.errors.OutOfRangeError:
print('Done!')
finally:
coord.request_stop()

coord.join(threads)

如此一來就能看到成功的還原了圖片:

三、參考資料

夏恩寫超好!!!

四、下一篇

--

--