def _patch_dataloader(model: "Task", dataloader: Union[Callable, DataLoader], stage: RunningStage): if isinstance(dataloader, DataLoader): if _PL_GREATER_EQUAL_1_4_3: dataloader = _PatchDataLoader(dataloader, _STAGES_PREFIX[stage]) dataloader.patch(model) else: dataloader = _PatchDataLoader(dataloader) return dataloader
def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[RunningStage] = None): if not stage: stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] elif isinstance(stage, RunningStage): stages = [stage] for stage in stages: device_collate = None if isinstance(model.transfer_batch_to_device, _StageOrchestrator): device_collate = model.transfer_batch_to_device.unregister_stage(stage) # if no additional funmc available: remove wrapper if model.transfer_batch_to_device.is_empty(): model.transfer_batch_to_device = model.transfer_batch_to_device.func if not device_collate: device_collate = self._identity loader_name = f'{_STAGES_PREFIX[stage]}_dataloader' dataloader, whole_attr_name = self._get_dataloader(model, loader_name) if not dataloader: continue if isinstance(dataloader, _PatchDataLoader): dataloader = dataloader() elif isinstance(dataloader, Callable): dataloader = dataloader() if isinstance(dataloader, Sequence): was_seq = True else: dataloader = [dataloader] was_seq = False for idx, loader in enumerate(dataloader): if isinstance(loader, DataLoader): dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")} if isinstance(dl_args['collate_fn'], _PreProcessor): dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn if isinstance(dl_args["dataset"], IterableAutoDataset): del dl_args['sampler'] del dl_args["batch_sampler"] loader = type(loader)(**dl_args) dataloader[idx] = loader if not was_seq: dataloader = dataloader[0] if isinstance(dataloader, DataLoader): dataloader = _PatchDataLoader(dataloader) self._set_loader(model, whole_attr_name, dataloader)
def _reset_dataloader_for_stage(self, running_state: RunningStage): dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader" # If the dataloader exists, we reset it. dataloader = (getattr( self.trainer.datamodule, dataloader_name) if is_overridden( dataloader_name, self.trainer.datamodule) else None) if dataloader: if _PL_GREATER_EQUAL_1_5_0: setattr( self.trainer._data_connector, f"_{dataloader_name}_source", _DataLoaderSource(self.trainer.datamodule, dataloader_name), ) else: setattr( self.trainer.lightning_module, dataloader_name, _PatchDataLoader(dataloader(), running_state), ) setattr(self.trainer, dataloader_name, None) # TODO: Resolve this within PyTorch Lightning. try: getattr(self.trainer, f"reset_{dataloader_name}")( self.trainer.lightning_module) except MisconfigurationException: pass
def _attach_preprocess_to_model( self, model: 'Task', stage: Optional[RunningStage] = None, device_transform_only: bool = False ) -> None: device_collate_fn = torch.nn.Identity() if not stage: stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] elif isinstance(stage, RunningStage): stages = [stage] for stage in stages: loader_name = f'{_STAGES_PREFIX[stage]}_dataloader' dataloader, whole_attr_name = self._get_dataloader(model, loader_name) if not dataloader: continue if isinstance(dataloader, (_PatchDataLoader, Callable)): dataloader = dataloader() if dataloader is None: continue if isinstance(dataloader, Sequence): was_seq = True else: dataloader = [dataloader] was_seq = False for idx, loader in enumerate(dataloader): # TODO: See lightning for proper reinstantiation of loader if isinstance(loader, DataLoader): dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")} dl_args['collate_fn'], device_collate_fn = self._create_collate_preprocessors( stage=stage, collate_fn=dl_args['collate_fn'] ) if isinstance(dl_args["dataset"], IterableDataset): del dl_args["sampler"] # don't have to reinstantiate loader if just rewrapping devices (happens during detach) if not device_transform_only: del dl_args["batch_sampler"] loader = type(loader)(**dl_args) dataloader[idx] = loader # don't have to set attribute if rewrapping device part (happens during detach) if not device_transform_only: if not was_seq: dataloader = dataloader[0] if isinstance(dataloader, DataLoader): dataloader = _PatchDataLoader(dataloader) self._set_loader(model, whole_attr_name, dataloader) model.transfer_batch_to_device = ( self._model_transfer_to_device_wrapper(model.transfer_batch_to_device, device_collate_fn, model, stage) )