def _resolve_batch_sampler(dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]:
        batch_sampler = getattr(dataloader, "batch_sampler")
        is_predicting = mode == RunningStage.PREDICTING
        # checking the batch sampler type is different than PyTorch default.
        if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting:
            batch_sampler = type(batch_sampler)(
                sampler,
                batch_size=batch_sampler.batch_size,
                drop_last=(False if is_predicting else batch_sampler.drop_last),
            )
            if is_predicting:
                batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

            if _fault_tolerant_enabled():
                fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
                fast_forward_sampler.setup(dataloader_batch_size=1)

            return {
                "sampler": None,
                "shuffle": False,
                "batch_sampler": batch_sampler,
                "batch_size": 1,
                "drop_last": False,
            }

        if _fault_tolerant_enabled():
            fast_forward_sampler = sampler = FastForwardSampler(sampler)
            fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)

        return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
    def create_dataloader():
        dataset = range(50)
        batch_size = 8
        sampler = FastForwardSampler(SequentialSampler(dataset))
        sampler.setup(batch_size)

        return DataLoader(dataset, sampler=sampler, batch_size=batch_size)
Beispiel #3
0
 def create_dataloader_sampler(dset, sampler):
     sampler = FastForwardSampler(sampler)
     sampler.setup(batch_size)
     dl = DataLoader(dset,
                     num_workers=num_workers,
                     sampler=sampler,
                     batch_size=batch_size)
     _add_capture_metadata_collate(dl)
     return dl, sampler
def test_fast_forward_getattr():
    dataset = range(15)
    sampler = SequentialSampler(dataset)
    batch_sampler = BatchSampler(sampler, 3, False)
    index_batch_sampler = FastForwardSampler(batch_sampler)

    assert index_batch_sampler.batch_size == 3
    assert index_batch_sampler.sampler == sampler
def create_dataloader():
    dataset = range(50)
    num_workers = 2
    batch_size = 8
    sampler = FastForwardSampler(SequentialSampler(dataset))
    sampler.setup(batch_size)

    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    dataloader.fast_forward_sampler = sampler

    loader_dict = {
        "a": [DataLoader(create_iterable_dataset(3, num_workers), num_workers=num_workers, batch_size=3), dataloader],
        "b": DataLoader(
            create_iterable_dataset(2, num_workers=1, attr_name="custom_sampler"), num_workers=0, batch_size=2
        ),
    }
    apply_to_collection(loader_dict, DataLoader, Trainer._add_sampler_metadata_collate)
    return CombinedLoader(loader_dict)
Beispiel #6
0
def _dataloader_init_kwargs_resolve_sampler(
        dataloader: DataLoader,
        sampler: Optional[Sampler],
        mode: Optional[RunningStage] = None) -> Dict[str, Any]:
    """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
    re-instantiation.

    If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so
    Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a
    `FastForwardSampler`.
    """
    fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
    batch_sampler = getattr(dataloader, "batch_sampler")
    is_predicting = mode == RunningStage.PREDICTING
    # checking the batch sampler type is different than PyTorch default.
    if batch_sampler is not None and (type(batch_sampler) is not BatchSampler
                                      or is_predicting):
        batch_sampler = type(batch_sampler)(
            sampler,
            batch_size=batch_sampler.batch_size,
            drop_last=(False if is_predicting else batch_sampler.drop_last),
        )
        if is_predicting:
            batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

        if fault_tolerant_mode.is_automatic:
            fast_forward_sampler = batch_sampler = FastForwardSampler(
                batch_sampler)
            fast_forward_sampler.setup(dataloader_batch_size=1)

        return {
            "sampler": None,
            "shuffle": False,
            "batch_sampler": batch_sampler,
            "batch_size": 1,
            "drop_last": False,
        }

    if fault_tolerant_mode.is_automatic:
        fast_forward_sampler = sampler = FastForwardSampler(sampler)
        fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)

    return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
Beispiel #7
0
def test_fast_forward_on_sequential_sampler():
    """This test ensures ``FastForwardSampler`` applied to ``SequentialSampler`` correctly retrived the right next
    batch on restart."""
    dataset = range(15)
    sequential_sampler = SequentialSampler(dataset)
    sampler = FastForwardSampler(sequential_sampler)
    sampler.setup(3)
    batch_sampler = BatchSampler(sampler, 3, False)

    batch_sampler_iter = iter(batch_sampler)

    assert next(batch_sampler_iter) == [0, 1, 2]
    assert next(batch_sampler_iter) == [3, 4, 5]

    state_dict = sampler.state_dict(2)
    assert state_dict[0]["current_iteration"] == 6

    sampler.load_state_dict(state_dict)

    batch_sampler_iter = iter(batch_sampler)
    assert next(batch_sampler_iter) == [6, 7, 8]
Beispiel #8
0
def test_fast_forward_on_batch_sampler():
    """This test ensures ``FastForwardSampler`` applied to ``BatchSampler`` correctly retrived the right next batch
    on restart."""
    dataset = range(15)
    sampler = SequentialSampler(dataset)
    batch_sampler = BatchSampler(sampler, 3, False)
    index_batch_sampler = FastForwardSampler(batch_sampler)

    assert isinstance(index_batch_sampler, Iterable)

    index_batch_sampler_iter = iter(index_batch_sampler)

    assert next(index_batch_sampler_iter) == [0, 1, 2]
    assert next(index_batch_sampler_iter) == [3, 4, 5]

    state_dict = index_batch_sampler.state_dict(2)

    index_batch_sampler = FastForwardSampler(batch_sampler)
    index_batch_sampler.load_state_dict(state_dict)

    index_batch_sampler_iter = iter(index_batch_sampler)
    assert next(index_batch_sampler_iter) == [6, 7, 8]
def _test_fast_forward_sampler_with_distributed_sampler(rank, worldsize):
    _setup_ddp(rank, worldsize)

    initial_seed = seed_everything(42)

    generator = torch.Generator()
    generator.manual_seed(initial_seed)

    num_workers = 2
    batch_size = 4

    dataset = range(30)
    sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed))
    sampler.setup(batch_size)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler
    )

    iter_dataloader = iter(dataloader)

    num_yielded = 0
    batches = []
    while True:
        try:
            batches.append(next(iter_dataloader))
            num_yielded += 1
        except StopIteration:
            break

    expected = torch.tensor([17, 27, 24]) if rank == 0 else torch.tensor([19, 5, 3])
    assert torch.equal(batches[-1], expected)

    assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16

    reload_state_dict = sampler.state_dict(num_yielded - 1)
    assert reload_state_dict[0]["current_iteration"] == 12

    sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed))
    sampler.setup(batch_size)
    sampler.load_state_dict(reload_state_dict)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler
    )

    iter_dataloader = iter(dataloader)

    batches = []
    while True:
        try:
            batches.append(next(iter_dataloader))
        except StopIteration:
            break

    assert torch.equal(batches[-1], expected)
    assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16
def test_fast_forward_on_random_sampler():
    """
    This test ensures ``FastForwardSampler`` applied to ``RandomSampler`` correctly retrived
    the right next batch on restart.
    """
    seed = 42
    seed_everything(42)

    dataset = range(15)
    generator = torch.Generator().manual_seed(seed)
    values = list(RandomSampler(dataset, generator=generator))

    generator = torch.Generator().manual_seed(seed)
    random_sampler = RandomSampler(dataset, generator=generator)
    sampler = FastForwardSampler(random_sampler)
    sampler.setup(3)
    batch_sampler = BatchSampler(sampler, 3, False)

    batch_sampler_iter = iter(batch_sampler)

    assert next(batch_sampler_iter) == values[:3]
    assert next(batch_sampler_iter) == values[3:6]
    assert next(batch_sampler_iter) == values[6:9]

    state_dict = sampler.state_dict(3)
    assert state_dict[0]["current_iteration"] == 9
    state_dict[0]["current_iteration"] = 6

    seed_everything(42)
    generator = torch.Generator().manual_seed(seed)
    random_sampler = RandomSampler(dataset, generator=generator)
    sampler = FastForwardSampler(random_sampler)
    sampler.setup(3)
    batch_sampler = BatchSampler(sampler, 3, False)
    sampler.load_state_dict(state_dict)

    batch_sampler_iter = iter(batch_sampler)
    assert next(batch_sampler_iter) == values[6:9]
    has_raised = False
    try:
        for _ in range(5):
            next(batch_sampler_iter)
    except StopIteration:
        has_raised = True
        assert sampler._current_iteration == 0
        sampler.load_state_dict(sampler.state_dict(0))
    assert has_raised