Пример #1
0
    def test_cycle_iterator_function(self):
        global cycle_iterator_function_calls
        cycle_iterator_function_calls = 0

        def one_and_two():
            global cycle_iterator_function_calls
            cycle_iterator_function_calls += 1
            for i in [1, 2]:
                yield i

        iterator = iter(util.cycle_iterator_function(one_and_two))

        # Function calls should be lazy.
        assert cycle_iterator_function_calls == 0

        values = [next(iterator) for _ in range(5)]
        assert values == [1, 2, 1, 2, 1]
        # This is the difference between cycle_iterator_function and itertools.cycle.  We'd only see
        # 1 here with itertools.cycle.
        assert cycle_iterator_function_calls == 3
Пример #2
0
    def __init__(
        self,
        reader: MultiTaskDatasetReader,
        data_path: Dict[str, str],
        scheduler: MultiTaskScheduler,
        *,
        sampler: MultiTaskEpochSampler = None,
        instances_per_epoch: int = None,
        num_workers: Dict[str, int] = None,
        max_instances_in_memory: Dict[str, int] = None,
        start_method: Dict[str, str] = None,
        instance_queue_size: Dict[str, int] = None,
        instance_chunk_size: Dict[str, int] = None,
        shuffle: bool = True,
        cuda_device: Optional[Union[int, str, torch.device]] = None,
    ) -> None:
        self.readers = reader.readers
        self.data_paths = data_path
        self.scheduler = scheduler
        self.sampler = sampler
        self.cuda_device: Optional[torch.device] = None
        if cuda_device is not None:
            if not isinstance(cuda_device, torch.device):
                self.cuda_device = torch.device(cuda_device)
            else:
                self.cuda_device = cuda_device

        self._instances_per_epoch = instances_per_epoch
        self._shuffle = shuffle

        if instances_per_epoch is not None and sampler is None:
            raise ValueError(
                "You must provide an EpochSampler if you want to not use all instances every epoch."
            )

        self._num_workers = num_workers or {}
        self._max_instances_in_memory = max_instances_in_memory or {}
        self._start_method = start_method or {}
        self._instance_queue_size = instance_queue_size or {}
        self._instance_chunk_size = instance_chunk_size or {}

        if self.readers.keys() != self.data_paths.keys():
            raise ValueError(
                f"Mismatch between readers ({self.readers.keys()}) and data paths "
                f"({self.data_paths.keys()})"
            )
        self._loaders = {key: self._make_data_loader(key) for key in self.readers}

        # This stores our current iterator with each dataset, so we don't just iterate over the
        # first k instances every epoch if we're using instances_per_epoch.  We'll grab instances
        # from here each epoch, and refresh it when it runs out.  We only use this in the case that
        # instances_per_epoch is not None, but these iterators are lazy, so always creating them
        # doesn't hurt anything.
        self._iterators: Dict[str, Iterator[Instance]] = {
            # NOTE: The order in which we're calling these iterator functions is important.  We want
            # an infinite iterator over the data, but we want the order in which we iterate over the
            # data to be different at every epoch.  The cycle function will give us an infinite
            # iterator, and it will call the lambda function each time it runs out of instances,
            # which will produce a new shuffling of the dataset.
            key: util.cycle_iterator_function(
                # This default argument to the lambda function is necessary to create a new scope
                # for the loader variable, so a _different_ loader gets saved for every iterator.
                # Dictionary comprehensions don't create new scopes in python.  If you don't have
                # this loader, you end up with `loader` always referring to the last loader in the
                # iteration... mypy also doesn't know what to do with this, for some reason I can't
                # figure out.
                lambda l=loader: maybe_shuffle_instances(l, self._shuffle)  # type: ignore
            )
            for key, loader in self._loaders.items()
        }
 def _read(self, file_path: str):
     return itertools.islice(
         cycle_iterator_function(
             lambda: [Instance({"label": LabelField("B")})]), 100)