示例#1
0
def test_fast_forward_sampler_over_iterative_dataset(num_workers):
    """
    This test ensures ``FastForwardSampler`` and ``CaptureIterableDataset`` are properly being
    used to capture workers states.
    """
    batch_size = 3
    initial_seed = seed_everything(42)
    generator = torch.Generator()
    generator.manual_seed(initial_seed)
    dataset = RangeIterableDataset(range(20), num_workers, batch_size, True)
    dataset = CaptureIterableDataset(dataset)

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            generator=generator)
    _add_capture_metadata_collate(dataloader)

    iter_dataloader = iter(dataloader)
    batches = []
    for _ in range(5):
        batches.append(next(iter_dataloader))

    # restarting on batch_1 and getting 3 extra batches

    state_dict = {"iter_sampler": {}}
    for batch in batches[:2]:
        batch, _state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(
            batch)
        for k, v in _state_dict[0].items():
            state_dict[k].update(v)

    assert len(
        state_dict["iter_sampler"]) == (num_workers if num_workers > 1 else 1)

    initial_seed = seed_everything(42)
    generator.manual_seed(initial_seed)
    dataset = RangeIterableDataset(range(20),
                                   num_workers,
                                   batch_size,
                                   state_dict=state_dict)
    dataset = CaptureIterableDataset(dataset)
    dataset.load_state_dict(state_dict)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            generator=generator)
    _add_capture_metadata_collate(dataloader)

    iter_dataloader = iter(dataloader)
    batches_restart = []
    for _ in range(3):
        batches_restart.append(next(iter_dataloader))

    assert torch.equal(batches_restart[0]["data"], batches[2]["data"])
    assert torch.equal(batches_restart[1]["data"], batches[3]["data"])
    assert torch.equal(batches_restart[2]["data"], batches[4]["data"])
示例#2
0
 def next_fn(iterator: Iterator):
     batch = next(iterator)
     if not _fault_tolerant_enabled():
         return batch
     # when fault tolerant is enabled, the iterator will return
     # `FastForwardSampler` state_dict metadata
     # along side with the user data.
     # the metadata are extracted and store directly on the iterator
     # to simplify the collection on `state_dict` call.
     batch, samplers_state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch)
     # store the `sampler_state_dict` on the iterator
     CaptureIterableDataset.store_samplers_state_dict(iterator, samplers_state_dict)
     return batch
def _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(rank, worldsize):
    if worldsize > 1:
        _setup_ddp(rank, worldsize)

    def all_gather(tensor, world_size):
        tensor_list = [torch.zeros_like(tensor, dtype=torch.int64) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, tensor)
        return tensor_list

    initial_seed = seed_everything(42)

    generator = torch.Generator()
    generator.manual_seed(initial_seed)

    num_workers = 2
    batch_size = 4
    dataset_length = 60
    num_classes = 10

    labels = np.random.randint(0, num_classes, dataset_length)

    dataset = ClassificationDataset(range(dataset_length), labels)
    dataset = MetaLearningDataset(
        dataset,
        batch_size=batch_size,
        drop_last=True,
        num_workers=num_workers,
        global_rank=rank,
        world_size=worldsize,
        initial_seed=initial_seed,
        debugging=True,
        shuffle=True,
    )
    dataset = CaptureIterableDataset(dataset, initial_seed=initial_seed)
    dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator)
    Trainer._add_sampler_metadata_collate(dataloader)

    epoch_results = []
    for _ in range(2):
        iter_dataloader = iter(dataloader)
        batches = []
        while True:
            try:
                batches.append(next(iter_dataloader))
            except StopIteration:
                break
        epoch_results.append(batches)
        dataloader.dataset.dataset.current_task_iteration += 1

    assert len(epoch_results) == 2

    assert len(epoch_results[0]) == math.ceil((dataset_length / (num_workers * worldsize)) / batch_size) + 2

    if worldsize == 1:
        assert epoch_results[0][0]["data"]["task_length"] == epoch_results[0][1]["data"]["task_length"]
        assert torch.equal(
            epoch_results[0][0]["data"]["selected_indexes"], epoch_results[0][1]["data"]["selected_indexes"]
        )
        assert 0 in epoch_results[0][2][AutoRestartBatchKeys.PL_SAMPLERS]["iter_sampler"]  # worker id 0
        assert 1 in epoch_results[0][3][AutoRestartBatchKeys.PL_SAMPLERS]["iter_sampler"]  # worker id 1
        assert not torch.equal(epoch_results[0][2]["data"][0], epoch_results[0][3]["data"][0])
    else:
        first_task_metadata = all_gather(epoch_results[0][0]["data"]["task_length"], worldsize)
        second_task_metadata = all_gather(epoch_results[0][1]["data"]["task_length"], worldsize)
        assert torch.equal(first_task_metadata[0], first_task_metadata[1])
        assert torch.equal(second_task_metadata[0], second_task_metadata[1])
        assert torch.equal(first_task_metadata[0], second_task_metadata[1])

        first_batch_list = all_gather(epoch_results[0][2]["data"][0], worldsize)
        assert not torch.equal(first_batch_list[0], first_batch_list[1])
        second_batch_list = all_gather(epoch_results[0][3]["data"][0], worldsize)
        assert not torch.equal(second_batch_list[0], second_batch_list[1])

    # restarting on epoch 0 / real batch 2
    state_dict = {"iter_sampler": {}}
    for batch in epoch_results[0][2:4]:
        batch, _state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch)
        for k, v in _state_dict[0].items():
            state_dict[k].update(v)

    dataset = ClassificationDataset(range(dataset_length), labels)
    dataset = MetaLearningDataset(
        dataset,
        batch_size=batch_size,
        drop_last=True,
        num_workers=num_workers,
        global_rank=rank,
        world_size=worldsize,
        initial_seed=initial_seed,
        debugging=True,
        shuffle=True,
    )

    dataset = CaptureIterableDataset(dataset, initial_seed=initial_seed)
    dataset.load_state_dict(state_dict)
    dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator)
    Trainer._add_sampler_metadata_collate(dataloader)

    epoch_results_restart = []
    for _ in range(2):
        iter_dataloader = iter(dataloader)
        batches = []
        while True:
            try:
                batches.append(next(iter_dataloader))
            except StopIteration:
                break
        epoch_results_restart.append(batches)
        dataloader.dataset.dataset.increment_iteration()
        dataloader.dataset.reset_on_epoch()

    assert len(epoch_results_restart[0]) + 2 == len(epoch_results[0])
    epoch_tensors = [e["data"][0] for e in epoch_results[0][4:]]
    epoch_tensors_restart = [e["data"][0] for e in epoch_results_restart[0][2:]]

    for t, tr in zip(epoch_tensors, epoch_tensors_restart):
        assert torch.equal(t, tr)

    epoch_tensors = [e["data"][0] for e in epoch_results[1][2:]]
    epoch_tensors_restart = [e["data"][0] for e in epoch_results_restart[1][2:]]

    for t, tr in zip(epoch_tensors, epoch_tensors_restart):
        assert torch.equal(t, tr)