tensorflow学习,string_input_producer和FixedLengthRecordReader的使用,将CIFAR-10数据集转成jpg图片(15)

上一节简要介绍了 CIFAR-10 数据集,此外,我们讨论了常规读大数据到内存方法的一些问题,并且介绍了 tensorflow 是如何解决的。本节,我们将利用 tensorflow 将 CIFAR-10 数据集导入内存,并且保存为 jpg 图片。

CIFAR-10 数据集的数据格式详解


CIFAR-10 数据集一共有 10 类物体(上节已介绍),每类 6000 张图片,一共 60000 张。其中 50000 张是训练集,10000 张是测试集。

数据集被分为 5 个训练 batch,和 1 个测试 batch,每个 batch 10000 张图片。测试 batch 是从每一类里随机挑选的 1000 张图片组成的,顺序是被打乱的。

CIFAR-10 数据集解压后,有如下文件:

打开官网后,CIFAR-10 数据集有 3 种格式

计划使用 Binary 版的,它的每一个 batch 的数据按照如下排列:

<1 x label><3072 x pixel>
...
<1 x label><3072 x pixel>

第一字节是第一张图片的标签值,数值范围是 0-9,表示 10 类。接下来的 3072 字节就是图片的像素值了。这样的每一个 3072 字节的都是 3 通道的 RGB 图片,前 1024 字节是 R 数据,中间的 1024 字节是 G 数据,最后 1024 字节是 B 数据。数据是按照行排列的,所以前 32 字节是图片第一行的 R 通道的数据。

python tensorflow 实战代码,将 CIFAR-10 转成 JPG


我们在工作目录新建文件夹 cifar10_data, 将 CIFAR-10 数据集解压到此,然后就可以创建文件名 list,然后使用 tf.train.string_input_producer 方法将其转换为 文件名队列(含义可参照上一节)。

# 创建文件名 list
for i in range(1,6):
    filenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'data_batch_%d.bin' % i)]
# 使用 tensorflow 将文件名 list 转成队列(queue)
filename_queue = tf.train.string_input_producer(filenames)

将 CIFAR-10 数据集的数据格式填入变量中,方便我们接下来的使用。

# 标签占一个字节
label_bytes = 1
# 图片尺寸 32x32x3
height = 32
width = 32
depth = 3   
# 一张图片字节数
image_bytes = height * width * depth
# 一帧数据包含一字节标签和 image_bytes 图片
record_bytes = label_bytes + image_bytes

因为每一张图片在 batch 中占有的都是固定长度,所以我们使用 tensorflow 固定长度的 reader

# 创建固定长度的数据 reader
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
key, value = reader.read(filename_queue)

然后就可以利用 reader 取数据了,按照上一部分介绍的数据格式转换成我们需要的数据。

# 读出的 value 是 string,现在转换为 uint8 型的向量
record_bytes = tf.decode_raw(value, tf.uint8)
# 第一字节表示 标签值,我们把它从 uint8 型转成 int32 型
label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
# 剩下的就是图片的数据了,我们把它的形状由 [深度*高(长)*宽] 转换成 [深度,高(长),宽] 
depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes],
                       [label_bytes + image_bytes]),
                        [depth, height, width])
# 将图片从 [深度,高(长),宽] 转成 [高(长),宽, 深度] 形状
uint8image = tf.transpose(depth_major, [1, 2, 0])
reshaped_image = tf.cast(uint8image, tf.float32)

这样我们就得到了 float32 数据类型的图片。接下来我们只需要创建一个 session,初始化变量后,就可以通过 run 方法,从 reshaped_image 读取图片了。每一次 run(reshaped_image) 都可以读取一张图片。

with tf.Session() as sess: 
    threads = tf.train.start_queue_runners(sess=sess)
    sess.run(tf.global_variables_initializer())
    for i in range(10):     # 这里仅提取 10 张图片作为示范
        image_array = sess.run(reshaped_image)
        scipy.misc.toimage(image_array).save('cifar10_data/raw/%d.jpg' % i)

我们将图片保存到工作目录的 cifar10_data/raw/ 文件夹里。

全部代码如下:

#coding: utf-8
import tensorflow as tf
import os
import scipy.misc

# 创建文件名 list
for i in range(1,6):
    filenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'data_batch_%d.bin' % i)]

# 使用 tensorflow 将文件名 list 转成队列(queue)
filename_queue = tf.train.string_input_producer(filenames)

# 标签占一个字节
label_bytes = 1
# 图片尺寸 32x32x3
height = 32
width = 32
depth = 3   

# 一张图片字节数
image_bytes = height * width * depth

# 一帧数据包含一字节标签和 image_bytes 图片
record_bytes = label_bytes + image_bytes

# 创建固定长度的数据 reader
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
key, value = reader.read(filename_queue)

# 读出的 value 是 string,现在转换为 uint8 型的向量
record_bytes = tf.decode_raw(value, tf.uint8)

# 第一字节表示 标签值,我们把它从 uint8 型转成 int32 型
label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

# 剩下的就是图片的数据了,我们把它的形状由 [深度*高(长)*宽] 转换成 [深度,高(长),宽] 
depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes],
                       [label_bytes + image_bytes]),
                        [depth, height, width])

# 将图片从 [深度,高(长),宽] 转成 [高(长),宽, 深度] 形状
uint8image = tf.transpose(depth_major, [1, 2, 0])

reshaped_image = tf.cast(uint8image, tf.float32)
with tf.Session() as sess: 
    threads = tf.train.start_queue_runners(sess=sess)
    sess.run(tf.global_variables_initializer())
    for i in range(10):
        image_array = sess.run(reshaped_image)
        scipy.misc.toimage(image_array).save('cifar10_data/raw/%d.jpg' % i)

运行脚本,发现成功了

阅读更多:   tensorflow
暂无评论
  1. […] 这部分的代码上一节说的非常清楚,其实这里就是将上一节的代码封装成一个函数,方便后面的复用。 […]

添加新评论

icon_redface.gificon_idea.gificon_cool.gif2016kuk.gificon_mrgreen.gif2016shuai.gif2016tp.gif2016db.gif2016ch.gificon_razz.gif2016zj.gificon_sad.gificon_cry.gif2016zhh.gificon_question.gif2016jk.gif2016bs.gificon_lol.gif2016qiao.gificon_surprised.gif2016fendou.gif2016ll.gif