def get_batches(self, batch_size, num_batches=None, shuffle=False, cluster=False): """ :param batch_size: :param num_batches: :param shuffle: :param cluster: cluster examples by their lengths; this might give performance boost (i.e. faster training). :return: """ num_batches_per_epoch = int(math.ceil(self.num_examples / batch_size)) if num_batches is None: num_batches = num_batches_per_epoch num_epochs = int(math.ceil(num_batches / num_batches_per_epoch)) if shuffle: random_idxs = random.sample(self.valid_idxs, len(self.valid_idxs)) if cluster: sorted_idxs = sorted(random_idxs, key=self._sort_key) sorted_grouped = lambda: list(grouper(sorted_idxs, batch_size)) grouped = lambda: random.sample(sorted_grouped(), num_batches_per_epoch) else: random_grouped = lambda: list(grouper(random_idxs, batch_size)) grouped = random_grouped else: raw_grouped = lambda: list(grouper(self.valid_idxs, batch_size)) grouped = raw_grouped batch_idx_tuples = itertools.chain.from_iterable( grouped() for _ in range(num_epochs)) for _ in range(num_batches): batch_idxs = tuple(i for i in next(batch_idx_tuples) if i is not None) batch_data = self.get_by_idxs(batch_idxs) shared_batch_data = {} for key, val in batch_data.items(): if key.startswith('*'): assert self.shared is not None shared_key = key[1:] shared_batch_data[shared_key] = [ index(self.shared[shared_key], each) for each in val ] batch_data.update(shared_batch_data) batch_ds = DataSet(batch_data, self.data_type, shared=self.shared) yield batch_idxs, batch_ds
def divide(self, integer): batch_size = int(math.ceil(self.num_examples / integer)) idxs_gen = grouper(self.valid_idxs, batch_size, shorten=True, num_groups=integer) data_gen = (self.get_by_idxs(idxs) for idxs in idxs_gen) ds_tuple = tuple( DataSet(data, self.data_type, shared=self.shared) for data in data_gen) return ds_tuple
def get_multi_batches(self, batch_size, num_batches_per_step, num_steps=None, shuffle=False, cluster=False): batch_size_per_step = batch_size * num_batches_per_step batches = self.get_batches(batch_size_per_step, num_batches=num_steps, shuffle=shuffle, cluster=cluster) multi_batches = (tuple( zip( grouper(idxs, batch_size, shorten=True, num_groups=num_batches_per_step), data_set.divide(num_batches_per_step))) for idxs, data_set in batches) return multi_batches