Example #1
0
 def _patch_dataloader(model: "Task", dataloader: Union[Callable,
                                                        DataLoader],
                       stage: RunningStage):
     if isinstance(dataloader, DataLoader):
         if _PL_GREATER_EQUAL_1_4_3:
             dataloader = _PatchDataLoader(dataloader,
                                           _STAGES_PREFIX[stage])
             dataloader.patch(model)
         else:
             dataloader = _PatchDataLoader(dataloader)
     return dataloader
Example #2
0
    def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[RunningStage] = None):
        if not stage:
            stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]
        elif isinstance(stage, RunningStage):
            stages = [stage]

        for stage in stages:

            device_collate = None
            if isinstance(model.transfer_batch_to_device, _StageOrchestrator):
                device_collate = model.transfer_batch_to_device.unregister_stage(stage)

                # if no additional funmc available: remove wrapper
                if model.transfer_batch_to_device.is_empty():
                    model.transfer_batch_to_device = model.transfer_batch_to_device.func

            if not device_collate:
                device_collate = self._identity

            loader_name = f'{_STAGES_PREFIX[stage]}_dataloader'

            dataloader, whole_attr_name = self._get_dataloader(model, loader_name)

            if not dataloader:
                continue

            if isinstance(dataloader, _PatchDataLoader):
                dataloader = dataloader()
            elif isinstance(dataloader, Callable):
                dataloader = dataloader()

            if isinstance(dataloader, Sequence):
                was_seq = True
            else:
                dataloader = [dataloader]
                was_seq = False

            for idx, loader in enumerate(dataloader):
                if isinstance(loader, DataLoader):
                    dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")}

                    if isinstance(dl_args['collate_fn'], _PreProcessor):
                        dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn

                        if isinstance(dl_args["dataset"], IterableAutoDataset):
                            del dl_args['sampler']

                        del dl_args["batch_sampler"]

                        loader = type(loader)(**dl_args)

                dataloader[idx] = loader

            if not was_seq:
                dataloader = dataloader[0]

            if isinstance(dataloader, DataLoader):
                dataloader = _PatchDataLoader(dataloader)

            self._set_loader(model, whole_attr_name, dataloader)
Example #3
0
    def _reset_dataloader_for_stage(self, running_state: RunningStage):
        dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader"
        # If the dataloader exists, we reset it.
        dataloader = (getattr(
            self.trainer.datamodule, dataloader_name) if is_overridden(
                dataloader_name, self.trainer.datamodule) else None)

        if dataloader:
            if _PL_GREATER_EQUAL_1_5_0:
                setattr(
                    self.trainer._data_connector,
                    f"_{dataloader_name}_source",
                    _DataLoaderSource(self.trainer.datamodule,
                                      dataloader_name),
                )
            else:
                setattr(
                    self.trainer.lightning_module,
                    dataloader_name,
                    _PatchDataLoader(dataloader(), running_state),
                )
            setattr(self.trainer, dataloader_name, None)
            # TODO: Resolve this within PyTorch Lightning.
            try:
                getattr(self.trainer, f"reset_{dataloader_name}")(
                    self.trainer.lightning_module)
            except MisconfigurationException:
                pass
Example #4
0
    def _attach_preprocess_to_model(
        self, model: 'Task', stage: Optional[RunningStage] = None, device_transform_only: bool = False
    ) -> None:
        device_collate_fn = torch.nn.Identity()

        if not stage:
            stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]

        elif isinstance(stage, RunningStage):
            stages = [stage]

        for stage in stages:

            loader_name = f'{_STAGES_PREFIX[stage]}_dataloader'

            dataloader, whole_attr_name = self._get_dataloader(model, loader_name)

            if not dataloader:
                continue

            if isinstance(dataloader, (_PatchDataLoader, Callable)):
                dataloader = dataloader()

            if dataloader is None:
                continue

            if isinstance(dataloader, Sequence):
                was_seq = True
            else:
                dataloader = [dataloader]
                was_seq = False

            for idx, loader in enumerate(dataloader):
                # TODO: See lightning for proper reinstantiation of loader
                if isinstance(loader, DataLoader):
                    dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")}

                    dl_args['collate_fn'], device_collate_fn = self._create_collate_preprocessors(
                        stage=stage, collate_fn=dl_args['collate_fn']
                    )

                    if isinstance(dl_args["dataset"], IterableDataset):
                        del dl_args["sampler"]

                    # don't have to reinstantiate loader if just rewrapping devices (happens during detach)
                    if not device_transform_only:
                        del dl_args["batch_sampler"]
                        loader = type(loader)(**dl_args)

                dataloader[idx] = loader

            # don't have to set attribute if rewrapping device part (happens during detach)
            if not device_transform_only:
                if not was_seq:
                    dataloader = dataloader[0]

                if isinstance(dataloader, DataLoader):
                    dataloader = _PatchDataLoader(dataloader)

                self._set_loader(model, whole_attr_name, dataloader)

            model.transfer_batch_to_device = (
                self._model_transfer_to_device_wrapper(model.transfer_batch_to_device, device_collate_fn, model, stage)
            )