最近学习 tensroflow,用到了入门级的经典数据集 MNIST,MNIST 包含几万张 28x28 像素大小的手写数字。但是它的存储是以字节流的形式存储的,几万张图片存储在一个文件里。一直对其很好奇,本节即用 python 的 struct 模块处理字节流信息,结合 python 的 Image 模块,将 MNIST 中的手写数字图片提取出来。
MNIST 图片集的格式
要想从 MNIST 中提取出图片数据,首先要了解它的格式。这点其实官网介绍的非常清楚。这里把介绍图片格式的部分截取下来,因为测试集的数据量稍小(格式与训练集相同),所以我们以测试集为例:
可以看出
- 起始 32bit 是数据集的魔法数,校验用的,可以不管。
- 4 字节偏移处是图片的总数目 10000。
- 8 字节偏移处是图片的行,为 28。
- 12 字节偏移处是图片的列,也为 28。
- 从 16 字节开始,每一字节都是像素值。
像素值按行排列,从 0-255,0 表示白色,255 表示黑色,中间值是灰色。所以第一张图片的数据,是从数据集的第 16 字节开始的,到 16+28x28 结束。
python struct 模块
python 的 struct 模块主要用于处理字节流信息,最重要的三个函数是pack(), unpack(), calcsize()。
- pack(fmt, v1, v2, ...) 把数据封装成字符串
- unpack(fmt, string) 按照 fmt 解析字节流string,返回解析出来的数据
- calcsize(fmt) 返回(fmt)占用多少字节数
fmt 可为:
使用方法是放在fmt的第一个位置,就像'@5s6sif'。图出自:嘎啦。
python 的 struct 模块弥补了 python 处理底层字节流信息不方便的不足。什么是字节流信息呢?比如 TCP 传来一段 C 结构体形式的数据:
struct Header{
unsigned short id;
char[4] tag;
unsigned int version;
unsigned int count;
}
用 python 程序接收,接收到的其实就是一段二进制数字,python 没有办法很方便的使用 c 的结构体来解析这些数据,这时 struct 模块就派上用场了。
import struct
id, tag, version, count = struct.unpack("!H4s2I", s)
# !表示网络字节顺序,H 表示unsigned short的id,4s表示4字节长的字符串,2I 表示有两个# unsigned int类型的数据.
使用 python 的 struct 模块,提取出 MNIST 中的图片,并保存为 bmp 格式
这里还是以数据量较小的测试集为例。首先,我们把数据集读入内存:
filename = 'MNIST_data/t10k-images.idx3-ubyte'
fd = open(filename , 'rb')
buf = fd.read()
fd.close()
然后,读出数据集的头信息,
index = 0
magic, numImages , numRows , numColumns = struct.unpack_from('>IIII' , buf , index)
index += struct.calcsize('>IIII') # 计算下一次读的偏移
接着,就可以读出 28x28=784 字节的图片数据了。
im = struct.unpack_from('>784B' ,buf, index)
index += struct.calcsize('>784B')
得到数据后,使用 python 的 Image 模块,就可以很轻松的把数据保存为 bmp 图片了。
im = np.array(im,dtype='uint8')
im = im.reshape(28,28)
im=Image.fromarray(im)
im.save('images/1.bmp') # 这里保存到 image 文件夹里,名字暂取 1.bmp
运行之,在 images 文件夹里得到了 1.bmp 文件,打开,发现是期望的手写数字。
接下来,加个 for 循环,就可以把全部图片提取出来了,完整代码如下:
#encoding=utf8
import numpy as np
import struct
import Image
filename = 'MNIST_data/t10k-images.idx3-ubyte'
fd = open(filename , 'rb')
buf = fd.read()
fd.close()
index = 0
magic, numImages , numRows , numColumns = struct.unpack_from('>IIII' , buf , index)
index += struct.calcsize('>IIII')
for i in range(0, numImages):
im = struct.unpack_from('>784B' ,buf, index)
index += struct.calcsize('>784B')
im = np.array(im,dtype='uint8')
im = im.reshape(28,28)
im=Image.fromarray(im)
im.save('images/%d.bmp' % i)
运行之,发现 10000 张测试集的图片已经全部被我们提取并保存为 bmp 图片了。