def test_composite_extension_main_loop_assignment(): ext1 = Mock() ext2 = Mock() comp = CompositeExtension([ext1, ext2]) comp.main_loop = object() assert ext1.main_loop == comp.main_loop assert ext2.main_loop == comp.main_loop
def test_composite_extension_different_schedules(): class Foo(SimpleExtension): def __init__(self, **kwargs): self.do = Mock() super(Foo, self).__init__(**kwargs) def do(self, *args): pass a = Foo(after_batch=False, after_training=True) b = Foo(after_batch=True) comp = CompositeExtension([a, b], before_training=True) comp.main_loop = Mock() comp.do = Mock() comp.dispatch('before_training') comp.dispatch('after_batch') comp.dispatch('after_training') comp.do.assert_called_once_with('before_training') a.do.assert_called_once_with('after_training') b.do.assert_called_once_with('after_batch')
def test_composite_extension_dispatches(): ext1 = Mock() ext2 = Mock() comp = CompositeExtension([ext1, ext2]) comp.main_loop = object() comp.dispatch('before_training') ext1.dispatch.assert_called_once_with('before_training') ext2.dispatch.assert_called_once_with('before_training') comp.dispatch('after_batch', 5) ext1.dispatch.assert_called_with('after_batch', 5) ext2.dispatch.assert_called_with('after_batch', 5)