Exemple #1
0
def test_fast_forward_sampler_over_iterative_dataset(num_workers):
    """
    This test ensures ``FastForwardSampler`` and ``CaptureIterableDataset`` are properly being
    used to capture workers states.
    """
    batch_size = 3
    initial_seed = seed_everything(42)
    generator = torch.Generator()
    generator.manual_seed(initial_seed)
    dataset = RangeIterableDataset(range(20), num_workers, batch_size, True)
    dataset = CaptureIterableDataset(dataset)

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            generator=generator)
    _add_capture_metadata_collate(dataloader)

    iter_dataloader = iter(dataloader)
    batches = []
    for _ in range(5):
        batches.append(next(iter_dataloader))

    # restarting on batch_1 and getting 3 extra batches

    state_dict = {"iter_sampler": {}}
    for batch in batches[:2]:
        batch, _state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(
            batch)
        for k, v in _state_dict[0].items():
            state_dict[k].update(v)

    assert len(
        state_dict["iter_sampler"]) == (num_workers if num_workers > 1 else 1)

    initial_seed = seed_everything(42)
    generator.manual_seed(initial_seed)
    dataset = RangeIterableDataset(range(20),
                                   num_workers,
                                   batch_size,
                                   state_dict=state_dict)
    dataset = CaptureIterableDataset(dataset)
    dataset.load_state_dict(state_dict)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            generator=generator)
    _add_capture_metadata_collate(dataloader)

    iter_dataloader = iter(dataloader)
    batches_restart = []
    for _ in range(3):
        batches_restart.append(next(iter_dataloader))

    assert torch.equal(batches_restart[0]["data"], batches[2]["data"])
    assert torch.equal(batches_restart[1]["data"], batches[3]["data"])
    assert torch.equal(batches_restart[2]["data"], batches[4]["data"])
Exemple #2
0
 def next_fn(iterator: Iterator):
     batch = next(iterator)
     if not _fault_tolerant_enabled():
         return batch
     # when fault tolerant is enabled, the iterator will return
     # `FastForwardSampler` state_dict metadata
     # along side with the user data.
     # the metadata are extracted and store directly on the iterator
     # to simplify the collection on `state_dict` call.
     batch, samplers_state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch)
     # store the `sampler_state_dict` on the iterator
     CaptureIterableDataset.store_samplers_state_dict(iterator, samplers_state_dict)
     return batch
def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset:
    if isinstance(dataset, IterableDataset):
        # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
        return CaptureIterableDataset(dataset=dataset)
    if get_len(dataset) != float("inf"):
        return CaptureMapDataset(dataset=dataset)
    raise RuntimeError(
        "This shouldn't happen, please open an issue on Lightning Github repository."
    )
Exemple #4
0
def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict:
    dataset = dl_kwargs["dataset"]
    if isinstance(dataset, IterableDataset):
        # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
        dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dataset)
    elif get_len(dataset) != float("inf"):
        dl_kwargs["dataset"] = CaptureMapDataset(dataset=dataset)
    else:
        raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.")
    return dl_kwargs
Exemple #5
0
def create_iterable_dataset(batch_size,
                            num_workers,
                            attr_name="iter_sampler",
                            wrap: bool = True):
    dataset = RangeIterableDataset(range(50),
                                   num_workers=num_workers,
                                   batch_size=batch_size,
                                   attr_name=attr_name)
    if wrap:
        dataset = CaptureIterableDataset(dataset)
    return dataset
def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler"):
    dataset = RangeIterableDataset(range(50), num_workers=num_workers, batch_size=batch_size, attr_name=attr_name)
    return CaptureIterableDataset(dataset)
def _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(rank, worldsize):
    if worldsize > 1:
        _setup_ddp(rank, worldsize)

    def all_gather(tensor, world_size):
        tensor_list = [torch.zeros_like(tensor, dtype=torch.int64) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, tensor)
        return tensor_list

    initial_seed = seed_everything(42)

    generator = torch.Generator()
    generator.manual_seed(initial_seed)

    num_workers = 2
    batch_size = 4
    dataset_length = 60
    num_classes = 10

    labels = np.random.randint(0, num_classes, dataset_length)

    dataset = ClassificationDataset(range(dataset_length), labels)
    dataset = MetaLearningDataset(
        dataset,
        batch_size=batch_size,
        drop_last=True,
        num_workers=num_workers,
        global_rank=rank,
        world_size=worldsize,
        initial_seed=initial_seed,
        debugging=True,
        shuffle=True,
    )
    dataset = CaptureIterableDataset(dataset, initial_seed=initial_seed)
    dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator)
    Trainer._add_sampler_metadata_collate(dataloader)

    epoch_results = []
    for _ in range(2):
        iter_dataloader = iter(dataloader)
        batches = []
        while True:
            try:
                batches.append(next(iter_dataloader))
            except StopIteration:
                break
        epoch_results.append(batches)
        dataloader.dataset.dataset.current_task_iteration += 1

    assert len(epoch_results) == 2

    assert len(epoch_results[0]) == math.ceil((dataset_length / (num_workers * worldsize)) / batch_size) + 2

    if worldsize == 1:
        assert epoch_results[0][0]["data"]["task_length"] == epoch_results[0][1]["data"]["task_length"]
        assert torch.equal(
            epoch_results[0][0]["data"]["selected_indexes"], epoch_results[0][1]["data"]["selected_indexes"]
        )
        assert 0 in epoch_results[0][2][AutoRestartBatchKeys.PL_SAMPLERS]["iter_sampler"]  # worker id 0
        assert 1 in epoch_results[0][3][AutoRestartBatchKeys.PL_SAMPLERS]["iter_sampler"]  # worker id 1
        assert not torch.equal(epoch_results[0][2]["data"][0], epoch_results[0][3]["data"][0])
    else:
        first_task_metadata = all_gather(epoch_results[0][0]["data"]["task_length"], worldsize)
        second_task_metadata = all_gather(epoch_results[0][1]["data"]["task_length"], worldsize)
        assert torch.equal(first_task_metadata[0], first_task_metadata[1])
        assert torch.equal(second_task_metadata[0], second_task_metadata[1])
        assert torch.equal(first_task_metadata[0], second_task_metadata[1])

        first_batch_list = all_gather(epoch_results[0][2]["data"][0], worldsize)
        assert not torch.equal(first_batch_list[0], first_batch_list[1])
        second_batch_list = all_gather(epoch_results[0][3]["data"][0], worldsize)
        assert not torch.equal(second_batch_list[0], second_batch_list[1])

    # restarting on epoch 0 / real batch 2
    state_dict = {"iter_sampler": {}}
    for batch in epoch_results[0][2:4]:
        batch, _state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch)
        for k, v in _state_dict[0].items():
            state_dict[k].update(v)

    dataset = ClassificationDataset(range(dataset_length), labels)
    dataset = MetaLearningDataset(
        dataset,
        batch_size=batch_size,
        drop_last=True,
        num_workers=num_workers,
        global_rank=rank,
        world_size=worldsize,
        initial_seed=initial_seed,
        debugging=True,
        shuffle=True,
    )

    dataset = CaptureIterableDataset(dataset, initial_seed=initial_seed)
    dataset.load_state_dict(state_dict)
    dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator)
    Trainer._add_sampler_metadata_collate(dataloader)

    epoch_results_restart = []
    for _ in range(2):
        iter_dataloader = iter(dataloader)
        batches = []
        while True:
            try:
                batches.append(next(iter_dataloader))
            except StopIteration:
                break
        epoch_results_restart.append(batches)
        dataloader.dataset.dataset.increment_iteration()
        dataloader.dataset.reset_on_epoch()

    assert len(epoch_results_restart[0]) + 2 == len(epoch_results[0])
    epoch_tensors = [e["data"][0] for e in epoch_results[0][4:]]
    epoch_tensors_restart = [e["data"][0] for e in epoch_results_restart[0][2:]]

    for t, tr in zip(epoch_tensors, epoch_tensors_restart):
        assert torch.equal(t, tr)

    epoch_tensors = [e["data"][0] for e in epoch_results[1][2:]]
    epoch_tensors_restart = [e["data"][0] for e in epoch_results_restart[1][2:]]

    for t, tr in zip(epoch_tensors, epoch_tensors_restart):
        assert torch.equal(t, tr)
    def _get_dataloader_init_kwargs(
            dataloader: DataLoader,
            sampler: Optional[Sampler],
            mode: Optional[RunningStage] = None) -> Dict[str, Any]:
        if not isinstance(dataloader, DataLoader):
            raise ValueError(
                f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`"
            )

        # get the dataloader instance attributes
        attrs = {
            k: v
            for k, v in vars(dataloader).items() if not k.startswith("_")
        }
        # not part of `vars`
        attrs["multiprocessing_context"] = dataloader.multiprocessing_context

        # get the dataloader instance `__init__` parameters
        params = dict(inspect.signature(dataloader.__init__).parameters)

        # keep only the params whose default is different to the current attr value
        non_defaults = {
            name
            for name, p in params.items()
            if name in attrs and p.default != attrs[name]
        }
        # add `dataset` as it might have been replaced with `*args`
        non_defaults.add("dataset")

        # kwargs to re-construct the dataloader
        dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
        dl_kwargs.update(
            TrainerDataLoadingMixin._resolve_batch_sampler(dataloader,
                                                           sampler,
                                                           mode=mode))

        required_args = {
            p.name
            for p in params.values()
            if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
            and p.default is p.empty and p.name not in dl_kwargs
        }
        # the dataloader has required args which we could not extract from the existing attributes
        if required_args:
            required_args = sorted(required_args)
            dataloader_cls_name = dataloader.__class__.__name__
            raise MisconfigurationException(
                f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
                "This would fail as some of the `__init__` arguments are not available as instance attributes. "
                f"The missing attributes are {required_args}. "
                f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
                "manually add the `DistributedSampler` as: "
                f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
            )

        has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD
                                  for p in params.values())
        if not has_variadic_kwargs:
            # the dataloader signature does not allow keyword arguments that need to be passed
            missing_kwargs = dl_kwargs.keys() - params.keys()
            if missing_kwargs:
                missing_kwargs = sorted(missing_kwargs)
                dataloader_cls_name = dataloader.__class__.__name__
                raise MisconfigurationException(
                    f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
                    "This would fail as it doesn't expose all its attributes in the `__init__` signature. "
                    f"The missing arguments are {missing_kwargs}. "
                    f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
                    "manually add the `DistributedSampler` as: "
                    f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
                )

        if isinstance(dl_kwargs["dataset"], IterableDataset):
            dl_kwargs["batch_sampler"] = None
            dl_kwargs["sampler"] = None

        if _fault_tolerant_training():
            if isinstance(dl_kwargs["dataset"], IterableDataset):
                # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
                dl_kwargs["dataset"] = CaptureIterableDataset(
                    dataset=dl_kwargs["dataset"])
            elif len(dl_kwargs["dataset"]):
                dl_kwargs["dataset"] = CaptureMapDataset(
                    dataset=dl_kwargs["dataset"])
            else:
                raise MisconfigurationException(
                    "This shouldn't happen, please open an issue on Lightning Github repository."
                )

        return dl_kwargs