Пример #1
0
    def call_hook(self, hook_name, *args, **kwargs):
        # set hook_name to model + reset Result obj
        skip = self._reset_result_and_set_hook_fx_name(hook_name)

        # always profile hooks
        with self.profiler.profile(hook_name):

            # first call trainer hook
            if hasattr(self, hook_name):
                trainer_hook = getattr(self, hook_name)
                trainer_hook(*args, **kwargs)

            # next call hook in lightningModule
            output = None
            model_ref = self.lightning_module
            if is_overridden(hook_name, model_ref):
                hook_fx = getattr(model_ref, hook_name)
                output = hook_fx(*args, **kwargs)

            # if the PL module doesn't have the hook then call the accelerator
            # used to auto-reduce things for the user with Results obj
            elif hasattr(self.accelerator, hook_name):
                accelerator_hook = getattr(self.accelerator, hook_name)
                output = accelerator_hook(*args, **kwargs)

        if not skip:
            self._cache_logged_metrics()
        return output
    def on_train_epoch_end(self,
                           epoch_output: List[List[List[Result]]]) -> None:
        # inform logger the batch loop has finished
        self.trainer.logger_connector.on_train_epoch_end()

        # prepare epoch output
        processed_epoch_output = TrainLoop._prepare_outputs(epoch_output,
                                                            batch_mode=False)

        # get the model and call model.training_epoch_end
        model = self.trainer.lightning_module

        if is_overridden('training_epoch_end', model=model):
            # run training_epoch_end
            # refresh the result for custom logging at the epoch level
            model._current_fx_name = 'training_epoch_end'
            training_epoch_end_output = model.training_epoch_end(
                processed_epoch_output)

            if training_epoch_end_output is not None:
                raise MisconfigurationException(
                    'training_epoch_end expects a return of None. '
                    'HINT: remove the return statement in training_epoch_end')

            # capture logging
            self.trainer.logger_connector.cache_logged_metrics()

        # call train epoch end hooks
        self._on_train_epoch_end_hook(processed_epoch_output)
        self.trainer.call_hook('on_epoch_end')
Пример #3
0
    def run_sanity_check(self, ref_model):
        using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model)
        should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

        # run tiny validation (if validation defined)
        # to make sure program won't crash during val
        if should_sanity_check:
            stage = self._running_stage
            self.sanity_checking = True

            # hook and callback
            self.on_sanity_check_start()

            # run eval step
            _, eval_results = self.run_evaluation()

            # allow no returns from eval
            if eval_results is not None and len(eval_results) > 0:
                # when we get a list back, used only the last item
                if isinstance(eval_results, list):
                    eval_results = eval_results[-1]

                _, _, _, callback_metrics, _ = self.process_dict_result(eval_results)
                self.logger_connector.callback_metrics = callback_metrics

            self.on_sanity_check_end()

            self._running_stage = stage
Пример #4
0
    def attach_datamodule(
            self,
            model: "pl.LightningModule",
            datamodule: Optional["pl.LightningDataModule"] = None) -> None:
        # If we have a datamodule, attach necessary hooks + dataloaders
        if datamodule is None:
            return

        self._train_dataloader_source = _DataLoaderSource(
            datamodule, "train_dataloader")
        self._val_dataloader_source = _DataLoaderSource(
            datamodule, "val_dataloader")
        self._test_dataloader_source = _DataLoaderSource(
            datamodule, "test_dataloader")
        self._predict_dataloader_source = _DataLoaderSource(
            datamodule, "predict_dataloader")

        # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
        batch_transfer_hooks = ("on_before_batch_transfer",
                                "transfer_batch_to_device",
                                "on_after_batch_transfer")
        for hook in batch_transfer_hooks:
            if is_overridden(hook, datamodule):
                setattr(model, hook, getattr(datamodule, hook))

        self.trainer.datamodule = datamodule
        datamodule.trainer = self.trainer

        # experimental feature for Flash
        if hasattr(datamodule, "data_pipeline"):
            model.data_pipeline = datamodule.data_pipeline
Пример #5
0
    def _collect_rank_zero_results(self, trainer: "pl.Trainer",
                                   results: Any) -> Optional["_SpawnOutput"]:
        rank_zero_debug("Finalizing the TPU spawn environment.")
        checkpoint_callback = trainer.checkpoint_callback
        best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

        # requires to compute the state_dict on all processes in case Metrics are present
        state_dict = self.lightning_module.state_dict()

        # save the last weights
        weights_path = None
        if trainer.state.fn == TrainerFn.FITTING:
            weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
            self.checkpoint_io.save_checkpoint(state_dict, weights_path)

        # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training
        if self.local_rank != 0:
            return

        # adds the `callback_metrics` to the queue
        extra = _FakeQueue()
        if is_overridden("add_to_queue", self.lightning_module):
            # TODO: Remove the if in v1.7
            self.lightning_module.add_to_queue(extra)
        self.add_to_queue(trainer, extra)

        return _SpawnOutput(best_model_path, weights_path, trainer.state,
                            results, extra)
Пример #6
0
    def init_deepspeed(self):
        # deepspeed handles gradient clipping internally
        if is_overridden("configure_gradient_clipping", self.lightning_module,
                         pl.LightningModule):
            rank_zero_warn(
                "Since DeepSpeed handles gradient clipping internally, the default"
                " `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients."
                " The hook will still be called. Consider setting"
                " `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`"
                " which will use the internal mechanism.")

        if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
            raise MisconfigurationException(
                "DeepSpeed does not support clipping gradients by value.")

        accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler

        if accumulation_scheduler.epochs != [0]:
            raise MisconfigurationException(
                "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs."
            )

        model = LightningDeepSpeedModule(
            pl_module=self.model, precision=self.precision_plugin.precision)

        if self.lightning_module.trainer and self.lightning_module.trainer.training:
            self._initialize_deepspeed_train(model)
        else:
            self._initialize_deepspeed_inference(model)
def test_dm_apply_batch_transfer_handler(get_module_mock):
    expected_device = torch.device('cuda', 0)

    class CustomBatch:
        def __init__(self, data):
            self.samples = data[0]
            self.targets = data[1]

    class CurrentTestDM(LightningDataModule):
        rank = 0
        transfer_batch_to_device_hook_rank = None
        on_before_batch_transfer_hook_rank = None
        on_after_batch_transfer_hook_rank = None

        def on_before_batch_transfer(self, batch, dataloader_idx):
            self.on_before_batch_transfer_hook_rank = self.rank
            self.rank += 1
            batch.samples += 1
            return batch

        def on_after_batch_transfer(self, batch, dataloader_idx):
            assert batch.samples.device == batch.targets.device == expected_device
            self.on_after_batch_transfer_hook_rank = self.rank
            self.rank += 1
            batch.targets *= 2
            return batch

        def transfer_batch_to_device(self, batch, device):
            self.transfer_batch_to_device_hook_rank = self.rank
            self.rank += 1
            batch.samples = batch.samples.to(device)
            batch.targets = batch.targets.to(device)
            return batch

    dm = CurrentTestDM()
    model = BoringModel()

    batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1,
                                                        dtype=torch.long)))

    trainer = Trainer(gpus=1)
    # running .fit() would require us to implement custom data loaders, we mock the model reference instead
    get_module_mock.return_value = model
    if is_overridden('transfer_batch_to_device', dm):
        model.transfer_batch_to_device = dm.transfer_batch_to_device

    model.on_before_batch_transfer = dm.on_before_batch_transfer
    model.transfer_batch_to_device = dm.transfer_batch_to_device
    model.on_after_batch_transfer = dm.on_after_batch_transfer

    batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device)

    assert dm.on_before_batch_transfer_hook_rank == 0
    assert dm.transfer_batch_to_device_hook_rank == 1
    assert dm.on_after_batch_transfer_hook_rank == 2
    assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device
    assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 32))
    assert torch.allclose(batch_gpu.targets.cpu(),
                          torch.ones(5, 1, dtype=torch.long) * 2)
Пример #8
0
    def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
        """Overrides the model's :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers`
        method if a single optimizer and optionally a scheduler argument groups are added to the parser as
        'AUTOMATIC'."""
        parser = self._parser(subcommand)

        def get_automatic(
            class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]]
        ) -> List[str]:
            automatic = []
            for key, (base_class, link_to) in register.items():
                if not isinstance(base_class, tuple):
                    base_class = (base_class,)
                if link_to == "AUTOMATIC" and any(issubclass(c, class_type) for c in base_class):
                    automatic.append(key)
            return automatic

        optimizers = get_automatic(Optimizer, parser._optimizers)
        lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers)

        if len(optimizers) == 0:
            return

        if len(optimizers) > 1 or len(lr_schedulers) > 1:
            raise MisconfigurationException(
                f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer "
                f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user "
                "is expected to link the argument groups and implement `configure_optimizers`, see "
                "https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html"
                "#optimizers-and-learning-rate-schedulers"
            )

        optimizer_class = parser._optimizers[optimizers[0]][0]
        optimizer_init = self._get(self.config_init, optimizers[0])
        if not isinstance(optimizer_class, tuple):
            optimizer_init = _global_add_class_path(optimizer_class, optimizer_init)
        if not optimizer_init:
            # optimizers were registered automatically but not passed by the user
            return

        lr_scheduler_init = None
        if lr_schedulers:
            lr_scheduler_class = parser._lr_schedulers[lr_schedulers[0]][0]
            lr_scheduler_init = self._get(self.config_init, lr_schedulers[0])
            if not isinstance(lr_scheduler_class, tuple):
                lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init)

        if is_overridden("configure_optimizers", self.model):
            _warn(
                f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by "
                f"`{self.__class__.__name__}.configure_optimizers`."
            )

        optimizer = instantiate_class(self.model.parameters(), optimizer_init)
        lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None
        fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler)
        update_wrapper(fn, self.configure_optimizers)  # necessary for `is_overridden`
        # override the existing method
        self.model.configure_optimizers = MethodType(fn, self.model)
Пример #9
0
    def add_configure_optimizers_method_to_model(self) -> None:
        """
        Adds to the model an automatically generated configure_optimizers method

        If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC',
        then a `configure_optimizers` method is automatically implemented in the model class.
        """

        def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]:
            automatic = []
            for key, (base_class, link_to) in self.parser.optimizers_and_lr_schedulers.items():
                if not isinstance(base_class, tuple):
                    base_class = (base_class, )
                if link_to == 'AUTOMATIC' and any(issubclass(c, class_type) for c in base_class):
                    automatic.append(key)
            return automatic

        optimizers = get_automatic(Optimizer)
        lr_schedulers = get_automatic(LRSchedulerTypeTuple)

        if len(optimizers) == 0:
            return

        if len(optimizers) > 1 or len(lr_schedulers) > 1:
            raise MisconfigurationException(
                f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer "
                f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user "
                "is expected to link the argument groups and implement `configure_optimizers`, see "
                "https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html"
                "#optimizers-and-learning-rate-schedulers"
            )

        if is_overridden('configure_optimizers', self.model):
            warnings.warn(
                f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by "
                f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`."
            )

        optimizer_class = self.parser.optimizers_and_lr_schedulers[optimizers[0]][0]
        optimizer_init = self.config_init.get(optimizers[0], {})
        if not isinstance(optimizer_class, tuple):
            optimizer_init = _global_add_class_path(optimizer_class, optimizer_init)
        lr_scheduler_init = None
        if lr_schedulers:
            lr_scheduler_class = self.parser.optimizers_and_lr_schedulers[lr_schedulers[0]][0]
            lr_scheduler_init = self.config_init.get(lr_schedulers[0], {})
            if not isinstance(lr_scheduler_class, tuple):
                lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init)

        def configure_optimizers(
            self: LightningModule
        ) -> Union[Optimizer, Tuple[List[Optimizer], List[LRSchedulerType]]]:
            optimizer = instantiate_class(self.parameters(), optimizer_init)
            if not lr_scheduler_init:
                return optimizer
            lr_scheduler = instantiate_class(optimizer, lr_scheduler_init)
            return [optimizer], [lr_scheduler]

        self.model.configure_optimizers = MethodType(configure_optimizers, self.model)
    def _should_add_batch_output_to_epoch_output(self) -> bool:
        """
        We add to the epoch outputs if
        1. The model defines training_epoch_end OR
        2. The model overrides on_train_epoch_end which has `outputs` in the signature
        """
        # TODO: in v1.5 this only needs to check if training_epoch_end is overridden
        lightning_module = self.trainer.lightning_module
        if is_overridden("training_epoch_end", lightning_module):
            return True

        if is_overridden("on_train_epoch_end", lightning_module):
            model_hook_fx = getattr(lightning_module, "on_train_epoch_end")
            if is_param_in_hook_signature(model_hook_fx, "outputs"):
                return True

        return False
Пример #11
0
 def _disable_zero_grad(self) -> None:
     lightning_module = self.lightning_module
     if is_overridden("optimizer_zero_grad", lightning_module):
         assert lightning_module is not None  # `is_overridden` returns False otherwise
         rank_zero_warn(
             "You have overridden the `LightningModule.optimizer_zero_grad` hook but it will be ignored since"
             " IPUs handle the zeroing of gradients internally.")
     lightning_module.optimizer_zero_grad = None  # type: ignore[assignment]
def _check_on_pretrain_routine(model: "pl.LightningModule") -> None:
    hooks = (("on_pretrain_routine_start", "on_fit_start"), ("on_pretrain_routine_end", "on_fit_start"))
    for hook, alternative_hook in hooks:
        if is_overridden(hook, model):
            rank_zero_deprecation(
                f"The `LightningModule.{hook}` hook was deprecated in v1.6 and"
                f" will be removed in v1.8. Please use `LightningModule.{alternative_hook}` instead."
            )
    def __verify_eval_loop_configuration(self, model):
        stage = "val" if self.trainer.validating else "test"

        loader_name = f'{stage}_dataloader'
        step_name = f'{stage}_step'

        has_loader = is_overridden(loader_name, model)
        has_step = is_overridden(step_name, model)

        if has_loader and not has_step:
            rank_zero_warn(
                f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop'
            )
        if has_step and not has_loader:
            rank_zero_warn(
                f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop'
            )
Пример #14
0
 def backward(self, model: "pl.LightningModule", closure_loss: Tensor,
              *args: Any, **kwargs: Any) -> None:
     if is_overridden("backward", model):
         warning_cache.warn(
             "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
             " the backward logic internally.")
     deepspeed_engine: DeepSpeedEngine = model.trainer.model
     deepspeed_engine.backward(closure_loss, *args, **kwargs)
def _check_on_keyboard_interrupt(trainer: "pl.Trainer") -> None:
    """Checks if on_keyboard_interrupt is overriden and sends a deprecation warning."""
    for callback in trainer.callbacks:
        if is_overridden(method_name="on_keyboard_interrupt",
                         instance=callback):
            rank_zero_deprecation(
                "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7."
                " Please use the `on_exception` callback hook instead.")
def _check_setup_method(trainer: "pl.Trainer") -> None:
    for obj in [trainer.lightning_module, trainer.datamodule
                ] + trainer.callbacks:
        if is_overridden("setup", obj) and not is_param_in_hook_signature(
                obj.setup, "stage"):
            raise MisconfigurationException(
                f"`{obj.__class__.__name__}.setup` does not have a `stage` argument."
            )
Пример #17
0
    def can_prepare_data(self):
        should_call_dm_prepare_data = True
        if self.trainer.datamodule is not None and is_overridden('prepare_data', self.trainer.datamodule):
            should_call_dm_prepare_data = not self.trainer.datamodule.has_prepared_data

        if self.trainer.prepare_data_per_node:
            return self.trainer.local_rank == 0 and should_call_dm_prepare_data
        return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data
Пример #18
0
    def make_petastorm_reader(model,
                              data_path,
                              dataloader_attr,
                              reader_worker_count,
                              reader_pool_type,
                              should_read=True):
        from petastorm import TransformSpec, make_reader, make_batch_reader
        import horovod.torch as hvd

        is_loader_overridden = False
        if LooseVersion(pl.__version__) >= LooseVersion('1.0.0'):
            from pytorch_lightning.utilities.model_helpers import is_overridden
            is_loader_overridden = is_overridden(dataloader_attr, model)

        if not should_read or is_loader_overridden:
            yield
            return

        transform_spec = TransformSpec(
            transformation) if transformation else None

        # In general, make_batch_reader is faster than make_reader for reading the dataset.
        # However, we found out that make_reader performs data transformations much faster than
        # make_batch_reader with parallel worker processes. Therefore, the default reader
        # we choose is make_batch_reader unless there are data transformations.
        reader_factory_kwargs = dict()
        if transform_spec:
            reader_factory = make_reader
            reader_factory_kwargs['pyarrow_serialize'] = True
        else:
            reader_factory = make_batch_reader

        # Petastorm: read data from the store with the correct shard for this rank
        # setting num_epochs=None will cause an infinite iterator
        # and enables ranks to perform training and validation with
        # unequal number of samples
        with reader_factory(data_path,
                            num_epochs=1,
                            cur_shard=hvd.rank(),
                            shard_count=hvd.size(),
                            reader_pool_type=reader_pool_type,
                            workers_count=reader_worker_count,
                            hdfs_driver=PETASTORM_HDFS_DRIVER,
                            schema_fields=schema_fields,
                            transform_spec=transform_spec,
                            **reader_factory_kwargs) as reader:

            def dataloader_fn():
                return dataloader_cls(
                    reader,
                    batch_size=batch_size,
                    shuffling_queue_capacity=calculate_shuffle_buffer_size())

            try:
                setattr(model, dataloader_attr, dataloader_fn)
                yield
            finally:
                setattr(model, dataloader_attr, None)
def __verify_eval_loop_configuration(trainer: "pl.Trainer",
                                     model: "pl.LightningModule",
                                     stage: str) -> None:
    loader_name = f"{stage}_dataloader"
    step_name = "validation_step" if stage == "val" else f"{stage}_step"
    trainer_method = "validate" if stage == "val" else stage
    on_eval_hook = f"on_{loader_name}"

    has_loader = getattr(trainer._data_connector,
                         f"_{stage}_dataloader_source").is_defined()
    has_step = is_overridden(step_name, model)
    has_on_eval_dataloader = is_overridden(on_eval_hook, model)

    # ----------------------------------------------
    # verify model does not have on_eval_dataloader
    # ----------------------------------------------
    if has_on_eval_dataloader:
        rank_zero_deprecation(
            f"Method `{on_eval_hook}` is deprecated in v1.5.0 and will"
            f" be removed in v1.7.0. Please use `{loader_name}()` directly.")

    # -----------------------------------
    # verify model has an eval_dataloader
    # -----------------------------------
    if not has_loader:
        raise MisconfigurationException(
            f"No `{loader_name}()` method defined to run `Trainer.{trainer_method}`."
        )

    # predict_step is not required to be overridden
    if stage == "predict":
        if model.predict_step is None:
            raise MisconfigurationException(
                "`predict_step` cannot be None to run `Trainer.predict`")
        elif not has_step and not is_overridden("forward", model):
            raise MisconfigurationException(
                "`Trainer.predict` requires `forward` method to run.")
    else:
        # -----------------------------------
        # verify model has an eval_step
        # -----------------------------------
        if not has_step:
            raise MisconfigurationException(
                f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`."
            )
Пример #20
0
    def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None:
        # -----------------------------------
        # verify model has a training step
        # -----------------------------------
        has_training_step = is_overridden("training_step", model)
        if not has_training_step:
            raise MisconfigurationException(
                "No `training_step()` method defined. Lightning `Trainer` expects as minimum a"
                " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
            )

        # -----------------------------------
        # verify model has a train dataloader
        # -----------------------------------
        has_train_dataloader = is_overridden("train_dataloader", model)
        if not has_train_dataloader:
            raise MisconfigurationException(
                "No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a"
                " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
            )

        # -----------------------------------
        # verify model has optimizer
        # -----------------------------------
        has_optimizers = is_overridden("configure_optimizers", model)
        if not has_optimizers:
            raise MisconfigurationException(
                "No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a"
                " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
            )

        trainer = self.trainer

        trainer.overriden_optimizer_step = is_overridden("optimizer_step", model)
        trainer.overriden_optimizer_zero_grad = is_overridden("optimizer_zero_grad", model)
        automatic_optimization = model.automatic_optimization
        going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches()

        has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
        if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization:
            rank_zero_warn(
                "When using `Trainer(accumulate_grad_batches != 1)` and overriding"
                "`LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"
                "(rather, they are called on every optimization step)."
            )
Пример #21
0
    def attach_datamodule(self, model,
                          datamodule: Optional[LightningDataModule]) -> None:

        # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
        datamodule = datamodule or getattr(model, 'datamodule', None)

        # If we have a datamodule, attach necessary hooks + dataloaders
        if datamodule:

            # Override loader hooks
            if is_overridden('train_dataloader', datamodule):
                model.train_dataloader = datamodule.train_dataloader
            if is_overridden('val_dataloader', datamodule):
                model.val_dataloader = datamodule.val_dataloader
            if is_overridden('test_dataloader', datamodule):
                model.test_dataloader = datamodule.test_dataloader
            if is_overridden('predict_dataloader', datamodule):
                model.predict_dataloader = datamodule.predict_dataloader

            # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
            if is_overridden('on_before_batch_transfer', datamodule):
                model.on_before_batch_transfer = datamodule.on_before_batch_transfer
            if is_overridden('transfer_batch_to_device', datamodule):
                model.transfer_batch_to_device = datamodule.transfer_batch_to_device
            if is_overridden('on_after_batch_transfer', datamodule):
                model.on_after_batch_transfer = datamodule.on_after_batch_transfer

            self.trainer.datamodule = datamodule
            datamodule.trainer = self.trainer
    def __verify_train_loop_configuration(self, model: 'pl.LightningModule') -> None:
        # -----------------------------------
        # verify model has a training step
        # -----------------------------------
        has_training_step = is_overridden('training_step', model)
        if not has_training_step:
            raise MisconfigurationException(
                'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
                ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
            )

        # -----------------------------------
        # verify model has a train dataloader
        # -----------------------------------
        has_train_dataloader = is_overridden('train_dataloader', model)
        if not has_train_dataloader:
            raise MisconfigurationException(
                'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
                ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
            )

        # -----------------------------------
        # verify model has optimizer
        # -----------------------------------
        has_optimizers = is_overridden('configure_optimizers', model)
        if not has_optimizers:
            raise MisconfigurationException(
                'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
                ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
            )

        trainer = self.trainer

        trainer.overriden_optimizer_step = is_overridden('optimizer_step', model)
        trainer.overriden_optimizer_zero_grad = is_overridden('optimizer_zero_grad', model)
        automatic_optimization = model.automatic_optimization
        going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches()

        has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
        if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization:
            raise MisconfigurationException(
                'When overriding `LightningModule` optimizer_step or optimizer_zero_grad,'
                ' `accumulate_grad_batches` in `Trainer` should be 1.'
                ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
            )
Пример #23
0
    def reset_predict_dataloader(self, model) -> None:
        """Resets the predict dataloader and determines the number of batches.

        Args:
            model: The current `LightningModule`
        """
        has_loader = is_overridden('predict_dataloader', model)
        if has_loader:
            self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict')
Пример #24
0
    def __run_eval_epoch_end(self, num_dataloaders):
        model = self.trainer.lightning_module

        # with a single dataloader don't pass an array
        outputs = self.outputs

        # free memory
        self.outputs = []

        eval_results = outputs
        if num_dataloaders == 1:
            eval_results = outputs[0]

        user_reduced = False

        if self.trainer.testing:
            if is_overridden('test_epoch_end', model=model):
                model._current_fx_name = 'test_epoch_end'
                eval_results = model.test_epoch_end(eval_results)
                user_reduced = True

        else:
            if is_overridden('validation_epoch_end', model=model):
                model._current_fx_name = 'validation_epoch_end'
                eval_results = model.validation_epoch_end(eval_results)
                user_reduced = True

        # capture logging
        self.trainer.logger_connector.cache_logged_metrics()
        # depre warning
        if eval_results is not None and user_reduced:
            step = 'testing_epoch_end' if self.trainer.testing else 'validation_epoch_end'
            self.warning_cache.warn(
                f'The {step} should not return anything as of 9.1.'
                ' To log, use self.log(...) or self.write(...) directly in the LightningModule'
            )

        if not isinstance(eval_results, list):
            eval_results = [eval_results]

        # track depreceated metrics
        self.trainer.logger_connector.track_metrics_deprecated(eval_results)

        return eval_results
Пример #25
0
    def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
        # unset dataloder_idx in model
        self.trainer.logger_connector.evaluation_epoch_end()

        # call the model epoch end
        model = self.trainer.lightning_module

        if self.trainer.testing:
            if is_overridden('test_epoch_end', model=model):
                model._current_fx_name = 'test_epoch_end'
                model.test_epoch_end(outputs)

        else:
            if is_overridden('validation_epoch_end', model=model):
                model._current_fx_name = 'validation_epoch_end'
                model.validation_epoch_end(outputs)

        # capture logging
        self.trainer.logger_connector.cache_logged_metrics()
Пример #26
0
    def __verify_eval_loop_configuration(self, model, eval_loop_name):
        step_name = f'{eval_loop_name}_step'

        # map the dataloader name
        loader_name = f'{eval_loop_name}_dataloader'
        if eval_loop_name == 'validation':
            loader_name = 'val_dataloader'

        has_loader = is_overridden(loader_name, model)
        has_step = is_overridden(step_name, model)

        if has_loader and not has_step:
            rank_zero_warn(
                f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop'
            )
        if has_step and not has_loader:
            rank_zero_warn(
                f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop'
            )
Пример #27
0
    def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
        """Resets the predict dataloader and determines the number of batches.

        Args:
            model: The `LightningModule` if called outside of the trainer scope.
        """
        pl_module = self.lightning_module or model
        has_loader = is_overridden("predict_dataloader", pl_module)
        if has_loader:
            self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader("predict", model=pl_module)
Пример #28
0
    def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
        # inform logger the batch loop has finished
        self.trainer.logger_connector.epoch_end_reached()

        # call the model epoch end
        model = self.trainer.lightning_module

        # unset dataloader_idx in model
        model._current_dataloader_idx = None

        if self.trainer.testing:
            if is_overridden('test_epoch_end', model):
                model._current_fx_name = 'test_epoch_end'
                model.test_epoch_end(outputs)

        else:
            if is_overridden('validation_epoch_end', model):
                model._current_fx_name = 'validation_epoch_end'
                model.validation_epoch_end(outputs)
Пример #29
0
 def _disable_zero_grad(self) -> None:
     lightning_module = self.lightning_module
     if is_overridden("optimizer_zero_grad", lightning_module):
         assert lightning_module is not None  # `is_overridden` returns False otherwise
         rank_zero_warn(
             "You have overridden `optimizer_zero_grad` which will be disabled."
             " When `HivemindStrategy(reuse_grad_buffers=True)`, the optimizer cannot call zero grad,"
             " as this would delete the gradients before they are averaged."
         )
     assert lightning_module is not None
     lightning_module.optimizer_zero_grad = None  # type: ignore[assignment]
Пример #30
0
 def _validate_data_hooks(self, model):
     # Raise Misconfiguration exception since these hooks are not supported in DP mode
     # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
     batch_transfer_hooks = ('on_before_batch_transfer',
                             'transfer_batch_to_device',
                             'on_after_batch_transfer')
     for hook in batch_transfer_hooks:
         if self.trainer.accelerator_connector.use_dp and is_overridden(
                 hook, model):
             raise MisconfigurationException(
                 f'Overriding `{hook}` is not supported in DP mode.')