def get_batches(self, batch_size, num_batches=None, shuffle=False, cluster=False): """ :param batch_size: :param num_batches: :param shuffle: :param cluster: cluster examples by their lengths; this might give performance boost (i.e. faster training). :return: """ num_batches_per_epoch = int(math.ceil(self.num_examples / batch_size)) if num_batches is None: num_batches = num_batches_per_epoch num_epochs = int(math.ceil(num_batches / num_batches_per_epoch)) if shuffle: random_idxs = random.sample(self.valid_idxs, len(self.valid_idxs)) if cluster: sorted_idxs = sorted(random_idxs, key=self._sort_key) sorted_grouped = lambda: list(grouper(sorted_idxs, batch_size)) grouped = lambda: random.sample(sorted_grouped(), num_batches_per_epoch) else: random_grouped = lambda: list(grouper(random_idxs, batch_size)) grouped = random_grouped else: raw_grouped = lambda: list(grouper(self.valid_idxs, batch_size)) grouped = raw_grouped batch_idx_tuples = itertools.chain.from_iterable(grouped() for _ in range(num_epochs)) for _ in range(num_batches): batch_idxs = tuple(i for i in next(batch_idx_tuples) if i is not None) batch_data = self.get_by_idxs(batch_idxs) shared_batch_data = {} for key, val in batch_data.items(): if key.startswith('*'): assert self.shared is not None shared_key = key[1:] shared_batch_data[shared_key] = [index(self.shared[shared_key], each) for each in val] batch_data.update(shared_batch_data) batch_ds = DataSet(batch_data, self.data_type, shared=self.shared) yield batch_idxs, batch_ds
def get_multi_batches(self, batch_size, num_batches_per_step, num_steps=None, shuffle=False, cluster=False): batch_size_per_step = batch_size * num_batches_per_step batches = self.get_batches(batch_size_per_step, num_batches=num_steps, shuffle=shuffle, cluster=cluster) multi_batches = (tuple(zip(grouper(idxs, batch_size, shorten=True, num_groups=num_batches_per_step), data_set.divide(num_batches_per_step))) for idxs, data_set in batches) return multi_batches
def divide(self, integer): batch_size = int(math.ceil(self.num_examples / integer)) idxs_gen = grouper(self.valid_idxs, batch_size, shorten=True, num_groups=integer) data_gen = (self.get_by_idxs(idxs) for idxs in idxs_gen) ds_tuple = tuple(DataSet(data, self.data_type, shared=self.shared) for data in data_gen) return ds_tuple
from my.tensorflow import grouper import random import itertools idxs = random.sample(range(5), 5) print(idxs) grouped = list(grouper(idxs, 2)) print(grouped) random_grouped = lambda: random.sample(grouped, 3) print(random_grouped) ''' batch_idx_tuples = itertools.chain.from_iterable(random_grouped() for _ in range(5)) for _ in range(5): batch_idxs = tuple(i for i in next(batch_idx_tuples) if i is not None) print("BATCH",batch_idxs) ''' a = [[1, 2], [3, 4], [5, 6]] grouped1 = list(grouper(a, 2)) print("NEW", grouped1)