def init_weights(self, pretrained=None): """Initialize weights for the model. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Default: None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: generation_init_weights( self, init_type=self.init_type, init_gain=self.init_gain) else: raise TypeError("'pretrained' must be a str or None. " f'But received {type(pretrained)}.')
def test_generation_init_weights(): # Conv module = nn.Conv2d(3, 3, 1) module_tmp = copy.deepcopy(module) generation_init_weights(module, init_type='normal', init_gain=0.02) generation_init_weights(module, init_type='xavier', init_gain=0.02) generation_init_weights(module, init_type='kaiming') generation_init_weights(module, init_type='orthogonal', init_gain=0.02) with pytest.raises(NotImplementedError): generation_init_weights(module, init_type='abc') assert not torch.equal(module.weight.data, module_tmp.weight.data) # Linear module = nn.Linear(3, 1) module_tmp = copy.deepcopy(module) generation_init_weights(module, init_type='normal', init_gain=0.02) generation_init_weights(module, init_type='xavier', init_gain=0.02) generation_init_weights(module, init_type='kaiming') generation_init_weights(module, init_type='orthogonal', init_gain=0.02) with pytest.raises(NotImplementedError): generation_init_weights(module, init_type='abc') assert not torch.equal(module.weight.data, module_tmp.weight.data) # BatchNorm2d module = nn.BatchNorm2d(3) module_tmp = copy.deepcopy(module) generation_init_weights(module, init_type='normal', init_gain=0.02) assert not torch.equal(module.weight.data, module_tmp.weight.data)