def test_checkpoint_hook(): """xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook.""" # test epoch based runner loader = DataLoader(torch.ones((5, 2))) runner = _build_demo_runner('EpochBasedRunner', max_epochs=1) runner.meta = dict() checkpointhook = CheckpointHook(interval=1, by_epoch=True) runner.register_hook(checkpointhook) runner.run([loader], [('train', 1)]) assert runner.meta['hook_msgs']['last_ckpt'] == osp.join( runner.work_dir, 'epoch_1.pth') shutil.rmtree(runner.work_dir) # test iter based runner runner = _build_demo_runner('IterBasedRunner', max_iters=1, max_epochs=None) runner.meta = dict() checkpointhook = CheckpointHook(interval=1, by_epoch=False) runner.register_hook(checkpointhook) runner.run([loader], [('train', 1)]) assert runner.meta['hook_msgs']['last_ckpt'] == osp.join( runner.work_dir, 'iter_1.pth') shutil.rmtree(runner.work_dir)
def register_qat_hooks(self, loss, metrics, lr_policies, qat_policies, ckpt_interval=None, runtime_hook=None): assert isinstance(loss, dict) assert isinstance(metrics, (tuple, list)) assert isinstance(lr_policies, (tuple, list)) assert isinstance(qat_policies, (tuple, list)) loss = training.build_loss(loss) metrics = training.build_metrics(*metrics) lr_policies = training.build_lr_policies(*lr_policies) qat_policies = training.build_qat_policies(*qat_policies) # make sure loss firstly getting ready after `batch_processor` self.register_hook(loss, priority="HIGH") self.register_hook(IterTimerHook()) if ckpt_interval: self.register_hook(CheckpointHook(interval=ckpt_interval)) for hook in chain(metrics, qat_policies, lr_policies): if isinstance(hook, HijackModuleOutput): priority = "LOW" else: priority = "NORMAL" self.register_hook(hook, priority) if runtime_hook is not None: interval = runtime_hook["interval"] hooks = runtime_hook["hooks"] post_process = runtime_hook.get("post_process") self.inject_runtime_hooks(interval, hooks, post_process) else: self.inject_runtime_hooks(-1, [], None)
def test_ema_hook(): """xdoctest -m tests/test_hooks.py test_ema_hook.""" class DemoModel(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=1, padding=1, bias=True) self._init_weight() def _init_weight(self): constant_(self.conv.weight, 0) constant_(self.conv.bias, 0) def forward(self, x): return self.conv(x).sum() def train_step(self, x, optimizer, **kwargs): return dict(loss=self(x)) def val_step(self, x, optimizer, **kwargs): return dict(loss=self(x)) loader = DataLoader(torch.ones((1, 1, 1, 1))) runner = _build_demo_runner() demo_model = DemoModel() runner.model = demo_model emahook = EMAHook(momentum=0.1, interval=2, warm_up=100, resume_from=None) checkpointhook = CheckpointHook(interval=1, by_epoch=True) runner.register_hook(emahook, priority='HIGHEST') runner.register_hook(checkpointhook) runner.run([loader, loader], [('train', 1), ('val', 1)]) checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth') contain_ema_buffer = False for name, value in checkpoint['state_dict'].items(): if 'ema' in name: contain_ema_buffer = True assert value.sum() == 0 value.fill_(1) else: assert value.sum() == 0 assert contain_ema_buffer torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth') work_dir = runner.work_dir resume_ema_hook = EMAHook(momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth') runner = _build_demo_runner(max_epochs=2) runner.model = demo_model runner.register_hook(resume_ema_hook, priority='HIGHEST') checkpointhook = CheckpointHook(interval=1, by_epoch=True) runner.register_hook(checkpointhook) runner.run([loader, loader], [('train', 1), ('val', 1)]) checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth') contain_ema_buffer = False for name, value in checkpoint['state_dict'].items(): if 'ema' in name: contain_ema_buffer = True assert value.sum() == 2 else: assert value.sum() == 1 assert contain_ema_buffer shutil.rmtree(runner.work_dir) shutil.rmtree(work_dir)
def test_ema_hook(): """xdoctest -m tests/test_hooks.py test_ema_hook.""" class DemoModel(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=1, padding=1, bias=True) self.bn = nn.BatchNorm2d(2) self._init_weight() def _init_weight(self): constant_(self.conv.weight, 0) constant_(self.conv.bias, 0) constant_(self.bn.weight, 0) constant_(self.bn.bias, 0) def forward(self, x): return self.bn(self.conv(x)).sum() def train_step(self, x, optimizer, **kwargs): return dict(loss=self(x)) def val_step(self, x, optimizer, **kwargs): return dict(loss=self(x)) loader = DataLoader(torch.ones((1, 1, 1, 1))) runner = _build_demo_runner() demo_model = DemoModel() runner.model = demo_model ema_hook = ExpMomentumEMAHook(momentum=0.0002, total_iter=1, skip_buffers=True, interval=2, resume_from=None) checkpointhook = CheckpointHook(interval=1, by_epoch=True) runner.register_hook(ema_hook, priority='HIGHEST') runner.register_hook(checkpointhook) runner.run([loader, loader], [('train', 1), ('val', 1)]) checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth') num_eam_params = 0 for name, value in checkpoint['state_dict'].items(): if 'ema' in name: num_eam_params += 1 value.fill_(1) assert num_eam_params == 4 torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth') work_dir = runner.work_dir resume_ema_hook = ExpMomentumEMAHook(momentum=0.5, total_iter=10, skip_buffers=True, interval=1, resume_from=f'{work_dir}/epoch_1.pth') runner = _build_demo_runner(max_epochs=2) runner.model = demo_model runner.register_hook(resume_ema_hook, priority='HIGHEST') checkpointhook = CheckpointHook(interval=1, by_epoch=True) runner.register_hook(checkpointhook) runner.run([loader, loader], [('train', 1), ('val', 1)]) checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth') num_eam_params = 0 desired_output = [0.9094, 0.9094] for name, value in checkpoint['state_dict'].items(): if 'ema' in name: num_eam_params += 1 assert value.sum() == 2 else: if ('weight' in name) or ('bias' in name): np.allclose(value.data.cpu().numpy().reshape(-1), desired_output, 1e-4) assert num_eam_params == 4 shutil.rmtree(runner.work_dir) shutil.rmtree(work_dir)