コード例 #1
0
def test_v1_7_0_index_batch_sampler_wrapper_batch_indices():
    sampler = IndexBatchSamplerWrapper(Mock())
    with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"):
        _ = sampler.batch_indices

    with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"):
        sampler.batch_indices = []
コード例 #2
0
    def _resolve_batch_sampler(
            dl_args,
            dataloader,
            sampler,
            mode: Optional[RunningStage] = None) -> Dict[str, Any]:
        batch_sampler = getattr(dataloader, "batch_sampler")
        is_predicting = mode == RunningStage.PREDICTING
        # checking the batch sampler type is different than PyTorch default.
        if (batch_sampler is not None
                and type(batch_sampler) is not BatchSampler) or is_predicting:
            batch_sampler = type(batch_sampler)(
                sampler,
                batch_size=batch_sampler.batch_size,
                drop_last=(False
                           if is_predicting else batch_sampler.drop_last),
            )
            if is_predicting:
                batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
            dl_args['batch_sampler'] = batch_sampler
            dl_args['batch_size'] = 1
            dl_args['shuffle'] = False
            dl_args['sampler'] = None
            dl_args['drop_last'] = False
        else:
            dl_args['sampler'] = sampler
            dl_args['shuffle'] = False
            dl_args['batch_sampler'] = None

        return dl_args
コード例 #3
0
    def _resolve_batch_sampler(dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]:
        batch_sampler = getattr(dataloader, "batch_sampler")
        is_predicting = mode == RunningStage.PREDICTING
        # checking the batch sampler type is different than PyTorch default.
        if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting:
            batch_sampler = type(batch_sampler)(
                sampler,
                batch_size=batch_sampler.batch_size,
                drop_last=(False if is_predicting else batch_sampler.drop_last),
            )
            if is_predicting:
                batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

            if _fault_tolerant_enabled():
                fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
                fast_forward_sampler.setup(dataloader_batch_size=1)

            return {
                "sampler": None,
                "shuffle": False,
                "batch_sampler": batch_sampler,
                "batch_size": 1,
                "drop_last": False,
            }

        if _fault_tolerant_enabled():
            fast_forward_sampler = sampler = FastForwardSampler(sampler)
            fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)

        return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
コード例 #4
0
def test_index_batch_sampler_methods():
    dataset = range(15)
    sampler = SequentialSampler(dataset)
    batch_sampler = BatchSampler(sampler, 3, False)
    index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

    assert isinstance(index_batch_sampler, Iterable)
    assert has_len(index_batch_sampler)
コード例 #5
0
def test_index_batch_sampler(tmpdir):
    """Test `IndexBatchSampler` properly extracts indices."""
    dataset = range(15)
    sampler = SequentialSampler(dataset)
    batch_sampler = BatchSampler(sampler, 3, False)
    index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

    assert batch_sampler.batch_size == index_batch_sampler.batch_size
    assert batch_sampler.drop_last == index_batch_sampler.drop_last
    assert batch_sampler.sampler is sampler
    assert list(index_batch_sampler) == index_batch_sampler.seen_batch_indices
コード例 #6
0
def _dataloader_init_kwargs_resolve_sampler(
        dataloader: DataLoader,
        sampler: Optional[Sampler],
        mode: Optional[RunningStage] = None) -> Dict[str, Any]:
    """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
    re-instantiation.

    If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so
    Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a
    `FastForwardSampler`.
    """
    fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
    batch_sampler = getattr(dataloader, "batch_sampler")
    is_predicting = mode == RunningStage.PREDICTING
    # checking the batch sampler type is different than PyTorch default.
    if batch_sampler is not None and (type(batch_sampler) is not BatchSampler
                                      or is_predicting):
        batch_sampler = type(batch_sampler)(
            sampler,
            batch_size=batch_sampler.batch_size,
            drop_last=(False if is_predicting else batch_sampler.drop_last),
        )
        if is_predicting:
            batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

        if fault_tolerant_mode.is_automatic:
            fast_forward_sampler = batch_sampler = FastForwardSampler(
                batch_sampler)
            fast_forward_sampler.setup(dataloader_batch_size=1)

        return {
            "sampler": None,
            "shuffle": False,
            "batch_sampler": batch_sampler,
            "batch_size": 1,
            "drop_last": False,
        }

    if fault_tolerant_mode.is_automatic:
        fast_forward_sampler = sampler = FastForwardSampler(sampler)
        fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)

    return {"sampler": sampler, "shuffle": False, "batch_sampler": None}