def test_train(self): config = get_test_mlp_task_config() task = build_task(config) num_samples = 10 precise_batch_norm_hook = PreciseBatchNormHook(num_samples) task.set_hooks([precise_batch_norm_hook]) task.prepare() trainer = ClassyTrainer() trainer.train(task)
def test_bn_stats(self): base_self = self class TestHook(ClassyHook): on_start = ClassyHook._noop on_phase_start = ClassyHook._noop on_phase_end = ClassyHook._noop on_end = ClassyHook._noop def __init__(self): self.train_bn_stats = None self.test_bn_stats = None def on_step(self, task): if task.train: self.train_bn_stats = base_self._get_bn_stats( task.base_model) else: self.test_bn_stats = base_self._get_bn_stats( task.base_model) config = get_test_mlp_task_config() task = build_task(config) num_samples = 10 precise_batch_norm_hook = PreciseBatchNormHook(num_samples) test_hook = TestHook() task.set_hooks([precise_batch_norm_hook, test_hook]) trainer = ClassyTrainer() trainer.train(task) updated_bn_stats = self._get_bn_stats(task.base_model) # the stats should be modified after train steps but not after test steps self.assertFalse( self._compare_bn_stats(test_hook.train_bn_stats, updated_bn_stats)) self.assertTrue( self._compare_bn_stats(test_hook.test_bn_stats, updated_bn_stats))