def test_trunc_normal_init(): def _random_float(a, b): return (b - a) * random.random() + a def _is_trunc_normal(tensor, mean, std, a, b): # scipy's trunc norm is suited for data drawn from N(0, 1), # so we need to transform our data to test it using scipy. z_samples = (tensor.view(-1) - mean) / std z_samples = z_samples.tolist() a0 = (a - mean) / std b0 = (b - mean) / std p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1] return p_value > 0.0001 conv_module = nn.Conv2d(3, 16, 3) mean = _random_float(-3, 3) std = _random_float(.01, 1) a = _random_float(mean - 2 * std, mean) b = _random_float(mean, mean + 2 * std) trunc_normal_init(conv_module, mean, std, a, b, bias=0.1) assert _is_trunc_normal(conv_module.weight, mean, std, a, b) assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1)) conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) trunc_normal_init(conv_module_no_bias)
def init_weights(self): if self.pretrained is None: for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, nn.LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m.weight, 0, math.sqrt(2.0 / fan_out)) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(self.pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint(self.pretrained, logger=logger, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint self.load_state_dict(state_dict, False)
def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint(pretrained, logger=logger, map_location='cpu') logger.warning(f'Load pre-trained model for ' f'{self.__class__.__name__} from original repo') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if self.convert_weights: # We need to convert pre-trained weights to match this # implementation. state_dict = tcformer_convert(state_dict) load_state_dict(self, state_dict, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m, 0, math.sqrt(2.0 / fan_out)) else: raise TypeError('pretrained must be a str or None')
def init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups normal_init(m, 0, math.sqrt(2.0 / fan_out))
def init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_()
def init_weights(self): logger = get_root_logger() if self.init_cfg is None: logger.warn(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, nn.LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m.weight, 0, math.sqrt(2.0 / fan_out)) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, AbsolutePositionEmbedding): m.init_weights() else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' checkpoint = _load_checkpoint( self.init_cfg.checkpoint, logger=logger, map_location='cpu') logger.warn(f'Load pre-trained model for ' f'{self.__class__.__name__} from original repo') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if self.convert_weights: # Because pvt backbones are not supported by mmcls, # so we need to convert pre-trained weights to match this # implementation. state_dict = pvt_convert(state_dict) load_state_dict(self, state_dict, strict=False, logger=logger)
def init_weights(self): """Initialize the weights in backbone.""" logger = get_root_logger() if self.init_cfg is None: logger.warning(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, 1.0) for bly in self.layers: bly._init_respostnorm() else: if self.ape: raise NotImplementedError assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' ckpt = _load_checkpoint(self.init_cfg.checkpoint, logger=logger, map_location='cpu') if 'state_dict' in ckpt: state_dict = ckpt['state_dict'] elif 'model' in ckpt: state_dict = ckpt['model'] else: state_dict = ckpt # delete keys for reinitialization reinit_keys = ('relative_position_index', 'relative_coords_table') for reinit_key in reinit_keys: for k in list(state_dict.keys()): if reinit_key in k: del state_dict[k] load_state_dict(self, state_dict, strict=False, logger=logger)
def init_weights(self): if self.pretrained is None: for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, nn.LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m.weight, 0, math.sqrt(2.0 / fan_out)) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(self.pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint(self.pretrained, logger=logger, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if self.pretrain_style == 'official': # Because segformer backbone is not support by mmcls, # so we need to convert pretrain weights to match this # implementation. state_dict = mit_convert(state_dict) self.load_state_dict(state_dict, False)
def init_weights(self): if isinstance(self.pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint(self.pretrained, logger=logger, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if self.pretrain_style == 'timm': # Because the refactor of vit is blocked by mmcls, # so we firstly use timm pretrain weights to train # downstream model. state_dict = vit_convert(state_dict) if 'pos_embed' in state_dict.keys(): if self.pos_embed.shape != state_dict['pos_embed'].shape: logger.info(msg=f'Resize the pos_embed shape from ' f'{state_dict["pos_embed"].shape} to ' f'{self.pos_embed.shape}') h, w = self.img_size pos_size = int( math.sqrt(state_dict['pos_embed'].shape[1] - 1)) state_dict['pos_embed'] = self.resize_pos_embed( state_dict['pos_embed'], (h // self.patch_size, w // self.patch_size), (pos_size, pos_size), self.interpolate_mode) self.load_state_dict(state_dict, False) elif self.pretrained is None: super(VisionTransformer, self).init_weights() # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 trunc_normal_init(self.pos_embed, std=.02) trunc_normal_init(self.cls_token, std=.02) for n, m in self.named_modules(): if isinstance(m, nn.Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: if 'ffn' in n: normal_init(m.bias, std=1e-6) else: constant_init(m.bias, 0) elif isinstance(m, nn.Conv2d): kaiming_init(m.weight, mode='fan_in') if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): constant_init(m.bias, 0) constant_init(m.weight, 1.0)
def init_weights(self): logger = get_root_logger() if self.init_cfg is None: logger.warn(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') if self.use_abs_pos_embed: trunc_normal_(self.absolute_pos_embed, std=0.02) for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, 1.0) else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' ckpt = _load_checkpoint(self.init_cfg.checkpoint, logger=logger, map_location='cpu') if 'state_dict' in ckpt: _state_dict = ckpt['state_dict'] elif 'model' in ckpt: _state_dict = ckpt['model'] else: _state_dict = ckpt if self.convert_weights: # supported loading weight from original repo, _state_dict = swin_converter(_state_dict) state_dict = OrderedDict() for k, v in _state_dict.items(): if k.startswith('backbone.'): state_dict[k[9:]] = v # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} # reshape absolute position embedding if state_dict.get('absolute_pos_embed') is not None: absolute_pos_embed = state_dict['absolute_pos_embed'] N1, L, C1 = absolute_pos_embed.size() N2, C2, H, W = self.absolute_pos_embed.size() if N1 != N2 or C1 != C2 or L != H * W: logger.warning('Error in loading absolute_pos_embed, pass') else: state_dict['absolute_pos_embed'] = absolute_pos_embed.view( N2, H, W, C2).permute(0, 3, 1, 2).contiguous() # interpolate position bias table if needed relative_position_bias_table_keys = [ k for k in state_dict.keys() if 'relative_position_bias_table' in k ] for table_key in relative_position_bias_table_keys: table_pretrained = state_dict[table_key] table_current = self.state_dict()[table_key] L1, nH1 = table_pretrained.size() L2, nH2 = table_current.size() if nH1 != nH2: logger.warning(f'Error in loading {table_key}, pass') elif L1 != L2: S1 = int(L1**0.5) S2 = int(L2**0.5) table_pretrained_resized = F.interpolate( table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[table_key] = table_pretrained_resized.view( nH2, L2).permute(1, 0).contiguous() # load state_dict self.load_state_dict(state_dict, False)
def init_weights(self): trunc_normal_init(self.pos_embed, std=0.02)
def init_weights(self): """Initiate the parameters from scratch.""" trunc_normal_init(self.fc_cls, std=self.init_std)
def init_weights(self): trunc_normal_init(self.relative_position_bias_table, std=0.02)
def init_weights(self): if self.pretrained is None: super().init_weights() if self.use_abs_pos_embed: trunc_normal_init(self.absolute_pos_embed, std=0.02) for m in self.modules(): if isinstance(m, Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) elif isinstance(self.pretrained, str): logger = get_root_logger() ckpt = _load_checkpoint(self.pretrained, logger=logger, map_location='cpu') if 'state_dict' in ckpt: state_dict = ckpt['state_dict'] elif 'model' in ckpt: state_dict = ckpt['model'] else: state_dict = ckpt if self.pretrain_style == 'official': state_dict = swin_convert(state_dict) # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} # reshape absolute position embedding if state_dict.get('absolute_pos_embed') is not None: absolute_pos_embed = state_dict['absolute_pos_embed'] N1, L, C1 = absolute_pos_embed.size() N2, C2, H, W = self.absolute_pos_embed.size() if N1 != N2 or C1 != C2 or L != H * W: logger.warning('Error in loading absolute_pos_embed, pass') else: state_dict['absolute_pos_embed'] = absolute_pos_embed.view( N2, H, W, C2).permute(0, 3, 1, 2).contiguous() # interpolate position bias table if needed relative_position_bias_table_keys = [ k for k in state_dict.keys() if 'relative_position_bias_table' in k ] for table_key in relative_position_bias_table_keys: table_pretrained = state_dict[table_key] table_current = self.state_dict()[table_key] L1, nH1 = table_pretrained.size() L2, nH2 = table_current.size() if nH1 != nH2: logger.warning(f'Error in loading {table_key}, pass') else: if L1 != L2: S1 = int(L1**0.5) S2 = int(L2**0.5) table_pretrained_resized = F.interpolate( table_pretrained.permute(1, 0).reshape( 1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[table_key] = table_pretrained_resized.view( nH2, L2).permute(1, 0).contiguous() # load state_dict self.load_state_dict(state_dict, False)