神经网络识别手写数字:数据集长什么样
介绍
当我们开始学习编程的时候,第一件事往往是学习打印Hello World!。而MNIST是一个入门级的计算机视觉数据集就是深度学习中的Hello World,MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片:
![此处输入图片的描述][1]
它也包含每一张图片对应的标签,告诉我们这个是数字几。比如,上面这四张图片的标签分别是5,0,4,1。 数据集下载地址:[http://yann.lecun.com/exdb/mnist/][2] 数据集内容:
- train-images-idx3-ubyte 训练数据图像 (60,000)
- train-labels-idx1-ubyte 训练数据label
- t10k-images-idx3-ubyte 测试数据图像 (10,000)
- t10k-labels-idx1-ubyte 测试数据label
每一张图片包含28像素X28像素。我们可以用一个数字数组来表示这张图片:
![此处输入图片的描述][3]
我们的任务就是训练一个机器学习模型用于预测图片里面的数字。
问题
一切看起来很美好,但当我下载完数据集后一看,傻眼了:
![此处输入图片的描述][4]
说好的图像呢,说好的标签呢。原来,为了方便解析,数据集被保存为二进制的格式,并不是我们熟悉的图片格式和文本格式,因此我们不能直观的看到它到底长什么样。对于初学者来说,神经网络已经是一个很玄的东西了,现在连数据集都是个黑盒子,让人怎么入门呢。本着眼见为实的精神,我决定使用Python把数据集还原为原始的格式。
解决
主要使用struct读取二进制文件,然后保存为相应格式的文件。
使用到的包:
- struct
- numpy
- PIL
- os
- shutil
转换图片
1# 转换图片
2def read_image(file_name, output_name, num=-1):
3 if os.path.exists(output_name):
4 shutil.rmtree(output_name)
5 os.makedirs(output_name)
6 # 二进制的形式读入
7 filename = file_name
8 binfile = open(filename, 'rb')
9 buf = binfile.read()
10 # 大端法读入 4 个 unsigned int32
11 # struct 用法参见网站 http://www.cnblogs.com/gala/archive/2011/09/22/2184801.html
12 index = 0
13 magic, num_images, num_rows, num_columns = struct.unpack_from('>IIII', buf, index)
14 index += struct.calcsize('>IIII')
15 num_images = num_images if num == -1 else num
16 # 将每张图片按照格式存储到对应位置
17 for image in range(0, num_images):
18 im = struct.unpack_from('>784B', buf, index)
19 index += struct.calcsize('>784B')
20 # 这里注意 Image 对象的 dtype 是 uint8,需要转换
21 im = np.array(im, dtype='uint8')
22 im = im.reshape(28, 28)
23 im = Image.fromarray(im)
24 im.save('%s/%s_%s.bmp' % (output_name, output_name, image), 'bmp')
因为训练图片有60000张,于是加入一个参数指定转换多少张,并不一定要全部转换。
1read_image('data/train-images.idx3-ubyte', 'train', 10)
执行上面的代码,会在当前目录新建一个train目录,里面有前10张图片:
![此处输入图片的描述][5]
转换标签
1# 转换 label
2def read_label(filename, save_filename):
3 f = open(filename, 'rb')
4 index = 0
5 buf = f.read()
6 f.close()
7 magic, labels = struct.unpack_from('>II', buf, index)
8 index += struct.calcsize('>II')
9 label_arr = [0] * labels
10 for x in range(0, labels):
11 label_arr[x] = int(struct.unpack_from('>B', buf, index)[0])
12 index += struct.calcsize('>B')
13 save = open(save_filename, 'w')
14 save.write(','.join(map(lambda n: str(n), label_arr)))
15 save.write('\n')
16 save.close()
17 print('save labels success')
标签转换相对比较简单,只需读取二进制文件并转存为txt即可。
1read_label('data/train-labels.idx1-ubyte', 'train_labels.txt')
执行上面的代码,会在当前目录生成一个train_labels.txt,里面就是所有图片对应的数字:
![此处输入图片的描述][6]
看到图片和标签,可以直观的看到图片和数字是如何对应的,非常有助于理解后面机器学习的算法。
[全部代码][7]
1import struct
2import numpy as np
3from PIL import Image
4import os
5import shutil
6
7
8# 转换图片
9def read_image(file_name, output_name, num=-1):
10 if os.path.exists(output_name):
11 shutil.rmtree(output_name)
12 os.makedirs(output_name)
13 # 二进制的形式读入
14 filename = file_name
15 binfile = open(filename, 'rb')
16 buf = binfile.read()
17 # 大端法读入 4 个 unsigned int32
18 # struct 用法参见网站 http://www.cnblogs.com/gala/archive/2011/09/22/2184801.html
19 index = 0
20 magic, num_images, num_rows, num_columns = struct.unpack_from('>IIII', buf, index)
21 index += struct.calcsize('>IIII')
22 num_images = num_images if num == -1 else num
23 # 将每张图片按照格式存储到对应位置
24 for image in range(0, num_images):
25 im = struct.unpack_from('>784B', buf, index)
26 index += struct.calcsize('>784B')
27 # 这里注意 Image 对象的 dtype 是 uint8,需要转换
28 im = np.array(im, dtype='uint8')
29 im = im.reshape(28, 28)
30 im = Image.fromarray(im)
31 im.save('%s/%s_%s.bmp' % (output_name, output_name, image), 'bmp')
32
33
34# 转换 label
35def read_label(filename, save_filename):
36 f = open(filename, 'rb')
37 index = 0
38 buf = f.read()
39 f.close()
40 magic, labels = struct.unpack_from('>II', buf, index)
41 index += struct.calcsize('>II')
42 label_arr = [0] * labels
43 for x in range(0, labels):
44 label_arr[x] = int(struct.unpack_from('>B', buf, index)[0])
45 index += struct.calcsize('>B')
46 save = open(save_filename, 'w')
47 save.write(','.join(map(lambda n: str(n), label_arr)))
48 save.write('\n')
49 save.close()
50 print('save labels success')
51
52
53if __name__ == '__main__':
54 # read_image('data/t10k-images.idx3-ubyte', 'test', 10)
55 read_image('data/train-images.idx3-ubyte', 'train', 10)
56 # read_label('data/t10k-labels.idx1-ubyte', 'test_labels.txt')
57 read_label('data/train-labels.idx1-ubyte', 'train_labels.txt')
58 pass
59
60``
61
62
63 [1]: https://files.ciphermagic.cn/mnist1.png
64 [2]: http://yann.lecun.com/exdb/mnist/
65 [3]: https://files.ciphermagic.cn/mnist3.png
66 [4]: https://files.ciphermagic.cn/mnist4.png
67 [5]: https://files.ciphermagic.cn/mnist5.png
68 [6]: https://files.ciphermagic.cn/mnist6.png
69 [7]: https://github.com/ciphermagic/python-learn/blob/master/tensorflow_learn/mnist/transform.py
