def __iter__(self): """ Return indices of all batches within an epoch. """ indices = stat_utils.permutation( (self._num_samples, stat_utils.random_int32())) batch_indices = [ indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._num_batches) ] return iter(batch_indices)
def __iter__(self): """ Returns the data after each sampling. """ indices = stat_utils.permutation( (self._num_samples, stat_utils.random_int32())) indices = indices.tolist() indices.extend(indices[:self._total_num_samples - len(indices)]) indices = indices[self._rank_id:self._total_num_samples:self. _rank_size] batch_indices = [ indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank) ] return iter(np.array(batch_indices))
def __iter__(self): """ Return indices of all batches within an epoch. """ indices = stat_utils.permutation( (self._num_samples, stat_utils.random_int32())) batch_indices = [ indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._num_batches) ] # padding last batch indices if necessary if len(batch_indices) > 2 and len(batch_indices[-2]) != len( batch_indices[-1]): pad_nums = len(batch_indices[-2]) - len(batch_indices[-1]) pad_indices = np.random.randint(0, self._num_samples, pad_nums) batch_indices[-1] = np.hstack((batch_indices[-1], pad_indices)) return iter(batch_indices)