def test_load_checkpoint_with_prefix(): class FooModule(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(1, 2) self.conv2d = nn.Conv2d(3, 1, 3) self.conv2d_2 = nn.Conv2d(3, 2, 3) model = FooModule() nn.init.constant_(model.linear.weight, 1) nn.init.constant_(model.linear.bias, 2) nn.init.constant_(model.conv2d.weight, 3) nn.init.constant_(model.conv2d.bias, 4) nn.init.constant_(model.conv2d_2.weight, 5) nn.init.constant_(model.conv2d_2.bias, 6) with TemporaryDirectory(): torch.save(model.state_dict(), 'model.pth') prefix = 'conv2d' state_dict = _load_checkpoint_with_prefix(prefix, 'model.pth') assert torch.equal(model.conv2d.state_dict()['weight'], state_dict['weight']) assert torch.equal(model.conv2d.state_dict()['bias'], state_dict['bias']) # test whether prefix is in pretrained model with pytest.raises(AssertionError): prefix = 'back' _load_checkpoint_with_prefix(prefix, 'model.pth')
def _load_pretrained_model(self, ckpt_path, prefix='', map_location='cpu', strict=True): state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path, map_location) self.load_state_dict(state_dict, strict=strict) mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmgen')
def init_weights(self, pretrained=None, init_type='ortho'): """Init weights for models. Args: pretrained (str | dict, optional): Path for the pretrained model or dict containing information for pretained models whose necessary key is 'ckpt_path'. Besides, you can also provide 'prefix' to load the generator part from the whole state dict. Defaults to None. init_type (str, optional): The name of an initialization method: ortho | N02 | xavier. Defaults to 'ortho'. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif isinstance(pretrained, dict): ckpt_path = pretrained.get('ckpt_path', None) assert ckpt_path is not None prefix = pretrained.get('prefix', '') map_location = pretrained.get('map_location', 'cpu') strict = pretrained.get('strict', True) state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path, map_location) self.load_state_dict(state_dict, strict=strict) mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmgen') elif pretrained is None: for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)): if init_type == 'ortho': nn.init.orthogonal_(m.weight) elif init_type == 'N02': normal_init(m, 0.0, 0.02) elif init_type == 'xavier': xavier_init(m) else: raise NotImplementedError( f'{init_type} initialization \ not supported now.') else: raise TypeError('pretrained must be a str or None but' f' got {type(pretrained)} instead.')
def init_weights(self, pretrained=None, strict=True): """Init weights for SNGAN-Proj and SAGAN. If ``pretrained=None`` and weight initialization would follow the ``INIT_TYPE`` in ``init_cfg=dict(type=INIT_TYPE)``. For SNGAN-Proj (``INIT_TYPE.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']``), we follow the initialization method in the official Chainer's implementation (https://github.com/pfnet-research/sngan_projection). For SAGAN (``INIT_TYPE.upper() == 'SAGAN'``), we follow the initialization method in official tensorflow's implementation (https://github.com/brain-research/self-attention-gan). Besides the reimplementation of the official code's initialization, we provide BigGAN's and Pytorch-StudioGAN's style initialization (``INIT_TYPE.upper() == BIGGAN`` and ``INIT_TYPE.upper() == STUDIO``). Please refer to https://github.com/ajbrock/BigGAN-PyTorch and https://github.com/POSTECH-CVLab/PyTorch-StudioGAN. Args: pretrained (str | dict, optional): Path for the pretrained model or dict containing information for pretained models whose necessary key is 'ckpt_path'. Besides, you can also provide 'prefix' to load the generator part from the whole state dict. Defaults to None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=strict, logger=logger) elif isinstance(pretrained, dict): ckpt_path = pretrained.get('ckpt_path', None) assert ckpt_path is not None prefix = pretrained.get('prefix', '') map_location = pretrained.get('map_location', 'cpu') strict = pretrained.get('strict', True) state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path, map_location) self.load_state_dict(state_dict, strict=strict) elif pretrained is None: if self.init_type.upper() == 'STUDIO': # initialization method from Pytorch-StudioGAN # * weight: orthogonal_init gain=1 # * bias : 0 for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)): nn.init.orthogonal_(m.weight, gain=1) if hasattr(m, 'bias') and m.bias is not None: m.bias.data.fill_(0.) elif self.init_type.upper() == 'BIGGAN': # initialization method from BigGAN-pytorch # * weight: xavier_init gain=1 # * bias : default for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)): xavier_uniform_(m.weight, gain=1) elif self.init_type.upper() == 'SAGAN': # initialization method from official tensorflow code # * weight: xavier_init gain=1 # * bias : 0 for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)): xavier_init(m, gain=1, distribution='uniform') elif self.init_type.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']: # initialization method from the official chainer code # * embedding.weight: xavier_init gain=1 # * conv.weight : xavier_init gain=sqrt(2) # * shortcut.weight : xavier_init gain=1 # * bias : 0 for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): if 'shortcut' in n: xavier_init(m, gain=1, distribution='uniform') else: xavier_init(m, gain=np.sqrt(2), distribution='uniform') if isinstance(m, (nn.Linear, nn.Embedding)): xavier_init(m, gain=1, distribution='uniform') else: raise NotImplementedError('Unknown initialization method: ' f'\'{self.init_type}\'') else: raise TypeError("'pretrained' must by a str or None. " f'But receive {type(pretrained)}.')