コード例 #1
0
    def test_ignore_foo_with_wrong_args(self):
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Validation,
            CallbackOrder.Scheduler,
            CallbackOrder.Logging,
            CallbackOrder.External,
        )
        order = random.choice(orders)

        callback = RaiserCallback(order, "on_epoch_start")

        with self.assertRaises(ValueError):
            ControlFlowCallback(callback, filter_fn=12345)

        with self.assertRaises(ValueError):
            ControlFlowCallback(callback, filter_fn=lambda arg: True)

        with self.assertRaises(ValueError):
            ControlFlowCallback(callback, filter_fn=lambda *args: True)

        with self.assertRaises(ValueError):
            ControlFlowCallback(callback,
                                filter_fn=lambda one, two, three, four: True)

        with self.assertRaises(ValueError):
            ControlFlowCallback(callback,
                                filter_fn=lambda *args, **kwargs: True)
コード例 #2
0
def test_controll_flow_callback_filter_fn_epochs():
    wraped = ControlFlowCallback(DummyCallback(), epochs=[3, 4, 6])
    mask = [
        False,
        False,
        True,
        True,
        False,
        True,
        False,
        False,
        False,
        False,
    ]
    expected = {
        "train": mask,
        "valid": mask,
    }
    actual = {loader: [] for loader in expected.keys()}
    for epoch in range(1, 10 + 1):
        for loader in expected.keys():
            runner = _Runner("stage", loader, epoch, epoch)
            wraped.on_loader_start(runner)
            actual[loader].append(wraped._is_enabled)
    assert actual == expected
コード例 #3
0
def test_controll_flow_callback_filter_fn_global_ignore_epochs():
    wraped = ControlFlowCallback(DummyCallback(),
                                 ignore_epochs=[3, 4, 7, 10],
                                 use_global_epochs=True)
    mask = [
        True,
        True,
        False,
        False,
        True,
        True,
        False,
        True,
        True,
        False,
    ]
    expected = {
        "train": mask,
        "valid": mask,
    }
    actual = {loader: [] for loader in expected.keys()}
    for stage_num, stage in enumerate(["stage1", "stage2"]):
        for epoch in range(1, 5 + 1):
            for loader in expected.keys():
                runner = _Runner(stage, loader, epoch + stage_num * 5, epoch)
                wraped.on_loader_start(runner)
                actual[loader].append(wraped._is_enabled)
    assert actual == expected
コード例 #4
0
    def test_ignore_loaders_with_wrong_args(self):
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Validation,
            CallbackOrder.Scheduler,
            CallbackOrder.Logging,
            CallbackOrder.External,
        )
        order = random.choice(orders)

        callback = RaiserCallback(order, "on_epoch_start")

        with self.assertRaises(ValueError):
            ControlFlowCallback(callback, ignore_loaders=1234.56)

        with self.assertRaises(ValueError):
            ControlFlowCallback(callback, ignore_loaders=1234.56)

        with self.assertRaises(ValueError):
            ControlFlowCallback(
                callback, ignore_loaders={"train": ["", "fjdskjfdk", "1234"]}
            )
コード例 #5
0
def test_control_flow_callback_filter_fn_loaders():
    wraped = ControlFlowCallback(DummyCallback(), loaders=["valid"])
    expected = {
        "train": [False] * 5,
        "valid": [True] * 5,
        "another_loader": [False] * 5,
        "like_valid": [False] * 5,
    }
    actual = {loader: [] for loader in expected.keys()}
    for epoch in range(1, 5 + 1):
        for loader in expected.keys():
            runner = _Runner("stage", loader, epoch, epoch)
            wraped.on_loader_start(runner)
            actual[loader].append(wraped._is_enabled)
    assert actual == expected
コード例 #6
0
def test_controll_flow_callback_filter_fn_periodical_ignore_epochs():
    wraped = ControlFlowCallback(DummyCallback(), ignore_epochs=4)
    mask = [i % 4 != 0 for i in range(1, 10 + 1)]
    expected = {
        "train": mask,
        "valid": mask,
        "another_loader": mask,
        "like_valid": mask,
    }
    actual = {loader: [] for loader in expected.keys()}
    for epoch in range(1, 10 + 1):
        for loader in expected.keys():
            runner = _Runner("stage", loader, epoch, epoch)
            wraped.on_loader_start(runner)
            actual[loader].append(wraped._is_enabled)
    assert actual == expected
コード例 #7
0
    def test_filter_fn_with_err_in_eval(self):
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Validation,
            CallbackOrder.Scheduler,
            CallbackOrder.Logging,
            CallbackOrder.External,
        )

        events = (
            "on_loader_start",
            "on_loader_end",
            "on_stage_start",
            "on_stage_end",
            "on_epoch_start",
            "on_epoch_end",
            "on_batch_start",
            "on_batch_end",
            "on_exception",
        )
        for event in events:
            for order in orders:
                callback = RaiserCallback(order, event)
                with self.assertRaises(ValueError):
                    ControlFlowCallback(callback, filter_fn="lambda s, e, l")
コード例 #8
0
    def test_ignore_epochs_with_wrong_args(self):
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Scheduler,
            CallbackOrder.External,
        )
        order = random.choice(orders)

        callback = RaiserCallback(order, "on_epoch_start")

        with self.assertRaises(ValueError):
            ControlFlowCallback(callback, ignore_epochs=None)

        with self.assertRaises(ValueError):
            ControlFlowCallback(callback, ignore_epochs="123456")
コード例 #9
0
 def test_with_missing_args(self):
     orders = (
         CallbackOrder.Internal,
         CallbackOrder.Metric,
         CallbackOrder.MetricAggregation,
         CallbackOrder.Optimizer,
         CallbackOrder.Scheduler,
         CallbackOrder.External,
     )
     for order in orders:
         callback = RaiserCallback(order, "on_epoch_start")
         with self.assertRaises(ValueError):
             ControlFlowCallback(callback)
コード例 #10
0
    def test_filter_fn_with_eval(self):
        runner = Mock(stage="stage1", loader_key="train", epoch=1)
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Validation,
            CallbackOrder.Scheduler,
            CallbackOrder.Logging,
            CallbackOrder.External,
        )

        for order in orders:
            callback = RaiserCallback(order, "on_loader_start")
            wrapper = ControlFlowCallback(callback,
                                          filter_fn="lambda s, e, l: False")

            wrapper.on_loader_start(runner)

            callback = RaiserCallback(order, "on_loader_start")
            wrapper = ControlFlowCallback(callback,
                                          filter_fn="lambda s, e, l: True")

            with self.assertRaises(Dummy):
                wrapper.on_loader_start(runner)

        events = (
            "on_loader_end",
            "on_stage_start",
            "on_stage_end",
            "on_epoch_start",
            "on_epoch_end",
            "on_batch_start",
            "on_batch_end",
            "on_exception",
        )
        for event in events:
            for order in orders:
                callback = RaiserCallback(order, event)
                wrapper = ControlFlowCallback(
                    callback, filter_fn="lambda s, e, l: False")

                wrapper.on_loader_start(runner)
                wrapper.__getattribute__(event)(runner)

                callback = RaiserCallback(order, event)
                wrapper = ControlFlowCallback(callback,
                                              filter_fn="lambda s, e, l: True")

                wrapper.on_loader_start(runner)
                with self.assertRaises(Dummy):
                    wrapper.__getattribute__(event)(runner)
コード例 #11
0
    def test_filter_fn_with_wrong_args(self):
        runner = Mock(stage="stage1", loader_key="train", epoch=1)
        orders = (
            CallbackOrder.Internal,
            CallbackOrder.Metric,
            CallbackOrder.MetricAggregation,
            CallbackOrder.Optimizer,
            CallbackOrder.Validation,
            CallbackOrder.Scheduler,
            CallbackOrder.Logging,
            CallbackOrder.External,
        )

        def _ignore_foo(stage: str, epoch: int, loader: str) -> bool:
            return False

        def _raise_foo(stage: str, epoch: int, loader: str) -> bool:
            return True

        for order in orders:
            callback = RaiserCallback(order, "on_loader_start")
            wrapper = ControlFlowCallback(callback, filter_fn=_ignore_foo)

            wrapper.on_loader_start(runner)

            callback = RaiserCallback(order, "on_loader_start")
            wrapper = ControlFlowCallback(callback, filter_fn=_raise_foo)

            with self.assertRaises(Dummy):
                wrapper.on_loader_start(runner)

        events = (
            "on_loader_end",
            "on_stage_start",
            "on_stage_end",
            "on_epoch_start",
            "on_epoch_end",
            "on_batch_start",
            "on_batch_end",
            "on_exception",
        )
        for event in events:
            for order in orders:
                callback = RaiserCallback(order, event)
                wrapper = ControlFlowCallback(callback, filter_fn=_ignore_foo)

                wrapper.on_loader_start(runner)
                wrapper.__getattribute__(event)(runner)

                callback = RaiserCallback(order, event)
                wrapper = ControlFlowCallback(callback, filter_fn=_raise_foo)

                wrapper.on_loader_start(runner)
                with self.assertRaises(Dummy):
                    wrapper.__getattribute__(event)(runner)