我上個月寫了一個簡單的畫作辨識Linebot,其中發現了一些有意思的東西與處理過程。怕之後忘記了,趕緊寫下來供未來檢視。
在處理大量圖片上,我第一步通常是會先將圖片轉換為TFRecord,TFRecord是Tensorflow官方通用的輸入格式。將圖片轉換為TFRecord後,除了能夠有效的紀錄數據,更重要的是能夠使用Tensorflow各項數據輸入API (包含圖片預處理、多執行序圖片輸入pipeline),讓數據輸入的效率提升,訓練速度更快。
想直接看程式碼可以到:
- 轉為 TFRecord
2. 讀取TFRecord
一、圖片轉換為TFRecord
TFRecord轉換相當簡單,只要將圖片依照TFRecord規定的格式進行轉換即可。TFRecord相關資料規定:
所有資料需要轉換為tf.train.Feature1. int value 轉換為tf.train.Feature(int64_list=tf.train.Int64List(value=value))
通常用於圖片label(onehot_encoded)、圖片width、圖片height、圖片channel2. Bytes 轉換
tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
通常適用於圖片名稱、"圖片本身"
轉為TFRecord步驟為下:
- 建立 TFWriter = tf.python_io.TFRecordWriter(“file.tfrecords”)
- 將要保存資料轉為 tf.train.Feature
- 將所有資料合併為 tf.train.Features
- 將tf.train.Features 轉為tf.train.Example 這會對tf.train.Feature進行封裝
- TFWriter寫入
詳細轉換程式如下,可以依照您的需求做修改 :
值得注意的是,當你的圖片量龐大,一個TFRecord有可能會塞不下,此時可以拆分為多個TFRecord,詳細作法可以參考:
二、 讀取TFRecord
讀取TFRecord也很簡單,剛剛怎麼encode為TFRecord,則就怎麼去decode他。步驟如下:
1. 建立圖片輸入Queue
# TF檔
filename = './data/Train.tfrecords'# 產生文件名隊列
filename_queue = tf.train.string_input_producer([filename],
shuffle=True,
num_epochs=3)
#由於可能不只一個TFRecords,因此用[filename]
#shuffle參數會隨機排多個TFRecords的順序
#num_pochs會決定一個TFRecords輸入多少次 (!!重要!!)
2. TFRecordReader() Read file_queue
# 數據讀取器
reader = tf.TFRecordReader()
key, serialized_example = reader.read(filename_queue)
3. 解析Example
img_features = tf.parse_single_example(
serialized_example,
features={'Label':tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([],tf.string),
'height':tf.FixedLenFeature([],tf.int64),
'width':tf.FixedLenFeature([],tf.int64),
'channel':tf.FixedLenFeature([],tf.int64)})#根據當初如何轉為TFRecord個各項標籤(如圖片放在image_raw),反轉回去height = tf.cast(img_features['height'], tf.int64)
width = tf.cast(img_features['width'], tf.int64)
channel = tf.cast(img_features['channel'], tf.int64)#還需要 tf.cast 轉換才可以獲得int格式的數字image = tf.decode_raw(img_features['image_raw'], tf.uint8)
image = tf.reshape(image, [height,width,channel])
label = tf.cast(img_features['Label'], tf.int64)#圖片利用reshape可以轉變回原圖(224,224,3)。
#在Session 中Session.run([image,label])即可獲得圖片
4. 放入Session中
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) #建立多執行序處理輸入數據
coord = tf.train.Coordinator()
# 啟動文件隊列,開始讀取文件
threads = tf.train.start_queue_runners(coord=coord) count = 0
try:
#印十張
while count<10:
# 這邊讀取
image_data, label_data = sess.run([image, label]
#打印出圖片與標籤
plt.imshow(image_data)
plt.show()
print(label_data)
count +=1
except tf.errors.OutOfRangeError:
print('Done!') finally:
coord.request_stop()
coord.join(threads)
如此一來就能看到成功的還原了圖片:
三、參考資料
夏恩寫超好!!!