forked from davidtob/timitdataset
-
Notifications
You must be signed in to change notification settings - Fork 0
/
iteration.py
159 lines (128 loc) · 4.39 KB
/
iteration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import numpy
import sys
from theano import config
from pylearn2.space import CompositeSpace
from pylearn2.utils import safe_zip
from pylearn2.utils.data_specs import is_flat_specs
from pylearn2.utils.iteration import *
class FiniteDatasetIterator(object):
"""
A thin wrapper around one of the mode iterators.
"""
def __init__(self, dataset, subset_iterator,
data_specs=None, return_tuple=False, convert=None):
"""
.. todo::
WRITEME
"""
self._data_specs = data_specs
self._dataset = dataset
self._subset_iterator = subset_iterator
self._return_tuple = return_tuple
assert is_flat_specs(data_specs)
dataset_space, dataset_source = self._dataset.get_data_specs()
assert is_flat_specs((dataset_space, dataset_source))
# the dataset's data spec is either a single (space, source) pair,
# or a pair of (non-nested CompositeSpace, non-nested tuple).
# We could build a mapping and call flatten(..., return_tuple=True)
# but simply putting spaces, sources and data in tuples is simpler.
if not isinstance(dataset_source, tuple):
dataset_source = (dataset_source,)
if not isinstance(dataset_space, CompositeSpace):
dataset_sub_spaces = (dataset_space,)
else:
dataset_sub_spaces = dataset_space.components
assert len(dataset_source) == len(dataset_sub_spaces)
space, source = data_specs
if not isinstance(source, tuple):
source = (source,)
if not isinstance(space, CompositeSpace):
sub_spaces = (space,)
else:
sub_spaces = space.components
assert len(source) == len(sub_spaces)
self._source = source
if convert is None:
self._convert = [None for s in source]
else:
assert len(convert) == len(source)
self._convert = convert
for i, (so, sp) in enumerate(safe_zip(source, sub_spaces)):
idx = dataset_source.index(so)
dspace = dataset_sub_spaces[idx]
init_fn = self._convert[i]
fn = init_fn
# If there is an init_fn, it is supposed to take
# care of the formatting, and it should be an error
# if it does not. If there was no init_fn, then
# the iterator will try to format using the generic
# space-formatting functions.
if init_fn is None:
# "dspace" and "sp" have to be passed as parameters
# to lambda, in order to capture their current value,
# otherwise they would change in the next iteration
# of the loop.
if fn is None:
fn = (lambda batch, dspace=dspace, sp=sp:
dspace.np_format_as(batch, sp))
else:
fn = (lambda batch, dspace=dspace, sp=sp, fn_=fn:
dspace.np_format_as(fn_(batch), sp))
self._convert[i] = fn
def __iter__(self):
"""
.. todo::
WRITEME
"""
return self
def next(self):
"""
.. todo::
WRITEME
"""
next_index = self._subset_iterator.next()
# TODO: handle fancy-index copies by allocating a buffer and
# using numpy.take()
rval = tuple(
fn(batch) if fn else batch for batch, fn in
safe_zip(self._dataset.get(self._source, next_index),
self._convert)
)
if not self._return_tuple and len(rval) == 1:
rval, = rval
return rval
@property
def batch_size(self):
"""
.. todo::
WRITEME
"""
return self._subset_iterator.batch_size
@property
def num_batches(self):
"""
.. todo::
WRITEME
"""
return self._subset_iterator.num_batches
@property
def num_examples(self):
"""
.. todo::
WRITEME
"""
return self._subset_iterator.num_examples
@property
def uneven(self):
"""
.. todo::
WRITEME
"""
return self._subset_iterator.uneven
@property
def stochastic(self):
"""
.. todo::
WRITEME
"""
return self._subset_iterator.stochastic