def test_update_init_info(): class DummyModel(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.conv1 = nn.Conv2d(1, 1, 1, 1) self.conv3 = nn.Conv2d(1, 1, 1, 1) self.fc1 = nn.Linear(1, 1) model = DummyModel() from collections import defaultdict model._params_init_info = defaultdict(dict) for name, param in model.named_parameters(): model._params_init_info[param]['param_name'] = name model._params_init_info[param]['init_info'] = 'init' model._params_init_info[param]['tmp_mean_value'] = param.data.mean() with torch.no_grad(): for p in model.parameters(): p.fill_(1) update_init_info(model, init_info='fill_1') for item in model._params_init_info.values(): assert item['init_info'] == 'fill_1' assert item['tmp_mean_value'] == 1
def __call__(self, module): def init(m): if self.wholemodule: uniform_init(m, self.a, self.b, self.bias) else: layername = m.__class__.__name__ basesname = _get_bases_name(m) if len(set(self.layer) & set([layername] + basesname)): uniform_init(m, self.a, self.b, self.bias) module.apply(init) if hasattr(module, '_params_init_info'): update_init_info(module, init_info=self._get_init_info())
def __call__(self, module): def init(m): if self.wholemodule: xavier_init(m, self.gain, self.bias, self.distribution) else: layername = m.__class__.__name__ basesname = _get_bases_name(m) if len(set(self.layer) & set([layername] + basesname)): xavier_init(m, self.gain, self.bias, self.distribution) module.apply(init) if hasattr(module, '_params_init_info'): update_init_info(module, init_info=self._get_init_info())
def __call__(self, module: nn.Module) -> None: def init(m): if self.wholemodule: trunc_normal_init(m, self.mean, self.std, self.a, self.b, self.bias) else: layername = m.__class__.__name__ basesname = _get_bases_name(m) if len(set(self.layer) & set([layername] + basesname)): trunc_normal_init(m, self.mean, self.std, self.a, self.b, self.bias) module.apply(init) if hasattr(module, '_params_init_info'): update_init_info(module, init_info=self._get_init_info())
def __call__(self, module): from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint, load_state_dict) logger = get_logger('mmcv') if self.prefix is None: print_log(f'load model from: {self.checkpoint}', logger=logger) load_checkpoint( module, self.checkpoint, map_location=self.map_location, strict=False, logger=logger) else: print_log( f'load {self.prefix} in model from: {self.checkpoint}', logger=logger) state_dict = _load_checkpoint_with_prefix( self.prefix, self.checkpoint, map_location=self.map_location) load_state_dict(module, state_dict, strict=False, logger=logger) if hasattr(module, '_params_init_info'): update_init_info(module, init_info=self._get_init_info())