TensorFlow中讀取圖像數據的三種方式

  本文面對三種常常遇到的情況,總結三種讀取數據的方式,分別用於處理單張圖片、大量圖片,和TFRecorder讀取方式。並且還補充了功能相近的tf函數。

1、處理單張圖片

  我們訓練完模型之後,常常要用圖片測試,有的時候,我們並不需要對很多圖像做測試,可能就是幾張甚至一張。這種情況下沒有必要用隊列機制。

import tensorflow as tf
import matplotlib.pyplot as plt

def read_image(file_name):
    img = tf.read_file(filename=file_name)     # 默認讀取格式為uint8
    print("img 的類型是",type(img));
    img = tf.image.decode_jpeg(img,channels=0) # channels 為1得到的是灰度圖,為0則按照圖片格式來讀
    return img

def main( ):
    with tf.device("/cpu:0"):
         # img_path是文件所在地址包括文件名稱,地址用相對地址或者絕對地址都行 
            img_path='./1.jpg'
            img=read_image(img_path)
            with tf.Session() as sess:
            image_numpy=sess.run(img)
            print(image_numpy)
            print(image_numpy.dtype)
            print(image_numpy.shape)
            plt.imshow(image_numpy)
            plt.show()

if __name__=="__main__":
    main()

"""
輸出結果為:

img 的類型是 <class 'tensorflow.python.framework.ops.Tensor'>
[[[196 219 209]
  [196 219 209]
  [196 219 209]
  ...

 [[ 71 106  42]
  [ 59  89  39]
  [ 34  63  19]
  ...
  [ 21  52  46]
  [ 15  45  43]
  [ 22  50  53]]]
uint8
(675, 1200, 3)
"""

   和tf.read_file用法相似的函數還有tf.gfile.FastGFile  tf.gfile.GFile,只是要指定讀取方式是’r’ 還是’rb’ 。

2、需要讀取大量圖像用於訓練

  這種情況就需要使用Tensorflow隊列機制。首先是獲得每張圖片的路徑,把他們都放進一個list裏面,然後用string_input_producer創建隊列,再用tf.WholeFileReader讀取。具體請看下例:

def get_image_batch(data_file,batch_size):
    data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
 
    #這個num_epochs函數在整個Graph是local Variable,所以在sess.run全局變量的時候也要加上局部變量。  
    filenames_queue=tf.train.string_input_producer(data_names,num_epochs=50,shuffle=True,capacity=512)
    reader=tf.WholeFileReader()
    _,img_bytes=reader.read(filenames_queue)
    image=tf.image.decode_png(img_bytes,channels=1)    #讀取的是什麼格式,就decode什麼格式
    #解碼成單通道的,並且獲得的結果的shape是[?, ?,1],也就是Graph不知道圖像的大小,需要set_shape
    image.set_shape([180,180,1])   #set到原本已知圖像的大小。或者直接通過tf.image.resize_images,tf.reshape()
    image=tf.image.convert_image_dtype(image,tf.float32)
    #預處理  下面的一句代碼可以換成自己想使用的預處理方式
    #image=tf.divide(image,255.0)   
    return tf.train.batch([image],batch_size) 

  這裏的date_file是指文件夾所在的路徑,不包括文件名。第一句是遍歷指定目錄下的文件名稱,存放到一個list中。當然這個做法有很多種方法,比如glob.glob,或者tf.train.match_filename_once

全部代碼如下:

import tensorflow as tf
import os
def read_image(data_file,batch_size):
    data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
    filenames_queue=tf.train.string_input_producer(data_names,num_epochs=5,shuffle=True,capacity=30)
    reader=tf.WholeFileReader()
    _,img_bytes=reader.read(filenames_queue)
    image=tf.image.decode_jpeg(img_bytes,channels=1)
    image=tf.image.resize_images(image,(180,180))

    image=tf.image.convert_image_dtype(image,tf.float32)
    return tf.train.batch([image],batch_size)

def main( ):
    img_path=r'F:\dataSet\WIDER\WIDER_train\images\6--Funeral'  #本地的一個數據集目錄,有足夠的圖像
    img=read_image(img_path,batch_size=10)
    image=img[0]  #取出每個batch的第一個數據
    print(image)
    init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
    with tf.Session() as sess:
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)
        try:
            while not coord.should_stop():
                print(image.shape)
        except tf.errors.OutOfRangeError:
            print('read done')
        finally:
            coord.request_stop()
        coord.join(threads)


if __name__=="__main__":
    main()

"""
輸出如下:
(180, 180, 1)
(180, 180, 1)
(180, 180, 1)
(180, 180, 1)
(180, 180, 1)
"""

  這段代碼可以說寫的很是規整了。注意到init裏面有對local變量的初始化,並且因為用到了隊列,當然要告訴電腦什麼時候隊列開始, tf.train.Coordinator 和 tf.train.start_queue_runners 就是兩個管理隊列的類,用法如程序所示。

  與 tf.train.string_input_producer相似的函數是 tf.train.slice_input_producer。 tf.train.slice_input_producer和tf.train.string_input_producer的第一個參數形式不一樣。等有時間再做一個二者比較的博客

 3、對TFRecorder解碼獲得圖像數據

  其實這塊和上一種方式差不多的,更重要的是怎麼生成TFRecorder文件,這一部分我會補充到另一篇博客上。

  仍然使用 tf.train.string_input_producer。

import tensorflow as tf
import matplotlib.pyplot as plt
import os
import cv2
import  numpy as np
import glob

def read_image(data_file,batch_size):
    files_path=glob.glob(data_file)
    queue=tf.train.string_input_producer(files_path,num_epochs=None)
    reader = tf.TFRecordReader()
    print(queue)
    _, serialized_example = reader.read(queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label_raw': tf.FixedLenFeature([], tf.string),
        })
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image = tf.cast(image, tf.float32)
    image.set_shape((12*12*3))
    label = tf.decode_raw(features['label_raw'], tf.float32)
    label.set_shape((2))
    # 預處理部分省略,大家可以自己根據需要添加
    return tf.train.batch([image,label],batch_size=batch_size,num_threads=4,capacity=5*batch_size)

def main( ):
    img_path=r'F:\python\MTCNN_by_myself\prepare_data\pnet*.tfrecords'  #本地的幾個tf文件
    img,label=read_image(img_path,batch_size=10)
    image=img[0]
    init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
    with tf.Session() as sess:
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)
        try:
            while not coord.should_stop():
                print(image.shape)
        except tf.errors.OutOfRangeError:
            print('read done')
        finally:
            coord.request_stop()
        coord.join(threads)


if __name__=="__main__":
    main()

  在read_image函數中,先使用glob函數獲得了存放tfrecord文件的列表,然後根據TFRecord文件是如何存的就如何parse,再set_shape;這裡有必要提醒下parse的方式。我們看到這裏用的是tf.decode_raw ,因為做TFRecord是將圖像數據string化了,數據是串行的,丟失了空間結果。從features中取出image和label的數據,這時就要用 tf.decode_raw  解碼,得到的結果當然也是串行的了,所以set_shape 成一個串行的,再reshape。這種方式是取決於你的編碼TFRecord方式的。

再舉一種例子:

reader=tf.TFRecordReader()
_,serialized_example=reader.read(file_name_queue)
features = tf.parse_single_example(serialized_example, features={
    'data': tf.FixedLenFeature([256,256], tf.float32), ###
    'label': tf.FixedLenFeature([], tf.int64),
    'id': tf.FixedLenFeature([], tf.int64)
})
img = features['data']
label =features['label']
id = features['id']

  這個時候就不需要任何解碼了。因為做TFRecord的方式就是直接把圖像數據append進去了。

參考鏈接:

  https://blog.csdn.net/qq_34914551/article/details/86286184

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理

【其他文章推薦】

※自行創業缺乏曝光? 網頁設計幫您第一時間規劃公司的形象門面

網頁設計一頭霧水該從何著手呢? 台北網頁設計公司幫您輕鬆架站!

※想知道最厲害的網頁設計公司"嚨底家"!

※幫你省時又省力,新北清潔一流服務好口碑

※別再煩惱如何寫文案,掌握八大原則!

※產品缺大量曝光嗎?你需要的是一流包裝設計!