/
utils.py
180 lines (137 loc) · 5.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import numpy as np
from chainer.dataset import download
import os
import gzip
import numpy
import struct
import six
import chainer
from chainer.datasets import TupleDataset
import chainer.datasets.mnist as mnist
def get_emnist(withlabel=True, ndim=1, scale=1., dtype=numpy.float32,
label_dtype=numpy.int32, rgb_format=False):
train_raw = _retrieve_emnist_training()
train = mnist._preprocess_mnist(train_raw, withlabel, ndim, scale, dtype,
label_dtype, rgb_format)
test_raw = _retrieve_emnist_test()
test = mnist._preprocess_mnist(test_raw, withlabel, ndim, scale, dtype,
label_dtype, rgb_format)
return train, test
def _retrieve_emnist_training():
archives = ['bin/emnist-letters-train-images-idx3-ubyte.gz',
'bin/emnist-letters-train-labels-idx1-ubyte.gz']
return _retrieve_emnist('em_train.npz', archives)
def _retrieve_emnist_test():
archives = ['bin/emnist-letters-test-images-idx3-ubyte.gz',
'bin/emnist-letters-test-labels-idx1-ubyte.gz']
return _retrieve_emnist('em_test.npz', archives)
def _retrieve_emnist(name, archives):
# the path to store the cached file to
root = download.get_dataset_directory('pfnet/chainer/emnist')
path = os.path.join(root, name)
return download.cache_or_load_file(
path, lambda path: _make_npz(path, archives), numpy.load)
def _make_npz(path,archives):
x_url, y_url = archives
with gzip.open(x_url, 'rb') as fx, gzip.open(y_url, 'rb') as fy:
fx.read(4)
fy.read(4)
N, = struct.unpack('>i', fx.read(4))
if N != struct.unpack('>i', fy.read(4))[0]:
raise RuntimeError('wrong pair of EMNIST images and labels')
fx.read(8)
x = numpy.empty((N, 784), dtype=numpy.uint8)
y = numpy.empty(N, dtype=numpy.uint8)
for i in six.moves.range(N):
y[i] = ord(fy.read(1))
for j in six.moves.range(784):
x[i, j] = ord(fx.read(1))
numpy.savez_compressed(path, x=x, y=y)
return {'x': x, 'y': y}
def get_mnist(n_train=100, n_test=100, n_dim=1, with_label=True, classes = None):
"""
:param n_train: nr of training examples per class
:param n_test: nr of test examples per class
:param n_dim: 1 or 3 (for convolutional input)
:param with_label: whether or not to also provide labels
:param classes: if not None, then it selects only those classes, e.g. [0, 1]
:return:
"""
train_data, test_data = chainer.datasets.get_mnist(ndim=n_dim, withlabel=with_label)
if not classes:
classes = np.arange(10)
n_classes = len(classes)
if with_label:
for d in range(2):
if d==0:
data = train_data._datasets[0]
labels = train_data._datasets[1]
n = n_train
else:
data = test_data._datasets[0]
labels = test_data._datasets[1]
n = n_test
for i in range(n_classes):
lidx = np.where(labels == classes[i])[0][:n]
if i==0:
idx = lidx
else:
idx = np.hstack([idx,lidx])
L = np.concatenate([i*np.ones(n) for i in np.arange(n_classes)]).astype('int32')
if d==0:
train_data = TupleDataset(data[idx],L)
else:
test_data = TupleDataset(data[idx],L)
else:
tmp1, tmp2 = chainer.datasets.get_mnist(ndim=n_dim,withlabel=True)
for d in range(2):
if d == 0:
data = train_data
labels = tmp1._datasets[1]
n = n_train
else:
data = test_data
labels = tmp2._datasets[1]
n = n_test
for i in range(n_classes):
lidx = np.where(labels == classes[i])[0][:n]
if i == 0:
idx = lidx
else:
idx = np.hstack([idx, lidx])
if d == 0:
train_data = data[idx]
else:
test_data = data[idx]
return train_data, test_data
# Custom iterator
class RandomIterator(object):
"""
Generates random subsets of data
"""
def __init__(self, data, batch_size=1):
"""
Args:
data (TupleDataset):
batch_size (int):
Returns:
list of batches consisting of (input, output) pairs
"""
self.data = data
self.idx = 0
self.batch_size = batch_size
self.n_batches = len(self.data) // batch_size
def __iter__(self):
self.idx = -1
self._order = np.random.permutation(len(self.data))[:(self.n_batches * self.batch_size)]
return self
def next(self):
self.idx += 1
if self.idx == self.n_batches:
raise StopIteration
i = self.idx * self.batch_size
# handles unlabeled and labeled data
if isinstance(self.data, np.ndarray):
return self.data[self._order[i:(i + self.batch_size)]]
else:
return list(self.data[self._order[i:(i + self.batch_size)]])