def test_filter_fn_with_eval(self):
        runner = Mock(stage_name="stage1", loader_name="train", epoch=1)
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Validation,
            CallbackOrder.Scheduler,
            CallbackOrder.Logging,
            CallbackOrder.External,
        )

        for order in orders:
            callback = RaiserCallback(order, "on_loader_start")
            wrapper = ControlFlowCallback(
                callback, filter_fn="lambda s, e, l: False"
            )

            wrapper.on_loader_start(runner)

            callback = RaiserCallback(order, "on_loader_start")
            wrapper = ControlFlowCallback(
                callback, filter_fn="lambda s, e, l: True"
            )

            with self.assertRaises(Dummy):
                wrapper.on_loader_start(runner)

        events = (
            "on_loader_end",
            "on_stage_start",
            "on_stage_end",
            "on_epoch_start",
            "on_epoch_end",
            "on_batch_start",
            "on_batch_end",
            "on_exception",
        )
        for event in events:
            for order in orders:
                callback = RaiserCallback(order, event)
                wrapper = ControlFlowCallback(
                    callback, filter_fn="lambda s, e, l: False"
                )

                wrapper.on_loader_start(runner)
                wrapper.__getattribute__(event)(runner)

                callback = RaiserCallback(order, event)
                wrapper = ControlFlowCallback(
                    callback, filter_fn="lambda s, e, l: True"
                )

                wrapper.on_loader_start(runner)
                with self.assertRaises(Dummy):
                    wrapper.__getattribute__(event)(runner)
예제 #2
0
    def test_filter_fn_with_wrong_args(self):
        runner = Mock(stage="stage1", loader_key="train", epoch=1)
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Validation,
            CallbackOrder.Scheduler,
            CallbackOrder.Logging,
            CallbackOrder.External,
        )

        def _ignore_foo(stage: str, epoch: int, loader: str) -> bool:
            return False

        def _raise_foo(stage: str, epoch: int, loader: str) -> bool:
            return True

        for order in orders:
            callback = RaiserCallback(order, "on_loader_start")
            wrapper = ControlFlowCallback(callback, filter_fn=_ignore_foo)

            wrapper.on_loader_start(runner)

            callback = RaiserCallback(order, "on_loader_start")
            wrapper = ControlFlowCallback(callback, filter_fn=_raise_foo)

            with self.assertRaises(Dummy):
                wrapper.on_loader_start(runner)

        events = (
            "on_loader_end",
            "on_stage_start",
            "on_stage_end",
            "on_epoch_start",
            "on_epoch_end",
            "on_batch_start",
            "on_batch_end",
            "on_exception",
        )
        for event in events:
            for order in orders:
                callback = RaiserCallback(order, event)
                wrapper = ControlFlowCallback(callback, filter_fn=_ignore_foo)

                wrapper.on_loader_start(runner)
                wrapper.__getattribute__(event)(runner)

                callback = RaiserCallback(order, event)
                wrapper = ControlFlowCallback(callback, filter_fn=_raise_foo)

                wrapper.on_loader_start(runner)
                with self.assertRaises(Dummy):
                    wrapper.__getattribute__(event)(runner)