-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
320 lines (270 loc) · 10.3 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
"""
Script/function for converting hdf5 b/p data to tensorflow tfrecords.
Creates a tfrecords file for each category in shapenet_selected.
"""
import os
import tensorflow as tf
import numpy as np
from ffd import get_data_dir
_string_dtypes = {np.uint8, np.bool}
_int_dtypes = {np.uint8, np.int8, np.int16, np.int32, np.int64}
_float_dtypes = {np.float16, np.float32, np.float64}
_dtypes = _int_dtypes.union(_float_dtypes).union(_string_dtypes)
tfrecords_dir = get_data_dir('tfrecords')
class LengthedGenerator(object):
"""Generator with an efficient, fixed length."""
def __init__(self, gen, gen_len):
self._gen = gen
self._len = gen_len
def __iter__(self):
return iter(self._gen)
def __len__(self):
return self._len
class ProxyBar(object):
"""
Proxy class for progress bars.
Does nothing. Useful for situations like:
```
bar = ProxyBar() if n is None else IncrementalBar(max=n)
bar.next()
bar.next()
bar.finish()
```
"""
def next(self):
pass
def finish(self):
pass
class FeatureSpec(object):
"""Specification for a tfrecords feature."""
def __init__(self, key, shape, dtype):
if not isinstance(shape, (list, tuple)) or not all(
[isinstance(s, int) for s in shape]):
raise ValueError('shape must be a list/tuple of ints')
if hasattr(dtype, 'as_numpy_dtype'):
dtype = dtype.as_numpy_dtype
if not isinstance(key, (str, unicode)):
raise ValueError('key must be a string/unicode')
self.key = key
self.shape = shape
if dtype not in _dtypes:
raise ValueError('dtype %s not in allowable dtypes' % str(dtype))
self.dtype = dtype
def _bytes_feature(value):
if isinstance(value, np.ndarray):
value = value.tostring()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
# if value.dtype != np.int64:
# value = value.astype(np.int64)
if isinstance(value, int):
value = [value]
elif isinstance(value, np.ndarray):
if len(value.shape) != 1:
value = np.reshape(value, (-1,))
value = [int(v) for v in value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _float32_feature(value):
# if value.dtype != np.float32:
# value = value.astype(np.float32)
if isinstance(value, float):
value = [value]
elif isinstance(value, np.ndarray):
if len(value.shape) != 1:
value = np.reshape(value, (-1,))
value = [float(v) for v in value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _get_proto_dtype(dtype):
if dtype in _string_dtypes:
return tf.string
elif dtype in _int_dtypes:
return tf.int64
elif dtype in _float_dtypes:
return tf.float32
else:
raise RuntimeError('spec dtype not supported')
def get_parse_feature(spec):
dtype = _get_proto_dtype(spec.dtype)
if -1 in spec.shape:
return tf.VarLenFeature(dtype)
else:
shape = [] if spec.dtype in _string_dtypes else spec.shape
return tf.FixedLenFeature(shape=shape, dtype=dtype)
def _feature_specs(feature_specs):
if isinstance(feature_specs, FeatureSpec):
feature_specs = feature_specs,
else:
if not hasattr(feature_specs, '__iter__'):
raise ValueError(
'feature_specs must be a `FeatureSpec` or an iterable of '
'`FeatureSpec`s')
if not hasattr(feature_specs, '__len__'):
feature_specs = tuple(feature_specs)
if not all([isinstance(spec, FeatureSpec) for spec in feature_specs]):
raise ValueError(
'feature_specs must be a `FeatureSpec` or an iterable of '
'`FeatureSpec`s')
return feature_specs
def _check_shape(instance_shape, spec_shape):
if len(instance_shape) != len(spec_shape):
raise ValueError(
'Shapes inconsistent: %s, %s'
% (str(instance_shape), str(spec_shape)))
for i, s in zip(instance_shape, spec_shape):
if s != -1 and i != s:
raise ValueError(
'Shapes inconsistent: %s, %s'
% (str(instance_shape), str(spec_shape)))
def _same_shape(instance_shape, spec_shape):
return len(instance_shape) == len(spec_shape) and all(
[i == s or i == -1 and s is None
for i, s in zip(instance_shape, spec_shape)])
def write_records(path, feature_specs, examples_fn):
"""
Write data generated by examples_fn to a tfrecords file at path.
Args:
path: path to save file to
feature_specs: a `FeatureSpec` or iterable of `FeatureSpec`s
examples_fn: function producing an iterable of examples, where each
example should be an iterable of ndarrays or a single ndarray if
feature_specs is a FeatureSpec of an
"""
from progress.bar import IncrementalBar
feature_specs = _feature_specs(feature_specs)
examples = examples_fn()
if hasattr(examples, '__len__'):
bar = IncrementalBar(max=len(examples))
else:
bar = ProxyBar()
feature_fns = []
keys = []
for spec in feature_specs:
if spec.dtype in _string_dtypes:
feature_fns.append(_bytes_feature)
elif spec.dtype in _float_dtypes:
feature_fns.append(_float32_feature)
elif spec.dtype in _int_dtypes:
feature_fns.append(_int64_feature)
else:
raise RuntimeError('Invalid dtype: %s' % str(spec.dtype))
keys.append(spec.key)
try:
folder = os.path.dirname(path)
if not os.path.isdir(folder):
os.makedirs(folder)
with tf.python_io.TFRecordWriter(path) as writer:
print('Creating tf_records: %s' % path)
for example in examples:
if isinstance(example, np.ndarray):
assert(len(feature_specs) == 1)
example = example,
ex = tf.train.Example()
features = tf.train.Features(
feature={
k: f(e) for k, f, e in zip(keys, feature_fns, example)
})
ex = tf.train.Example(features=features)
writer.write(ex.SerializeToString())
bar.next()
bar.finish()
except Exception:
print('Writing dataset failed. Deleting...')
if os.path.isfile(path):
os.remove(path)
raise
def load_dataset(paths, feature_specs):
"""
Create a dataset from tfrecord file(s) and parse the result.
Args:
paths: string or iterable of strings of addresses of tfrecords files
feature_specs: `FeatureSpec` of iterable of `FeatureSpec`s used to
parse the dataset. Should be the same to those used in
`write_records` (order may vary).
Returns:
dataset with elements corresponding to feature_specs
"""
if isinstance(paths, (str, unicode)):
paths = [paths]
for path in paths:
if not os.path.isfile(path):
raise ValueError('No tfrecords file at path: %s' % path)
return tf.data.TFRecordDataset(paths)
def get_parse_fn(feature_specs):
"""
Get a `parse_fn` to map raw datasets to meaningful data.
Example usage:
```
dataset = get_dataset(path, feature_specs, example_fn)
map_fn = get_parse_fn(feature_specs)
dataset = dataset.map(map_fn, num_parallel_calls=8)
```
"""
def parse_fn(example_proto):
features = {
spec.key: get_parse_feature(spec) for spec in feature_specs}
ret_features = []
parsed_features = tf.parse_single_example(example_proto, features)
for spec in feature_specs:
feature = parsed_features[spec.key]
if isinstance(feature, tf.SparseTensor):
# VarLenFeature?
default_value = np.nan if spec.dtype == np.float32 else 0
feature = tf.sparse_tensor_to_dense(
feature, default_value=default_value)
if spec.dtype in _string_dtypes:
feature = tf.decode_raw(feature, spec.dtype)
if feature.dtype.as_numpy_dtype != spec.dtype:
feature = tf.cast(feature, spec.dtype)
if not _same_shape(feature.shape, spec.shape):
feature = tf.reshape(feature, spec.shape)
ret_features.append(feature)
return ret_features
return parse_fn
class DatasetManager(object):
"""Class interface for writing and reading tfrecords."""
def __init__(self, feature_specs, path_fn, example_fn):
self._specs = feature_specs
if isinstance(path_fn, str):
self._path_fn = lambda: path_fn
else:
self._path_fn = path_fn
self._example_fn = example_fn
@property
def feature_specs(self):
return self._specs
def write_records(self, *args):
def example_fn():
return self._example_fn(*args)
write_records(
self._path_fn(*args), self._specs, example_fn)
def write_dataset(self, args_list, overwrite=False):
paths = []
for args in args_list:
path = self._path_fn(*args)
if not os.path.isfile(path) or overwrite:
def example_fn():
return self._example_fn(*args)
write_records(path, self._specs, example_fn)
paths.append(path)
return paths
def load_dataset(self, args_list):
paths = []
for args in args_list:
path = self._path_fn(*args)
if not os.path.isfile(path):
raise OSError('No file at path %s' % path)
paths.append(path)
return tf.data.TFRecordDataset(paths)
def get_dataset(self, args_list, overwrite=False):
paths = self.write_dataset(args_list, overwrite=overwrite)
return tf.data.TFRecordDataset(paths)
def _parse(self, specs, example_proto):
fn = get_parse_fn(specs)
ret = fn(example_proto)
return ret
def parse(self, example_proto):
if isinstance(self._specs, FeatureSpec):
ret, = self._parse([self._specs], example_proto)
else:
ret = self._parse(self._specs, example_proto)
return ret