Example #1
0
def test_controll_flow_callback_filter_fn_ignore_epochs():
    wraped = ControlFlowCallbackWrapper(DummyCallback(),
                                        ignore_epochs=[3, 4, 6, 8])
    mask = [
        True,
        True,
        False,
        False,
        True,
        False,
        True,
        False,
        True,
        True,
    ]
    expected = {
        "train": mask,
        "valid": mask,
    }
    actual = {loader: [] for loader in expected.keys()}
    for epoch in range(1, 10 + 1):
        for loader in expected.keys():
            runner = _Runner(loader, epoch)
            wraped.on_loader_start(runner)
            actual[loader].append(wraped._is_enabled)
    assert actual == expected
Example #2
0
    def test_ignore_foo_with_wrong_args(self):
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Scheduler,
            CallbackOrder.External,
        )
        order = random.choice(orders)

        callback = RaiserCallback(order, "on_epoch_start")

        with self.assertRaises(ValueError):
            ControlFlowCallbackWrapper(callback, filter_fn=12345)

        with self.assertRaises(ValueError):
            ControlFlowCallbackWrapper(callback, filter_fn=lambda arg: True)

        with self.assertRaises(ValueError):
            ControlFlowCallbackWrapper(callback, filter_fn=lambda *args: True)

        with self.assertRaises(ValueError):
            ControlFlowCallbackWrapper(
                callback, filter_fn=lambda one, two, three, four: True)

        with self.assertRaises(ValueError):
            ControlFlowCallbackWrapper(callback,
                                       filter_fn=lambda *args, **kwargs: True)
Example #3
0
def test_control_flow_callback_filter_fn_loaders():
    wraped = ControlFlowCallbackWrapper(DummyCallback(), loaders=["valid"])
    expected = {
        "train": [False] * 5,
        "valid": [True] * 5,
        "another_loader": [False] * 5,
        "like_valid": [False] * 5,
    }
    actual = {loader: [] for loader in expected.keys()}
    for epoch in range(1, 5 + 1):
        for loader in expected.keys():
            runner = _Runner(loader, epoch)
            wraped.on_loader_start(runner)
            actual[loader].append(wraped._is_enabled)
    assert actual == expected
Example #4
0
def test_controll_flow_callback_filter_fn_periodical_ignore_epochs():
    wraped = ControlFlowCallbackWrapper(DummyCallback(), ignore_epochs=4)
    mask = [i % 4 != 0 for i in range(1, 10 + 1)]
    expected = {
        "train": mask,
        "valid": mask,
        "another_loader": mask,
        "like_valid": mask,
    }
    actual = {loader: [] for loader in expected.keys()}
    for epoch in range(1, 10 + 1):
        for loader in expected.keys():
            runner = _Runner(loader, epoch)
            wraped.on_loader_start(runner)
            actual[loader].append(wraped._is_enabled)
    assert actual == expected
Example #5
0
    def test_filter_fn_with_err_in_eval(self):
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Scheduler,
            CallbackOrder.External,
        )

        events = (
            "on_loader_start",
            "on_loader_end",
            "on_experiment_start",
            "on_experiment_end",
            "on_epoch_start",
            "on_epoch_end",
            "on_batch_start",
            "on_batch_end",
            "on_exception",
        )
        for event in events:
            for order in orders:
                callback = RaiserCallback(order, event)
                with self.assertRaises(ValueError):
                    ControlFlowCallbackWrapper(callback,
                                               filter_fn="lambda e, l")
Example #6
0
    def test_ignore_epochs_with_wrong_args(self):
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Scheduler,
            CallbackOrder.External,
        )
        order = random.choice(orders)

        callback = RaiserCallback(order, "on_epoch_start")

        with self.assertRaises(ValueError):
            ControlFlowCallbackWrapper(callback, ignore_epochs=None)

        with self.assertRaises(ValueError):
            ControlFlowCallbackWrapper(callback, ignore_epochs="123456")
Example #7
0
    def test_ignore_loaders_with_wrong_args(self):
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Scheduler,
            CallbackOrder.External,
        )
        order = random.choice(orders)

        callback = RaiserCallback(order, "on_epoch_start")

        with self.assertRaises(ValueError):
            ControlFlowCallbackWrapper(callback, ignore_loaders=1234.56)

        with self.assertRaises(ValueError):
            ControlFlowCallbackWrapper(callback, ignore_loaders=1234.56)

        with self.assertRaises(ValueError):
            ControlFlowCallbackWrapper(
                callback, ignore_loaders={"train": ["", "fjdskjfdk", "1234"]})
Example #8
0
 def test_with_missing_args(self):
     orders = (
         CallbackOrder.Internal,
         CallbackOrder.Metric,
         CallbackOrder.MetricAggregation,
         CallbackOrder.Optimizer,
         CallbackOrder.Scheduler,
         CallbackOrder.External,
     )
     for order in orders:
         callback = RaiserCallback(order, "on_epoch_start")
         with self.assertRaises(ValueError):
             ControlFlowCallbackWrapper(callback)
Example #9
0
    def test_filter_fn_with_eval(self):
        runner = Mock(loader_key="train", epoch=1)
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Scheduler,
            CallbackOrder.External,
        )

        for order in orders:
            callback = RaiserCallback(order, "on_loader_start")
            wrapper = ControlFlowCallbackWrapper(
                callback, filter_fn="lambda e, l: False")

            wrapper.on_loader_start(runner)

            callback = RaiserCallback(order, "on_loader_start")
            wrapper = ControlFlowCallbackWrapper(callback,
                                                 filter_fn="lambda e, l: True")

            with self.assertRaises(Dummy):
                wrapper.on_loader_start(runner)

        events = (
            "on_loader_end",
            "on_experiment_start",
            "on_experiment_end",
            "on_epoch_start",
            "on_epoch_end",
            "on_batch_start",
            "on_batch_end",
            "on_exception",
        )
        for event in events:
            for order in orders:
                callback = RaiserCallback(order, event)
                wrapper = ControlFlowCallbackWrapper(
                    callback, filter_fn="lambda e, l: False")

                wrapper.on_loader_start(runner)
                wrapper.__getattribute__(event)(runner)

                callback = RaiserCallback(order, event)
                wrapper = ControlFlowCallbackWrapper(
                    callback, filter_fn="lambda e, l: True")

                wrapper.on_loader_start(runner)
                with self.assertRaises(Dummy):
                    wrapper.__getattribute__(event)(runner)
Example #10
0
    def test_filter_fn_with_wrong_args(self):
        runner = Mock(loader_key="train", epoch=1)
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Scheduler,
            CallbackOrder.External,
        )

        def _ignore_foo(epoch: int, loader: str) -> bool:
            return False

        def _raise_foo(epoch: int, loader: str) -> bool:
            return True

        for order in orders:
            callback = RaiserCallback(order, "on_loader_start")
            wrapper = ControlFlowCallbackWrapper(callback,
                                                 filter_fn=_ignore_foo)

            wrapper.on_loader_start(runner)

            callback = RaiserCallback(order, "on_loader_start")
            wrapper = ControlFlowCallbackWrapper(callback,
                                                 filter_fn=_raise_foo)

            with self.assertRaises(Dummy):
                wrapper.on_loader_start(runner)

        events = (
            "on_loader_end",
            "on_experiment_start",
            "on_experiment_end",
            "on_epoch_start",
            "on_epoch_end",
            "on_batch_start",
            "on_batch_end",
            "on_exception",
        )
        for event in events:
            for order in orders:
                callback = RaiserCallback(order, event)
                wrapper = ControlFlowCallbackWrapper(callback,
                                                     filter_fn=_ignore_foo)

                wrapper.on_loader_start(runner)
                wrapper.__getattribute__(event)(runner)

                callback = RaiserCallback(order, event)
                wrapper = ControlFlowCallbackWrapper(callback,
                                                     filter_fn=_raise_foo)

                wrapper.on_loader_start(runner)
                with self.assertRaises(Dummy):
                    wrapper.__getattribute__(event)(runner)