Exemplo n.º 1
def test_trunc_normal_init():
    def _random_float(a, b):
        return (b - a) * random.random() + a

    def _is_trunc_normal(tensor, mean, std, a, b):
        # scipy's trunc norm is suited for data drawn from N(0, 1),
        # so we need to transform our data to test it using scipy.
        z_samples = (tensor.view(-1) - mean) / std
        z_samples = z_samples.tolist()
        a0 = (a - mean) / std
        b0 = (b - mean) / std
        p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1]
        return p_value > 0.0001

    conv_module = nn.Conv2d(3, 16, 3)
    mean = _random_float(-3, 3)
    std = _random_float(.01, 1)
    a = _random_float(mean - 2 * std, mean)
    b = _random_float(mean, mean + 2 * std)
    trunc_normal_init(conv_module, mean, std, a, b, bias=0.1)
    assert _is_trunc_normal(conv_module.weight, mean, std, a, b)
    assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))

    conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
Exemplo n.º 2
    def init_weights(self):
        if self.pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m.weight, std=.02)
                    if m.bias is not None:
                        constant_init(m.bias, 0)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m.bias, 0)
                    constant_init(m.weight, 1.0)
                elif isinstance(m, nn.Conv2d):
                    fan_out = m.kernel_size[0] * m.kernel_size[
                        1] * m.out_channels
                    fan_out //= m.groups
                    normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
                    if m.bias is not None:
                        constant_init(m.bias, 0)
        elif isinstance(self.pretrained, str):
            logger = get_root_logger()
            checkpoint = _load_checkpoint(self.pretrained,
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
                state_dict = checkpoint

            self.load_state_dict(state_dict, False)
Exemplo n.º 3
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            logger = get_root_logger()

            checkpoint = _load_checkpoint(pretrained,
            logger.warning(f'Load pre-trained model for '
                           f'{self.__class__.__name__} from original repo')
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            elif 'model' in checkpoint:
                state_dict = checkpoint['model']
                state_dict = checkpoint

            if self.convert_weights:
                # We need to convert pre-trained weights to match this
                # implementation.
                state_dict = tcformer_convert(state_dict)
            load_state_dict(self, state_dict, strict=False, logger=logger)

        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m, std=.02, bias=0.)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m, 1.0)
                elif isinstance(m, nn.Conv2d):
                    fan_out = m.kernel_size[0] * m.kernel_size[
                        1] * m.out_channels
                    fan_out //= m.groups
                    normal_init(m, 0, math.sqrt(2.0 / fan_out))
            raise TypeError('pretrained must be a str or None')
Exemplo n.º 4
 def init_weights(self):
     for m in self.modules():
         if isinstance(m, nn.Linear):
             trunc_normal_init(m, std=.02, bias=0.)
         elif isinstance(m, nn.LayerNorm):
             constant_init(m, 1.0)
         elif isinstance(m, nn.Conv2d):
             fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
             fan_out //= m.groups
             normal_init(m, 0, math.sqrt(2.0 / fan_out))
Exemplo n.º 5
 def init_weights(self):
     for m in self.modules():
         if isinstance(m, nn.Linear):
             trunc_normal_init(m, std=.02, bias=0.)
         elif isinstance(m, nn.LayerNorm):
             nn.init.constant_(m.bias, 0)
             nn.init.constant_(m.weight, 1.0)
         elif isinstance(m, nn.Conv2d):
             fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
             fan_out //= m.groups
             m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
             if m.bias is not None:
Exemplo n.º 6
 def init_weights(self):
     logger = get_root_logger()
     if self.init_cfg is None:
         logger.warn(f'No pre-trained weights for '
                     f'{self.__class__.__name__}, '
                     f'training start from scratch')
         for m in self.modules():
             if isinstance(m, nn.Linear):
                 trunc_normal_init(m.weight, std=.02)
                 if m.bias is not None:
                     constant_init(m.bias, 0)
             elif isinstance(m, nn.LayerNorm):
                 constant_init(m.bias, 0)
                 constant_init(m.weight, 1.0)
             elif isinstance(m, nn.Conv2d):
                 fan_out = m.kernel_size[0] * m.kernel_size[
                     1] * m.out_channels
                 fan_out //= m.groups
                 normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
                 if m.bias is not None:
                     constant_init(m.bias, 0)
             elif isinstance(m, AbsolutePositionEmbedding):
         assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                               f'specify `Pretrained` in ' \
                                               f'`init_cfg` in ' \
                                               f'{self.__class__.__name__} '
         checkpoint = _load_checkpoint(
             self.init_cfg.checkpoint, logger=logger, map_location='cpu')
         logger.warn(f'Load pre-trained model for '
                     f'{self.__class__.__name__} from original repo')
         if 'state_dict' in checkpoint:
             state_dict = checkpoint['state_dict']
         elif 'model' in checkpoint:
             state_dict = checkpoint['model']
             state_dict = checkpoint
         if self.convert_weights:
             # Because pvt backbones are not supported by mmcls,
             # so we need to convert pre-trained weights to match this
             # implementation.
             state_dict = pvt_convert(state_dict)
         load_state_dict(self, state_dict, strict=False, logger=logger)
Exemplo n.º 7
    def init_weights(self):
        """Initialize the weights in backbone."""
        logger = get_root_logger()
        if self.init_cfg is None:
            logger.warning(f'No pre-trained weights for '
                           f'{self.__class__.__name__}, '
                           f'training start from scratch')
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m, std=.02, bias=0.)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m, 1.0)
            for bly in self.layers:
            if self.ape:
                raise NotImplementedError
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            ckpt = _load_checkpoint(self.init_cfg.checkpoint,
            if 'state_dict' in ckpt:
                state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                state_dict = ckpt['model']
                state_dict = ckpt

            # delete keys for reinitialization
            reinit_keys = ('relative_position_index', 'relative_coords_table')
            for reinit_key in reinit_keys:
                for k in list(state_dict.keys()):
                    if reinit_key in k:
                        del state_dict[k]

            load_state_dict(self, state_dict, strict=False, logger=logger)
Exemplo n.º 8
    def init_weights(self):
        if self.pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m.weight, std=.02)
                    if m.bias is not None:
                        constant_init(m.bias, 0)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m.bias, 0)
                    constant_init(m.weight, 1.0)
                elif isinstance(m, nn.Conv2d):
                    fan_out = m.kernel_size[0] * m.kernel_size[
                        1] * m.out_channels
                    fan_out //= m.groups
                    normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
                    if m.bias is not None:
                        constant_init(m.bias, 0)
        elif isinstance(self.pretrained, str):
            logger = get_root_logger()
            checkpoint = _load_checkpoint(self.pretrained,
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            elif 'model' in checkpoint:
                state_dict = checkpoint['model']
                state_dict = checkpoint

            if self.pretrain_style == 'official':
                # Because segformer backbone is not support by mmcls,
                # so we need to convert pretrain weights to match this
                # implementation.
                state_dict = mit_convert(state_dict)

            self.load_state_dict(state_dict, False)
Exemplo n.º 9
    def init_weights(self):
        if isinstance(self.pretrained, str):
            logger = get_root_logger()
            checkpoint = _load_checkpoint(self.pretrained,
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            elif 'model' in checkpoint:
                state_dict = checkpoint['model']
                state_dict = checkpoint

            if self.pretrain_style == 'timm':
                # Because the refactor of vit is blocked by mmcls,
                # so we firstly use timm pretrain weights to train
                # downstream model.
                state_dict = vit_convert(state_dict)

            if 'pos_embed' in state_dict.keys():
                if self.pos_embed.shape != state_dict['pos_embed'].shape:
                    logger.info(msg=f'Resize the pos_embed shape from '
                                f'{state_dict["pos_embed"].shape} to '
                    h, w = self.img_size
                    pos_size = int(
                        math.sqrt(state_dict['pos_embed'].shape[1] - 1))
                    state_dict['pos_embed'] = self.resize_pos_embed(
                        (h // self.patch_size, w // self.patch_size),
                        (pos_size, pos_size), self.interpolate_mode)

            self.load_state_dict(state_dict, False)

        elif self.pretrained is None:
            super(VisionTransformer, self).init_weights()
            # We only implement the 'jax_impl' initialization implemented at
            # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353  # noqa: E501
            trunc_normal_init(self.pos_embed, std=.02)
            trunc_normal_init(self.cls_token, std=.02)
            for n, m in self.named_modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m.weight, std=.02)
                    if m.bias is not None:
                        if 'ffn' in n:
                            normal_init(m.bias, std=1e-6)
                            constant_init(m.bias, 0)
                elif isinstance(m, nn.Conv2d):
                    kaiming_init(m.weight, mode='fan_in')
                    if m.bias is not None:
                        constant_init(m.bias, 0)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
                    constant_init(m.bias, 0)
                    constant_init(m.weight, 1.0)
Exemplo n.º 10
    def init_weights(self):
        logger = get_root_logger()
        if self.init_cfg is None:
            logger.warn(f'No pre-trained weights for '
                        f'{self.__class__.__name__}, '
                        f'training start from scratch')
            if self.use_abs_pos_embed:
                trunc_normal_(self.absolute_pos_embed, std=0.02)
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m, std=.02, bias=0.)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m, 1.0)
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            ckpt = _load_checkpoint(self.init_cfg.checkpoint,
            if 'state_dict' in ckpt:
                _state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                _state_dict = ckpt['model']
                _state_dict = ckpt
            if self.convert_weights:
                # supported loading weight from original repo,
                _state_dict = swin_converter(_state_dict)

            state_dict = OrderedDict()
            for k, v in _state_dict.items():
                if k.startswith('backbone.'):
                    state_dict[k[9:]] = v

            # strip prefix of state_dict
            if list(state_dict.keys())[0].startswith('module.'):
                state_dict = {k[7:]: v for k, v in state_dict.items()}

            # reshape absolute position embedding
            if state_dict.get('absolute_pos_embed') is not None:
                absolute_pos_embed = state_dict['absolute_pos_embed']
                N1, L, C1 = absolute_pos_embed.size()
                N2, C2, H, W = self.absolute_pos_embed.size()
                if N1 != N2 or C1 != C2 or L != H * W:
                    logger.warning('Error in loading absolute_pos_embed, pass')
                    state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
                        N2, H, W, C2).permute(0, 3, 1, 2).contiguous()

            # interpolate position bias table if needed
            relative_position_bias_table_keys = [
                k for k in state_dict.keys()
                if 'relative_position_bias_table' in k
            for table_key in relative_position_bias_table_keys:
                table_pretrained = state_dict[table_key]
                table_current = self.state_dict()[table_key]
                L1, nH1 = table_pretrained.size()
                L2, nH2 = table_current.size()
                if nH1 != nH2:
                    logger.warning(f'Error in loading {table_key}, pass')
                elif L1 != L2:
                    S1 = int(L1**0.5)
                    S2 = int(L2**0.5)
                    table_pretrained_resized = F.interpolate(
                        table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
                        size=(S2, S2),
                    state_dict[table_key] = table_pretrained_resized.view(
                        nH2, L2).permute(1, 0).contiguous()

            # load state_dict
            self.load_state_dict(state_dict, False)
Exemplo n.º 11
 def init_weights(self):
     trunc_normal_init(self.pos_embed, std=0.02)
Exemplo n.º 12
 def init_weights(self):
     """Initiate the parameters from scratch."""
     trunc_normal_init(self.fc_cls, std=self.init_std)
Exemplo n.º 13
 def init_weights(self):
     trunc_normal_init(self.relative_position_bias_table, std=0.02)
Exemplo n.º 14
    def init_weights(self):
        if self.pretrained is None:
            if self.use_abs_pos_embed:
                trunc_normal_init(self.absolute_pos_embed, std=0.02)
            for m in self.modules():
                if isinstance(m, Linear):
                    trunc_normal_init(m.weight, std=.02)
                    if m.bias is not None:
                        constant_init(m.bias, 0)
                elif isinstance(m, LayerNorm):
                    constant_init(m.bias, 0)
                    constant_init(m.weight, 1.0)
        elif isinstance(self.pretrained, str):
            logger = get_root_logger()
            ckpt = _load_checkpoint(self.pretrained,
            if 'state_dict' in ckpt:
                state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                state_dict = ckpt['model']
                state_dict = ckpt

            if self.pretrain_style == 'official':
                state_dict = swin_convert(state_dict)

            # strip prefix of state_dict
            if list(state_dict.keys())[0].startswith('module.'):
                state_dict = {k[7:]: v for k, v in state_dict.items()}

            # reshape absolute position embedding
            if state_dict.get('absolute_pos_embed') is not None:
                absolute_pos_embed = state_dict['absolute_pos_embed']
                N1, L, C1 = absolute_pos_embed.size()
                N2, C2, H, W = self.absolute_pos_embed.size()
                if N1 != N2 or C1 != C2 or L != H * W:
                    logger.warning('Error in loading absolute_pos_embed, pass')
                    state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
                        N2, H, W, C2).permute(0, 3, 1, 2).contiguous()

            # interpolate position bias table if needed
            relative_position_bias_table_keys = [
                k for k in state_dict.keys()
                if 'relative_position_bias_table' in k
            for table_key in relative_position_bias_table_keys:
                table_pretrained = state_dict[table_key]
                table_current = self.state_dict()[table_key]
                L1, nH1 = table_pretrained.size()
                L2, nH2 = table_current.size()
                if nH1 != nH2:
                    logger.warning(f'Error in loading {table_key}, pass')
                    if L1 != L2:
                        S1 = int(L1**0.5)
                        S2 = int(L2**0.5)
                        table_pretrained_resized = F.interpolate(
                            table_pretrained.permute(1, 0).reshape(
                                1, nH1, S1, S1),
                            size=(S2, S2),
                        state_dict[table_key] = table_pretrained_resized.view(
                            nH2, L2).permute(1, 0).contiguous()

            # load state_dict
            self.load_state_dict(state_dict, False)