def test_dataloader_source_direct_access(): """Test requesting a dataloader when the source is already a dataloader.""" dataloader = BoringModel().train_dataloader() source = _DataLoaderSource(instance=dataloader, name="any") assert not source.is_module() assert source.is_defined() assert source.dataloader() is dataloader
def _reset_dataloader_for_stage(self, running_state: RunningStage): dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader" # If the dataloader exists, we reset it. dataloader = (getattr( self.trainer.datamodule, dataloader_name) if is_overridden( dataloader_name, self.trainer.datamodule) else None) if dataloader: if _PL_GREATER_EQUAL_1_5_0: setattr( self.trainer._data_connector, f"_{dataloader_name}_source", _DataLoaderSource(self.trainer.datamodule, dataloader_name), ) else: setattr( self.trainer.lightning_module, dataloader_name, _PatchDataLoader(dataloader(), running_state), ) setattr(self.trainer, dataloader_name, None) # TODO: Resolve this within PyTorch Lightning. try: getattr(self.trainer, f"reset_{dataloader_name}")( self.trainer.lightning_module) except MisconfigurationException: pass
def test_dataloader_source_request_from_module(): """Test requesting a dataloader from a module works.""" module = BoringModel() module.trainer = Trainer() module.foo = Mock(return_value=module.train_dataloader()) source = _DataLoaderSource(module, "foo") assert source.is_module() module.foo.assert_not_called() assert isinstance(source.dataloader(), DataLoader) module.foo.assert_called_once()
def test_dataloader_source_available(instance, available): """Test the availability check for _DataLoaderSource.""" source = _DataLoaderSource(instance=instance, name="train_dataloader") assert source.is_defined() is available