def _instances_to_batches( self, instance_iterator: Iterable[Instance], move_to_device ) -> Iterator[TensorDict]: instance_iterator = (self._index_instance(instance) for instance in instance_iterator) if move_to_device and self.cuda_device is not None: tensorize = lambda batch: nn_util.move_to_device( # noqa: E731 self.collate_fn(batch), self.cuda_device ) else: tensorize = self.collate_fn if self.batch_sampler is not None: instance_chunks: Iterable[List[Instance]] if self.max_instances_in_memory is not None: instance_chunks = lazy_groups_of(instance_iterator, self.max_instances_in_memory) else: instance_chunks = [list(instance_iterator)] for instances in instance_chunks: batches = ( [instances[i] for i in batch_indices] for batch_indices in self.batch_sampler.get_batch_indices(instances) ) for batch in batches: yield tensorize(batch) else: # Safe to assume this is not `None` when `self.batch_sampler` is `None`. assert self.batch_size is not None if self.shuffle: if self.max_instances_in_memory is not None: instance_iterator = shuffle_iterable( instance_iterator, self.max_instances_in_memory, ) else: # At this point we've already loaded the instances in memory and indexed them, # so this won't take long. instance_iterator = list(instance_iterator) random.shuffle(instance_iterator) for batch in lazy_groups_of(instance_iterator, self.batch_size): if self.drop_last and len(batch) < self.batch_size: break yield tensorize(batch)
def maybe_shuffle_instances(loader: DataLoader, shuffle: bool) -> Iterable[Instance]: if shuffle: return util.shuffle_iterable(loader.iter_instances()) else: return loader.iter_instances()