def test_filter_fn_with_eval(self): runner = Mock(stage_name="stage1", loader_name="train", epoch=1) orders = ( CallbackOrder.Internal, CallbackOrder.Metric, CallbackOrder.MetricAggregation, CallbackOrder.Optimizer, CallbackOrder.Validation, CallbackOrder.Scheduler, CallbackOrder.Logging, CallbackOrder.External, ) for order in orders: callback = RaiserCallback(order, "on_loader_start") wrapper = ControlFlowCallback( callback, filter_fn="lambda s, e, l: False" ) wrapper.on_loader_start(runner) callback = RaiserCallback(order, "on_loader_start") wrapper = ControlFlowCallback( callback, filter_fn="lambda s, e, l: True" ) with self.assertRaises(Dummy): wrapper.on_loader_start(runner) events = ( "on_loader_end", "on_stage_start", "on_stage_end", "on_epoch_start", "on_epoch_end", "on_batch_start", "on_batch_end", "on_exception", ) for event in events: for order in orders: callback = RaiserCallback(order, event) wrapper = ControlFlowCallback( callback, filter_fn="lambda s, e, l: False" ) wrapper.on_loader_start(runner) wrapper.__getattribute__(event)(runner) callback = RaiserCallback(order, event) wrapper = ControlFlowCallback( callback, filter_fn="lambda s, e, l: True" ) wrapper.on_loader_start(runner) with self.assertRaises(Dummy): wrapper.__getattribute__(event)(runner)
def test_filter_fn_with_wrong_args(self): runner = Mock(stage="stage1", loader_key="train", epoch=1) orders = ( CallbackOrder.Internal, CallbackOrder.Metric, CallbackOrder.MetricAggregation, CallbackOrder.Optimizer, CallbackOrder.Validation, CallbackOrder.Scheduler, CallbackOrder.Logging, CallbackOrder.External, ) def _ignore_foo(stage: str, epoch: int, loader: str) -> bool: return False def _raise_foo(stage: str, epoch: int, loader: str) -> bool: return True for order in orders: callback = RaiserCallback(order, "on_loader_start") wrapper = ControlFlowCallback(callback, filter_fn=_ignore_foo) wrapper.on_loader_start(runner) callback = RaiserCallback(order, "on_loader_start") wrapper = ControlFlowCallback(callback, filter_fn=_raise_foo) with self.assertRaises(Dummy): wrapper.on_loader_start(runner) events = ( "on_loader_end", "on_stage_start", "on_stage_end", "on_epoch_start", "on_epoch_end", "on_batch_start", "on_batch_end", "on_exception", ) for event in events: for order in orders: callback = RaiserCallback(order, event) wrapper = ControlFlowCallback(callback, filter_fn=_ignore_foo) wrapper.on_loader_start(runner) wrapper.__getattribute__(event)(runner) callback = RaiserCallback(order, event) wrapper = ControlFlowCallback(callback, filter_fn=_raise_foo) wrapper.on_loader_start(runner) with self.assertRaises(Dummy): wrapper.__getattribute__(event)(runner)