forked from mpezeshki/ladder_network
-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
129 lines (102 loc) · 4.22 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import logging
import numpy
from fuel.datasets import MNIST
from fuel.schemes import ShuffledScheme
from fuel.streams import DataStream
from fuel.transformers import Transformer
from picklable_itertools import cycle, imap
from utils import AttributeDict
logger = logging.getLogger('datasets')
def make_datastream(dataset, indices, batch_size,
n_labeled=None, n_unlabeled=None,
balanced_classes=True, whiten=None, cnorm=None,
scheme=ShuffledScheme):
# Ensure each label is equally represented
logger.info('Balancing %d labels...' % n_labeled)
all_data = dataset.data_sources[dataset.sources.index('targets')]
y = all_data.flatten()[indices]
n_classes = y.max() + 1
assert n_labeled % n_classes == 0
n_from_each_class = n_labeled / n_classes
i_labeled = []
for c in range(n_classes):
i = (indices[y == c])[:n_from_each_class]
i_labeled += list(i)
# Get unlabeled indices
i_unlabeled = indices[:n_unlabeled]
ds = CombinedDataStream(
data_stream_labeled=MyTransformer(
DataStream(dataset),
iteration_scheme=scheme(i_labeled, batch_size)),
data_stream_unlabeled=MyTransformer(
DataStream(dataset),
iteration_scheme=scheme(i_unlabeled, batch_size))
)
return ds
class MyTransformer(Transformer):
def __init__(self, data_stream, iteration_scheme, **kwargs):
super(MyTransformer, self).__init__(data_stream,
iteration_scheme=iteration_scheme,
**kwargs)
data = data_stream.get_data(slice(data_stream.dataset.num_examples))
shape = data[0].shape
self.data = [data[0].reshape(shape[0], -1)]
self.data += [data[1].flatten()]
def get_data(self, request=None):
return (s[request] for s in self.data)
class CombinedDataStream(Transformer):
def __init__(self, data_stream_labeled, data_stream_unlabeled, **kwargs):
super(Transformer, self).__init__(**kwargs)
self.ds_labeled = data_stream_labeled
self.ds_unlabeled = data_stream_unlabeled
# Rename the sources for clarity
self.ds_labeled.sources = ('features_labeled', 'targets_labeled')
# Hide the labels.
self.ds_unlabeled.sources = ('features_unlabeled',)
@property
def sources(self):
if hasattr(self, '_sources'):
return self._sources
return self.ds_labeled.sources + self.ds_unlabeled.sources
@sources.setter
def sources(self, value):
self._sources = value
def close(self):
self.ds_labeled.close()
self.ds_unlabeled.close()
def reset(self):
self.ds_labeled.reset()
self.ds_unlabeled.reset()
def next_epoch(self):
self.ds_labeled.next_epoch()
self.ds_unlabeled.next_epoch()
def get_epoch_iterator(self, **kwargs):
unlabeled = self.ds_unlabeled.get_epoch_iterator(**kwargs)
labeled = self.ds_labeled.get_epoch_iterator(**kwargs)
assert type(labeled) == type(unlabeled)
return imap(self.mergedicts, cycle(labeled), unlabeled)
def mergedicts(self, x, y):
return dict(list(x.items()) + list(y.items()))
def get_mnist_data_dict(unlabeled_samples, valid_set_size, test_set=False):
train_set = MNIST(("train",))
# Make sure the MNIST data is in right format
train_set.data_sources = (
(train_set.data_sources[0] / 255.).astype(numpy.float32),
train_set.data_sources[1])
# Take all indices and permutate them
all_ind = numpy.arange(train_set.num_examples)
rng = numpy.random.RandomState(seed=1)
rng.shuffle(all_ind)
data = AttributeDict()
# Choose the training set
data.train = train_set
data.train_ind = all_ind[:unlabeled_samples]
# Then choose validation set from the remaining indices
data.valid = train_set
data.valid_ind = numpy.setdiff1d(all_ind, data.train_ind)[:valid_set_size]
logger.info('Using %d examples for validation' % len(data.valid_ind))
# Only touch test data if requested
if test_set:
data.test = MNIST(("test",))
data.test_ind = numpy.arange(data.test.num_examples)
return data