def __init__(self, not_supported):
        super().__init__()
        pl_module_hooks = get_members(LightningModule)
        pl_module_hooks.difference_update(
            {
                "log",
                "log_dict",
                # the following are problematic as they do have `self._current_fx_name` defined some times but
                # not others depending on where they were called. So we cannot reliably `self.log` in them
                "on_before_batch_transfer",
                "transfer_batch_to_device",
                "on_after_batch_transfer",
            }
        )
        # remove `nn.Module` hooks
        module_hooks = get_members(torch.nn.Module)
        pl_module_hooks.difference_update(module_hooks)

        def call(hook, fn, *args, **kwargs):
            out = fn(*args, **kwargs)

            if hook in not_supported:
                with pytest.raises(MisconfigurationException, match=not_supported[hook]):
                    self.log("anything", 1)
            else:
                self.log(hook, 1)
            return out

        for h in pl_module_hooks:
            attr = getattr(self, h)
            setattr(self, h, partial(call, h, attr))
Esempio n. 2
0
def test_lambda_call(tmpdir):
    seed_everything(42)

    class CustomModel(BoringModel):
        def on_train_epoch_start(self):
            if self.current_epoch > 1:
                raise KeyboardInterrupt

    checker = set()

    def call(hook, *_, **__):
        checker.add(hook)

    hooks = get_members(Callback)
    hooks_args = {h: partial(call, h) for h in hooks}
    hooks_args[
        "on_save_checkpoint"] = lambda *_: [checker.add("on_save_checkpoint")]

    model = CustomModel()

    # successful run
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=1,
        limit_val_batches=1,
        callbacks=[LambdaCallback(**hooks_args)],
    )
    with pytest.deprecated_call(
            match="on_keyboard_interrupt` callback hook was deprecated in v1.5"
    ):
        trainer.fit(model)

    ckpt_path = trainer.checkpoint_callback.best_model_path

    # raises KeyboardInterrupt and loads from checkpoint
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=3,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        limit_predict_batches=1,
        callbacks=[LambdaCallback(**hooks_args)],
    )
    with pytest.deprecated_call(
            match="on_keyboard_interrupt` callback hook was deprecated in v1.5"
    ):
        trainer.fit(model, ckpt_path=ckpt_path)
    with pytest.deprecated_call(
            match="on_keyboard_interrupt` callback hook was deprecated in v1.5"
    ):
        trainer.test(model)
    with pytest.deprecated_call(
            match="on_keyboard_interrupt` callback hook was deprecated in v1.5"
    ):
        trainer.predict(model)

    assert checker == hooks
    def __init__(self, not_supported):
        def call(hook, trainer, model=None, *_, **__):
            lightning_module = trainer.lightning_module or model
            if lightning_module is None:
                # `on_init_{start,end}` do not have the `LightningModule` available
                assert hook in ("on_init_start", "on_init_end")
                return

            if hook in not_supported:
                with pytest.raises(MisconfigurationException, match=not_supported[hook]):
                    lightning_module.log("anything", 1)
            else:
                lightning_module.log(hook, 1)

        for h in get_members(Callback):
            setattr(self, h, partial(call, h))
def test_fx_validator():
    funcs_name = get_members(Callback)

    callbacks_func = {
        "on_before_backward",
        "on_after_backward",
        "on_before_optimizer_step",
        "on_batch_end",
        "on_batch_start",
        "on_before_accelerator_backend_setup",
        "on_before_zero_grad",
        "on_epoch_end",
        "on_epoch_start",
        "on_fit_end",
        "on_configure_sharded_model",
        "on_fit_start",
        "on_init_end",
        "on_init_start",
        "on_keyboard_interrupt",
        "on_exception",
        "on_load_checkpoint",
        "load_state_dict",
        "on_pretrain_routine_end",
        "on_pretrain_routine_start",
        "on_sanity_check_end",
        "on_sanity_check_start",
        "state_dict",
        "on_save_checkpoint",
        "on_test_batch_end",
        "on_test_batch_start",
        "on_test_end",
        "on_test_epoch_end",
        "on_test_epoch_start",
        "on_test_start",
        "on_train_batch_end",
        "on_train_batch_start",
        "on_train_end",
        "on_train_epoch_end",
        "on_train_epoch_start",
        "on_train_start",
        "on_validation_batch_end",
        "on_validation_batch_start",
        "on_validation_end",
        "on_validation_epoch_end",
        "on_validation_epoch_start",
        "on_validation_start",
        "on_predict_batch_end",
        "on_predict_batch_start",
        "on_predict_end",
        "on_predict_epoch_end",
        "on_predict_epoch_start",
        "on_predict_start",
        "setup",
        "teardown",
    }

    not_supported = {
        "on_before_accelerator_backend_setup",
        "on_fit_end",
        "on_fit_start",
        "on_configure_sharded_model",
        "on_init_end",
        "on_init_start",
        "on_keyboard_interrupt",
        "on_exception",
        "on_load_checkpoint",
        "load_state_dict",
        "on_pretrain_routine_end",
        "on_pretrain_routine_start",
        "on_sanity_check_end",
        "on_sanity_check_start",
        "on_predict_batch_end",
        "on_predict_batch_start",
        "on_predict_end",
        "on_predict_epoch_end",
        "on_predict_epoch_start",
        "on_predict_start",
        "state_dict",
        "on_save_checkpoint",
        "on_test_end",
        "on_train_end",
        "on_validation_end",
        "setup",
        "teardown",
    }

    # Detected new callback function. Need to add its logging permission to FxValidator and update this test
    assert funcs_name == callbacks_func

    validator = _FxValidator()

    for func_name in funcs_name:
        # This summarizes where and what is currently possible to log using `self.log`
        is_stage = "train" in func_name or "test" in func_name or "validation" in func_name
        is_start = "start" in func_name or "batch" in func_name
        is_epoch = "epoch" in func_name
        on_step = is_stage and not is_start and not is_epoch
        on_epoch = True
        # creating allowed condition
        allowed = (
            is_stage
            or "batch" in func_name
            or "epoch" in func_name
            or "grad" in func_name
            or "backward" in func_name
            or "optimizer_step" in func_name
        )
        allowed = (
            allowed
            and "pretrain" not in func_name
            and "predict" not in func_name
            and func_name not in ["on_train_end", "on_test_end", "on_validation_end"]
        )
        if allowed:
            validator.check_logging_levels(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
            if not is_start and is_stage:
                with pytest.raises(MisconfigurationException, match="must be one of"):
                    validator.check_logging_levels(fx_name=func_name, on_step=True, on_epoch=on_epoch)
        else:
            assert func_name in not_supported
            with pytest.raises(MisconfigurationException, match="You can't"):
                validator.check_logging(fx_name=func_name)

    with pytest.raises(RuntimeError, match="Logging inside `foo` is not implemented"):
        validator.check_logging("foo")