Beispiel #1
0
def test_update_init_info():

    class DummyModel(BaseModule):

        def __init__(self, init_cfg=None):
            super().__init__(init_cfg)
            self.conv1 = nn.Conv2d(1, 1, 1, 1)
            self.conv3 = nn.Conv2d(1, 1, 1, 1)
            self.fc1 = nn.Linear(1, 1)

    model = DummyModel()
    from collections import defaultdict
    model._params_init_info = defaultdict(dict)
    for name, param in model.named_parameters():
        model._params_init_info[param]['param_name'] = name
        model._params_init_info[param]['init_info'] = 'init'
        model._params_init_info[param]['tmp_mean_value'] = param.data.mean()

    with torch.no_grad():
        for p in model.parameters():
            p.fill_(1)

    update_init_info(model, init_info='fill_1')

    for item in model._params_init_info.values():
        assert item['init_info'] == 'fill_1'
        assert item['tmp_mean_value'] == 1
Beispiel #2
0
    def __call__(self, module):

        def init(m):
            if self.wholemodule:
                uniform_init(m, self.a, self.b, self.bias)
            else:
                layername = m.__class__.__name__
                basesname = _get_bases_name(m)
                if len(set(self.layer) & set([layername] + basesname)):
                    uniform_init(m, self.a, self.b, self.bias)

        module.apply(init)
        if hasattr(module, '_params_init_info'):
            update_init_info(module, init_info=self._get_init_info())
Beispiel #3
0
    def __call__(self, module):

        def init(m):
            if self.wholemodule:
                xavier_init(m, self.gain, self.bias, self.distribution)
            else:
                layername = m.__class__.__name__
                basesname = _get_bases_name(m)
                if len(set(self.layer) & set([layername] + basesname)):
                    xavier_init(m, self.gain, self.bias, self.distribution)

        module.apply(init)
        if hasattr(module, '_params_init_info'):
            update_init_info(module, init_info=self._get_init_info())
Beispiel #4
0
    def __call__(self, module: nn.Module) -> None:

        def init(m):
            if self.wholemodule:
                trunc_normal_init(m, self.mean, self.std, self.a, self.b,
                                  self.bias)
            else:
                layername = m.__class__.__name__
                basesname = _get_bases_name(m)
                if len(set(self.layer) & set([layername] + basesname)):
                    trunc_normal_init(m, self.mean, self.std, self.a, self.b,
                                      self.bias)

        module.apply(init)
        if hasattr(module, '_params_init_info'):
            update_init_info(module, init_info=self._get_init_info())
Beispiel #5
0
    def __call__(self, module):
        from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
                                 load_state_dict)
        logger = get_logger('mmcv')
        if self.prefix is None:
            print_log(f'load model from: {self.checkpoint}', logger=logger)
            load_checkpoint(
                module,
                self.checkpoint,
                map_location=self.map_location,
                strict=False,
                logger=logger)
        else:
            print_log(
                f'load {self.prefix} in model from: {self.checkpoint}',
                logger=logger)
            state_dict = _load_checkpoint_with_prefix(
                self.prefix, self.checkpoint, map_location=self.map_location)
            load_state_dict(module, state_dict, strict=False, logger=logger)

        if hasattr(module, '_params_init_info'):
            update_init_info(module, init_info=self._get_init_info())