def test_v1_7_0_index_batch_sampler_wrapper_batch_indices(): sampler = IndexBatchSamplerWrapper(Mock()) with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"): _ = sampler.batch_indices with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"): sampler.batch_indices = []
def _resolve_batch_sampler( dl_args, dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]: batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting: batch_sampler = type(batch_sampler)( sampler, batch_size=batch_sampler.batch_size, drop_last=(False if is_predicting else batch_sampler.drop_last), ) if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) dl_args['batch_sampler'] = batch_sampler dl_args['batch_size'] = 1 dl_args['shuffle'] = False dl_args['sampler'] = None dl_args['drop_last'] = False else: dl_args['sampler'] = sampler dl_args['shuffle'] = False dl_args['batch_sampler'] = None return dl_args
def _resolve_batch_sampler(dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]: batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting: batch_sampler = type(batch_sampler)( sampler, batch_size=batch_sampler.batch_size, drop_last=(False if is_predicting else batch_sampler.drop_last), ) if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) if _fault_tolerant_enabled(): fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler) fast_forward_sampler.setup(dataloader_batch_size=1) return { "sampler": None, "shuffle": False, "batch_sampler": batch_sampler, "batch_size": 1, "drop_last": False, } if _fault_tolerant_enabled(): fast_forward_sampler = sampler = FastForwardSampler(sampler) fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size) return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
def test_index_batch_sampler_methods(): dataset = range(15) sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, 3, False) index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler) assert isinstance(index_batch_sampler, Iterable) assert has_len(index_batch_sampler)
def test_index_batch_sampler(tmpdir): """Test `IndexBatchSampler` properly extracts indices.""" dataset = range(15) sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, 3, False) index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler) assert batch_sampler.batch_size == index_batch_sampler.batch_size assert batch_sampler.drop_last == index_batch_sampler.drop_last assert batch_sampler.sampler is sampler assert list(index_batch_sampler) == index_batch_sampler.seen_batch_indices
def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None) -> Dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-instantiation. If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a `FastForwardSampler`. """ fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting): batch_sampler = type(batch_sampler)( sampler, batch_size=batch_sampler.batch_size, drop_last=(False if is_predicting else batch_sampler.drop_last), ) if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) if fault_tolerant_mode.is_automatic: fast_forward_sampler = batch_sampler = FastForwardSampler( batch_sampler) fast_forward_sampler.setup(dataloader_batch_size=1) return { "sampler": None, "shuffle": False, "batch_sampler": batch_sampler, "batch_size": 1, "drop_last": False, } if fault_tolerant_mode.is_automatic: fast_forward_sampler = sampler = FastForwardSampler(sampler) fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size) return {"sampler": sampler, "shuffle": False, "batch_sampler": None}