Ejemplo n.º 1
0
    def get_batches(self, batch_size, num_batches=None, shuffle=False, cluster=False):
        """

        :param batch_size:
        :param num_batches:
        :param shuffle:
        :param cluster: cluster examples by their lengths; this might give performance boost (i.e. faster training).
        :return:
        """
        num_batches_per_epoch = int(math.ceil(self.num_examples / batch_size))
        if num_batches is None:
            num_batches = num_batches_per_epoch
        num_epochs = int(math.ceil(num_batches / num_batches_per_epoch))

        if shuffle:
            random_idxs = random.sample(self.valid_idxs, len(self.valid_idxs))
            if cluster:
                sorted_idxs = sorted(random_idxs, key=self._sort_key)
                sorted_grouped = lambda: list(grouper(sorted_idxs, batch_size))
                grouped = lambda: random.sample(sorted_grouped(), num_batches_per_epoch)
            else:
                random_grouped = lambda: list(grouper(random_idxs, batch_size))
                grouped = random_grouped
        else:
            raw_grouped = lambda: list(grouper(self.valid_idxs, batch_size))
            grouped = raw_grouped

        batch_idx_tuples = itertools.chain.from_iterable(grouped() for _ in range(num_epochs))
        for _ in range(num_batches):
            batch_idxs = tuple(i for i in next(batch_idx_tuples) if i is not None)
            batch_data = self.get_by_idxs(batch_idxs)
            shared_batch_data = {}
            for key, val in batch_data.items():
                if key.startswith('*'):
                    assert self.shared is not None
                    shared_key = key[1:]
                    shared_batch_data[shared_key] = [index(self.shared[shared_key], each) for each in val]
            batch_data.update(shared_batch_data)

            batch_ds = DataSet(batch_data, self.data_type, shared=self.shared)
            yield batch_idxs, batch_ds
Ejemplo n.º 2
0
 def get_multi_batches(self, batch_size, num_batches_per_step, num_steps=None, shuffle=False, cluster=False):
     batch_size_per_step = batch_size * num_batches_per_step
     batches = self.get_batches(batch_size_per_step, num_batches=num_steps, shuffle=shuffle, cluster=cluster)
     multi_batches = (tuple(zip(grouper(idxs, batch_size, shorten=True, num_groups=num_batches_per_step),
                      data_set.divide(num_batches_per_step))) for idxs, data_set in batches)
     return multi_batches
Ejemplo n.º 3
0
 def divide(self, integer):
     batch_size = int(math.ceil(self.num_examples / integer))
     idxs_gen = grouper(self.valid_idxs, batch_size, shorten=True, num_groups=integer)
     data_gen = (self.get_by_idxs(idxs) for idxs in idxs_gen)
     ds_tuple = tuple(DataSet(data, self.data_type, shared=self.shared) for data in data_gen)
     return ds_tuple
Ejemplo n.º 4
0
from my.tensorflow import grouper
import random
import itertools

idxs = random.sample(range(5), 5)
print(idxs)
grouped = list(grouper(idxs, 2))
print(grouped)
random_grouped = lambda: random.sample(grouped, 3)
print(random_grouped)
'''
batch_idx_tuples = itertools.chain.from_iterable(random_grouped() for _ in range(5))
for _ in range(5):
    batch_idxs = tuple(i for i in next(batch_idx_tuples) if i is not None)
    print("BATCH",batch_idxs)
'''
a = [[1, 2], [3, 4], [5, 6]]
grouped1 = list(grouper(a, 2))
print("NEW", grouped1)