示例#1
0
def test_checkpoint_hook():
    """xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook."""

    # test epoch based runner
    loader = DataLoader(torch.ones((5, 2)))
    runner = _build_demo_runner('EpochBasedRunner', max_epochs=1)
    runner.meta = dict()
    checkpointhook = CheckpointHook(interval=1, by_epoch=True)
    runner.register_hook(checkpointhook)
    runner.run([loader], [('train', 1)])
    assert runner.meta['hook_msgs']['last_ckpt'] == osp.join(
        runner.work_dir, 'epoch_1.pth')
    shutil.rmtree(runner.work_dir)

    # test iter based runner
    runner = _build_demo_runner('IterBasedRunner',
                                max_iters=1,
                                max_epochs=None)
    runner.meta = dict()
    checkpointhook = CheckpointHook(interval=1, by_epoch=False)
    runner.register_hook(checkpointhook)
    runner.run([loader], [('train', 1)])
    assert runner.meta['hook_msgs']['last_ckpt'] == osp.join(
        runner.work_dir, 'iter_1.pth')
    shutil.rmtree(runner.work_dir)
示例#2
0
    def register_qat_hooks(self,
                           loss,
                           metrics,
                           lr_policies,
                           qat_policies,
                           ckpt_interval=None,
                           runtime_hook=None):
        assert isinstance(loss, dict)
        assert isinstance(metrics, (tuple, list))
        assert isinstance(lr_policies, (tuple, list))
        assert isinstance(qat_policies, (tuple, list))

        loss = training.build_loss(loss)
        metrics = training.build_metrics(*metrics)
        lr_policies = training.build_lr_policies(*lr_policies)
        qat_policies = training.build_qat_policies(*qat_policies)

        # make sure loss firstly getting ready after `batch_processor`
        self.register_hook(loss, priority="HIGH")
        self.register_hook(IterTimerHook())
        if ckpt_interval:
            self.register_hook(CheckpointHook(interval=ckpt_interval))

        for hook in chain(metrics, qat_policies, lr_policies):
            if isinstance(hook, HijackModuleOutput):
                priority = "LOW"
            else:
                priority = "NORMAL"
            self.register_hook(hook, priority)

        if runtime_hook is not None:
            interval = runtime_hook["interval"]
            hooks = runtime_hook["hooks"]
            post_process = runtime_hook.get("post_process")
            self.inject_runtime_hooks(interval, hooks, post_process)
        else:
            self.inject_runtime_hooks(-1, [], None)
示例#3
0
def test_ema_hook():
    """xdoctest -m tests/test_hooks.py test_ema_hook."""
    class DemoModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(in_channels=1,
                                  out_channels=2,
                                  kernel_size=1,
                                  padding=1,
                                  bias=True)
            self._init_weight()

        def _init_weight(self):
            constant_(self.conv.weight, 0)
            constant_(self.conv.bias, 0)

        def forward(self, x):
            return self.conv(x).sum()

        def train_step(self, x, optimizer, **kwargs):
            return dict(loss=self(x))

        def val_step(self, x, optimizer, **kwargs):
            return dict(loss=self(x))

    loader = DataLoader(torch.ones((1, 1, 1, 1)))
    runner = _build_demo_runner()
    demo_model = DemoModel()
    runner.model = demo_model
    emahook = EMAHook(momentum=0.1, interval=2, warm_up=100, resume_from=None)
    checkpointhook = CheckpointHook(interval=1, by_epoch=True)
    runner.register_hook(emahook, priority='HIGHEST')
    runner.register_hook(checkpointhook)
    runner.run([loader, loader], [('train', 1), ('val', 1)])
    checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
    contain_ema_buffer = False
    for name, value in checkpoint['state_dict'].items():
        if 'ema' in name:
            contain_ema_buffer = True
            assert value.sum() == 0
            value.fill_(1)
        else:
            assert value.sum() == 0
    assert contain_ema_buffer
    torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth')
    work_dir = runner.work_dir
    resume_ema_hook = EMAHook(momentum=0.5,
                              warm_up=0,
                              resume_from=f'{work_dir}/epoch_1.pth')
    runner = _build_demo_runner(max_epochs=2)
    runner.model = demo_model
    runner.register_hook(resume_ema_hook, priority='HIGHEST')
    checkpointhook = CheckpointHook(interval=1, by_epoch=True)
    runner.register_hook(checkpointhook)
    runner.run([loader, loader], [('train', 1), ('val', 1)])
    checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
    contain_ema_buffer = False
    for name, value in checkpoint['state_dict'].items():
        if 'ema' in name:
            contain_ema_buffer = True
            assert value.sum() == 2
        else:
            assert value.sum() == 1
    assert contain_ema_buffer
    shutil.rmtree(runner.work_dir)
    shutil.rmtree(work_dir)
示例#4
0
def test_ema_hook():
    """xdoctest -m tests/test_hooks.py test_ema_hook."""
    class DemoModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(in_channels=1,
                                  out_channels=2,
                                  kernel_size=1,
                                  padding=1,
                                  bias=True)
            self.bn = nn.BatchNorm2d(2)

            self._init_weight()

        def _init_weight(self):
            constant_(self.conv.weight, 0)
            constant_(self.conv.bias, 0)
            constant_(self.bn.weight, 0)
            constant_(self.bn.bias, 0)

        def forward(self, x):
            return self.bn(self.conv(x)).sum()

        def train_step(self, x, optimizer, **kwargs):
            return dict(loss=self(x))

        def val_step(self, x, optimizer, **kwargs):
            return dict(loss=self(x))

    loader = DataLoader(torch.ones((1, 1, 1, 1)))
    runner = _build_demo_runner()
    demo_model = DemoModel()
    runner.model = demo_model
    ema_hook = ExpMomentumEMAHook(momentum=0.0002,
                                  total_iter=1,
                                  skip_buffers=True,
                                  interval=2,
                                  resume_from=None)
    checkpointhook = CheckpointHook(interval=1, by_epoch=True)
    runner.register_hook(ema_hook, priority='HIGHEST')
    runner.register_hook(checkpointhook)
    runner.run([loader, loader], [('train', 1), ('val', 1)])
    checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
    num_eam_params = 0
    for name, value in checkpoint['state_dict'].items():
        if 'ema' in name:
            num_eam_params += 1
            value.fill_(1)
    assert num_eam_params == 4
    torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth')

    work_dir = runner.work_dir
    resume_ema_hook = ExpMomentumEMAHook(momentum=0.5,
                                         total_iter=10,
                                         skip_buffers=True,
                                         interval=1,
                                         resume_from=f'{work_dir}/epoch_1.pth')
    runner = _build_demo_runner(max_epochs=2)
    runner.model = demo_model
    runner.register_hook(resume_ema_hook, priority='HIGHEST')
    checkpointhook = CheckpointHook(interval=1, by_epoch=True)
    runner.register_hook(checkpointhook)
    runner.run([loader, loader], [('train', 1), ('val', 1)])
    checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
    num_eam_params = 0
    desired_output = [0.9094, 0.9094]
    for name, value in checkpoint['state_dict'].items():
        if 'ema' in name:
            num_eam_params += 1
            assert value.sum() == 2
        else:
            if ('weight' in name) or ('bias' in name):
                np.allclose(value.data.cpu().numpy().reshape(-1),
                            desired_output, 1e-4)
    assert num_eam_params == 4
    shutil.rmtree(runner.work_dir)
    shutil.rmtree(work_dir)