Пример #1
0
 def init_weights(self, pretrained=None):
     if isinstance(pretrained, str):
         logger = get_root_logger()
         #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
         revise_keys = [(r'^module\.', '')]
         checkpoint = _load_checkpoint(pretrained,
                                       map_location='cpu',
                                       logger=logger)
         # OrderedDict is a subclass of dict
         if not isinstance(checkpoint, dict):
             raise RuntimeError(
                 f'No state_dict found in checkpoint file {pretrained}')
         # get state_dict from checkpoint
         if 'state_dict' in checkpoint:
             state_dict = checkpoint['state_dict']
         elif 'model' in checkpoint:  # for our model
             state_dict = checkpoint['model']
         else:
             state_dict = checkpoint
         # strip prefix of state_dict
         for p, r in revise_keys:
             state_dict = {
                 re.sub(p, r, k): v
                 for k, v in state_dict.items()
             }
         # load state_dict
         load_state_dict(self, state_dict, strict=False, logger=logger)
Пример #2
0
    def init_weights(self):
        if self.pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m.weight, std=.02)
                    if m.bias is not None:
                        constant_init(m.bias, 0)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m.bias, 0)
                    constant_init(m.weight, 1.0)
                elif isinstance(m, nn.Conv2d):
                    fan_out = m.kernel_size[0] * m.kernel_size[
                        1] * m.out_channels
                    fan_out //= m.groups
                    normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
                    if m.bias is not None:
                        constant_init(m.bias, 0)
        elif isinstance(self.pretrained, str):
            logger = get_root_logger()
            checkpoint = _load_checkpoint(self.pretrained,
                                          logger=logger,
                                          map_location='cpu')
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint

            self.load_state_dict(state_dict, False)
Пример #3
0
    def _load_torchvision_checkpoint(self, logger=None):
        """Initiate the parameters from torchvision pretrained checkpoint."""
        state_dict_torchvision = _load_checkpoint(self.pretrained)
        if 'state_dict' in state_dict_torchvision:
            state_dict_torchvision = state_dict_torchvision['state_dict']

        loaded_param_names = []
        for name, module in self.named_modules():
            if isinstance(module, ConvModule):
                # we use a ConvModule to wrap conv+bn+relu layers, thus the
                # name mapping is needed
                if 'downsample' in name:
                    # layer{X}.{Y}.downsample.conv->layer{X}.{Y}.downsample.0
                    original_conv_name = name + '.0'
                    # layer{X}.{Y}.downsample.bn->layer{X}.{Y}.downsample.1
                    original_bn_name = name + '.1'
                else:
                    # layer{X}.{Y}.conv{n}.conv->layer{X}.{Y}.conv{n}
                    original_conv_name = name
                    # layer{X}.{Y}.conv{n}.bn->layer{X}.{Y}.bn{n}
                    original_bn_name = name.replace('conv', 'bn')
                self._load_conv_params(module.conv, state_dict_torchvision,
                                       original_conv_name, loaded_param_names)
                self._load_bn_params(module.bn, state_dict_torchvision,
                                     original_bn_name, loaded_param_names)

        # check if any parameters in the 2d checkpoint are not loaded
        remaining_names = set(
            state_dict_torchvision.keys()) - set(loaded_param_names)
        if remaining_names:
            logger.info(
                f'These parameters in pretrained checkpoint are not loaded'
                f': {remaining_names}')
Пример #4
0
    def init_weights(self, pretrained=None):
        logger = get_root_logger()
        if self.init_cfg is None and pretrained is None:
            logger.warn(f'No pre-trained weights for '
                        f'{self.__class__.__name__}, '
                        f'training start from scratch')
            pass
        else:
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            if self.init_cfg is not None:
                ckpt_path = self.init_cfg['checkpoint']
            elif pretrained is not None:
                ckpt_path = pretrained

            ckpt = _load_checkpoint(ckpt_path,
                                    logger=logger,
                                    map_location='cpu')
            if 'state_dict' in ckpt:
                _state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                _state_dict = ckpt['model']
            else:
                _state_dict = ckpt

            state_dict = _state_dict
            missing_keys, unexpected_keys = \
                self.load_state_dict(state_dict, False)
Пример #5
0
    def init_weights(self):
        """Initialize the weights in backbone."""
        logger = get_root_logger()
        if self.init_cfg is None:
            logger.warning(f'No pre-trained weights for '
                           f'{self.__class__.__name__}, '
                           f'training start from scratch')
            self.apply(self._init_weights)
        else:
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            ckpt = _load_checkpoint(self.init_cfg.checkpoint,
                                    logger=logger,
                                    map_location='cpu')
            if 'state_dict' in ckpt:
                state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                state_dict = ckpt['model']
            else:
                state_dict = ckpt

            missing_keys, unexpected_keys = \
                self.load_state_dict(state_dict, False)
            logger.warning(f'missing_keys: {missing_keys}')
            logger.warning(f'unexpected_keys: {unexpected_keys}')
Пример #6
0
    def init_weights(self, pretrained=None):
        """Initiate the parameters either from existing checkpoint or from
        scratch."""
        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)

        if pretrained:
            self.pretrained = pretrained
        if isinstance(self.pretrained, str):
            logger = get_root_logger()
            logger.info(f'load model from: {self.pretrained}')

            state_dict = _load_checkpoint(self.pretrained)
            if 'state_dict' in state_dict:
                state_dict = state_dict['state_dict']

            if self.attention_type == 'divided_space_time':
                # modify the key names of norm layers
                old_state_dict_keys = list(state_dict.keys())
                for old_key in old_state_dict_keys:
                    if 'norms' in old_key:
                        new_key = old_key.replace('norms.0',
                                                  'attentions.0.norm')
                        new_key = new_key.replace('norms.1', 'ffns.0.norm')
                        state_dict[new_key] = state_dict.pop(old_key)

                # copy the parameters of space attention to time attention
                old_state_dict_keys = list(state_dict.keys())
                for old_key in old_state_dict_keys:
                    if 'attentions.0' in old_key:
                        new_key = old_key.replace('attentions.0',
                                                  'attentions.1')
                        state_dict[new_key] = state_dict[old_key].clone()

            load_state_dict(self, state_dict, strict=False, logger=logger)
Пример #7
0
    def _load_file(self, filename):
        if filename.endswith(".pkl"):
            with PathManager.open(filename, "rb") as f:
                data = pickle.load(f, encoding="latin1")
            if "model" in data and "__author__" in data:
                # file is in Detectron2 model zoo format
                self.logger.info("Reading a file from '{}'".format(data["__author__"]))
                return data
            else:
                # assume file is from Caffe2 / Detectron1 model zoo
                if "blobs" in data:
                    # Detection models have "blobs", but ImageNet models don't
                    data = data["blobs"]
                data = {k: v for k, v in data.items() if not k.endswith("_momentum")}
                if "weight_order" in data:
                    del data["weight_order"]
                return {"model": data, "__author__": "Caffe2", "matching_heuristics": True}

        if filename.startswith("torchvision://") or filename.startswith(("http://", "https://")):
            loaded = _load_checkpoint(filename)  # load torchvision pretrained model using mmcv
        else:
            loaded = super()._load_file(filename)  # load native pth checkpoint
        if "model" not in loaded:
            loaded = {"model": loaded}

        basename = os.path.basename(filename).lower()
        if "lpf" in basename or "dla" in basename:
            loaded["matching_heuristics"] = True
        return loaded
Пример #8
0
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            logger = get_root_logger()

            checkpoint = _load_checkpoint(pretrained,
                                          logger=logger,
                                          map_location='cpu')
            logger.warning(f'Load pre-trained model for '
                           f'{self.__class__.__name__} from original repo')
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            elif 'model' in checkpoint:
                state_dict = checkpoint['model']
            else:
                state_dict = checkpoint

            if self.convert_weights:
                # We need to convert pre-trained weights to match this
                # implementation.
                state_dict = tcformer_convert(state_dict)
            load_state_dict(self, state_dict, strict=False, logger=logger)

        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m, std=.02, bias=0.)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m, 1.0)
                elif isinstance(m, nn.Conv2d):
                    fan_out = m.kernel_size[0] * m.kernel_size[
                        1] * m.out_channels
                    fan_out //= m.groups
                    normal_init(m, 0, math.sqrt(2.0 / fan_out))
        else:
            raise TypeError('pretrained must be a str or None')
Пример #9
0
def _check_backbone(config, print_cfg=True):
    """Check out backbone whether successfully load pretrained model, by using
    `backbone.init_cfg`.

    First, using `mmcv._load_checkpoint` to load the checkpoint without
        loading models.
    Then, using `build_detector` to build models, and using
        `model.init_weights()` to initialize the parameters.
    Finally, assert weights and bias of each layer loaded from pretrained
        checkpoint are equal to the weights and bias of original checkpoint.
        For the convenience of comparison, we sum up weights and bias of
        each loaded layer separately.

    Args:
        config (str): Config file path.
        print_cfg (bool): Whether print logger and return the result.

    Returns:
        results (str or None): If backbone successfully load pretrained
            checkpoint, return None; else, return config file path.
    """
    if print_cfg:
        print('-' * 15 + 'loading ', config)
    cfg = Config.fromfile(config)
    init_cfg = None
    try:
        init_cfg = cfg.model.backbone.init_cfg
        init_flag = True
    except AttributeError:
        init_flag = False
    if init_cfg is None or init_cfg.get('type') != 'Pretrained':
        init_flag = False
    if init_flag:
        checkpoint = _load_checkpoint(init_cfg.checkpoint)
        if 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        else:
            state_dict = checkpoint

        model = build_detector(cfg.model,
                               train_cfg=cfg.get('train_cfg'),
                               test_cfg=cfg.get('test_cfg'))
        model.init_weights()

        checkpoint_layers = state_dict.keys()
        for name, value in model.backbone.state_dict().items():
            if name in checkpoint_layers:
                assert value.equal(state_dict[name])

        if print_cfg:
            print('-' * 10 + 'Successfully load checkpoint' + '-' * 10 +
                  '\n', )
            return None
    else:
        if print_cfg:
            print(config + '\n' + '-' * 10 +
                  'config file do not have init_cfg' + '-' * 10 + '\n')
            return config
Пример #10
0
    def inflate_weights(self, logger):
        """Inflate the resnet2d parameters to resnet3d pathway.

        The differences between resnet3d and resnet2d mainly lie in an extra
        axis of conv kernel. To utilize the pretrained parameters in 2d model,
        the weight of conv2d models should be inflated to fit in the shapes of
        the 3d counterpart. For pathway the ``lateral_connection`` part should
        not be inflated from 2d weights.

        Args:
            logger (logging.Logger): The logger used to print
                debugging infomation.
        """

        state_dict_r2d = _load_checkpoint(self.pretrained)
        if 'state_dict' in state_dict_r2d:
            state_dict_r2d = state_dict_r2d['state_dict']

        inflated_param_names = []
        for name, module in self.named_modules():
            if 'lateral' in name:
                continue
            if isinstance(module, ConvModule):
                # we use a ConvModule to wrap conv+bn+relu layers, thus the
                # name mapping is needed
                if 'downsample' in name:
                    # layer{X}.{Y}.downsample.conv->layer{X}.{Y}.downsample.0
                    original_conv_name = name + '.0'
                    # layer{X}.{Y}.downsample.bn->layer{X}.{Y}.downsample.1
                    original_bn_name = name + '.1'
                else:
                    # layer{X}.{Y}.conv{n}.conv->layer{X}.{Y}.conv{n}
                    original_conv_name = name
                    # layer{X}.{Y}.conv{n}.bn->layer{X}.{Y}.bn{n}
                    original_bn_name = name.replace('conv', 'bn')
                if original_conv_name + '.weight' not in state_dict_r2d:
                    logger.warning(f'Module not exist in the state_dict_r2d'
                                   f': {original_conv_name}')
                else:
                    self._inflate_conv_params(module.conv, state_dict_r2d,
                                              original_conv_name,
                                              inflated_param_names)
                if original_bn_name + '.weight' not in state_dict_r2d:
                    logger.warning(f'Module not exist in the state_dict_r2d'
                                   f': {original_bn_name}')
                else:
                    self._inflate_bn_params(module.bn, state_dict_r2d,
                                            original_bn_name,
                                            inflated_param_names)

        # check if any parameters in the 2d checkpoint are not loaded
        remaining_names = set(
            state_dict_r2d.keys()) - set(inflated_param_names)
        if remaining_names:
            logger.info(f'These parameters in the 2d checkpoint are not loaded'
                        f': {remaining_names}')
Пример #11
0
    def init_weights(self):
        if isinstance(self.pretrained, str):
            logger = get_root_logger()
            checkpoint = _load_checkpoint(self.pretrained,
                                          logger=logger,
                                          map_location='cpu')
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            elif 'model' in checkpoint:
                state_dict = checkpoint['model']
            else:
                state_dict = checkpoint

            if self.pretrain_style == 'timm':
                # Because the refactor of vit is blocked by mmcls,
                # so we firstly use timm pretrain weights to train
                # downstream model.
                state_dict = vit_convert(state_dict)

            if 'pos_embed' in state_dict.keys():
                if self.pos_embed.shape != state_dict['pos_embed'].shape:
                    logger.info(msg=f'Resize the pos_embed shape from '
                                f'{state_dict["pos_embed"].shape} to '
                                f'{self.pos_embed.shape}')
                    h, w = self.img_size
                    pos_size = int(
                        math.sqrt(state_dict['pos_embed'].shape[1] - 1))
                    state_dict['pos_embed'] = self.resize_pos_embed(
                        state_dict['pos_embed'],
                        (h // self.patch_size, w // self.patch_size),
                        (pos_size, pos_size), self.interpolate_mode)

            self.load_state_dict(state_dict, False)

        elif self.pretrained is None:
            super(VisionTransformer, self).init_weights()
            # We only implement the 'jax_impl' initialization implemented at
            # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353  # noqa: E501
            trunc_normal_init(self.pos_embed, std=.02)
            trunc_normal_init(self.cls_token, std=.02)
            for n, m in self.named_modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m.weight, std=.02)
                    if m.bias is not None:
                        if 'ffn' in n:
                            normal_init(m.bias, std=1e-6)
                        else:
                            constant_init(m.bias, 0)
                elif isinstance(m, nn.Conv2d):
                    kaiming_init(m.weight, mode='fan_in')
                    if m.bias is not None:
                        constant_init(m.bias, 0)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
                    constant_init(m.bias, 0)
                    constant_init(m.weight, 1.0)
Пример #12
0
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            logger = get_root_logger()
            checkpoint = _load_checkpoint(pretrained, logger=logger)
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            elif 'model' in checkpoint:
                state_dict = checkpoint['model']
            else:
                state_dict = checkpoint

            if 'pos_embed' in state_dict.keys():
                if self.pos_embed.shape != state_dict['pos_embed'].shape:
                    logger.info(msg=f'Resize the pos_embed shape from \
{state_dict["pos_embed"].shape} to {self.pos_embed.shape}')
                    h, w = self.img_size
                    pos_size = int(
                        math.sqrt(state_dict['pos_embed'].shape[1] - 1))
                    state_dict['pos_embed'] = self.resize_pos_embed(
                        state_dict['pos_embed'], (h, w), (pos_size, pos_size),
                        self.patch_size, self.interpolate_mode)

            self.load_state_dict(state_dict, False)

        elif pretrained is None:
            # We only implement the 'jax_impl' initialization implemented at
            # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353  # noqa: E501
            trunc_normal_(self.pos_embed, std=.02)
            trunc_normal_(self.cls_token, std=.02)
            for n, m in self.named_modules():
                if isinstance(m, Linear):
                    trunc_normal_(m.weight, std=.02)
                    if m.bias is not None:
                        if 'mlp' in n:
                            normal_init(m.bias, std=1e-6)
                        else:
                            constant_init(m.bias, 0)
                elif isinstance(m, Conv2d):
                    kaiming_init(m.weight, mode='fan_in')
                    if m.bias is not None:
                        constant_init(m.bias, 0)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
                    constant_init(m.bias, 0)
                    constant_init(m.weight, 1.0)
        else:
            raise TypeError('pretrained must be a str or None')
Пример #13
0
    def init_weights(self):
        if (isinstance(self.init_cfg, dict)
                and self.init_cfg.get('type') == 'Pretrained'):
            logger = get_root_logger()
            checkpoint = _load_checkpoint(
                self.init_cfg['checkpoint'], logger=logger, map_location='cpu')

            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint

            if 'pos_embed' in state_dict.keys():
                if self.pos_embed.shape != state_dict['pos_embed'].shape:
                    logger.info(msg=f'Resize the pos_embed shape from '
                                f'{state_dict["pos_embed"].shape} to '
                                f'{self.pos_embed.shape}')
                    h, w = self.img_size
                    pos_size = int(
                        math.sqrt(state_dict['pos_embed'].shape[1] - 1))
                    state_dict['pos_embed'] = self.resize_pos_embed(
                        state_dict['pos_embed'],
                        (h // self.patch_size, w // self.patch_size),
                        (pos_size, pos_size), self.interpolate_mode)

            self.load_state_dict(state_dict, False)
        elif self.init_cfg is not None:
            super(VisionTransformer, self).init_weights()
        else:
            # We only implement the 'jax_impl' initialization implemented at
            # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353  # noqa: E501
            trunc_normal_(self.pos_embed, std=.02)
            trunc_normal_(self.cls_token, std=.02)
            for n, m in self.named_modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_(m.weight, std=.02)
                    if m.bias is not None:
                        if 'ffn' in n:
                            nn.init.normal_(m.bias, mean=0., std=1e-6)
                        else:
                            nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Conv2d):
                    kaiming_init(m, mode='fan_in', bias=0.)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
                    constant_init(m, val=1.0, bias=0.)
Пример #14
0
 def init_weights(self):
     logger = get_root_logger()
     if self.init_cfg is None:
         logger.warn(f'No pre-trained weights for '
                     f'{self.__class__.__name__}, '
                     f'training start from scratch')
         for m in self.modules():
             if isinstance(m, nn.Linear):
                 trunc_normal_init(m.weight, std=.02)
                 if m.bias is not None:
                     constant_init(m.bias, 0)
             elif isinstance(m, nn.LayerNorm):
                 constant_init(m.bias, 0)
                 constant_init(m.weight, 1.0)
             elif isinstance(m, nn.Conv2d):
                 fan_out = m.kernel_size[0] * m.kernel_size[
                     1] * m.out_channels
                 fan_out //= m.groups
                 normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
                 if m.bias is not None:
                     constant_init(m.bias, 0)
             elif isinstance(m, AbsolutePositionEmbedding):
                 m.init_weights()
     else:
         assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                               f'specify `Pretrained` in ' \
                                               f'`init_cfg` in ' \
                                               f'{self.__class__.__name__} '
         checkpoint = _load_checkpoint(
             self.init_cfg.checkpoint, logger=logger, map_location='cpu')
         logger.warn(f'Load pre-trained model for '
                     f'{self.__class__.__name__} from original repo')
         if 'state_dict' in checkpoint:
             state_dict = checkpoint['state_dict']
         elif 'model' in checkpoint:
             state_dict = checkpoint['model']
         else:
             state_dict = checkpoint
         if self.convert_weights:
             # Because pvt backbones are not supported by mmcls,
             # so we need to convert pre-trained weights to match this
             # implementation.
             state_dict = pvt_convert(state_dict)
         load_state_dict(self, state_dict, strict=False, logger=logger)
Пример #15
0
    def init_weights(self):

        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        self.apply(_init_weights)
        self.fix_init_weight()

        if (isinstance(self.init_cfg, dict)
                and self.init_cfg.get('type') == 'Pretrained'):
            logger = get_root_logger()
            checkpoint = _load_checkpoint(
                self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
            state_dict = self.resize_rel_pos_embed(checkpoint)
            state_dict = self.resize_abs_pos_embed(state_dict)
            self.load_state_dict(state_dict, False)
        elif self.init_cfg is not None:
            super(MAE, self).init_weights()
        else:
            # We only implement the 'jax_impl' initialization implemented at
            # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353  # noqa: E501
            # Copyright 2019 Ross Wightman
            # Licensed under the Apache License, Version 2.0 (the "License")
            trunc_normal_(self.cls_token, std=.02)
            for n, m in self.named_modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_(m.weight, std=.02)
                    if m.bias is not None:
                        if 'ffn' in n:
                            nn.init.normal_(m.bias, mean=0., std=1e-6)
                        else:
                            nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Conv2d):
                    kaiming_init(m, mode='fan_in', bias=0.)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
                    constant_init(m, val=1.0, bias=0.)
    def _load_checkpoint(self, checkpoint, prefix=None, map_location=None):
        from mmcv.runner import (_load_checkpoint,
                                 _load_checkpoint_with_prefix, load_state_dict)
        from mmcv.utils import print_log

        logger = get_root_logger()

        if prefix is None:
            print_log(f'load model from: {checkpoint}', logger=logger)
            checkpoint = _load_checkpoint(checkpoint, map_location, logger)
            # get state_dict from checkpoint
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint
        else:
            print_log(
                f'load {prefix} in model from: {checkpoint}', logger=logger)
            state_dict = _load_checkpoint_with_prefix(prefix, checkpoint,
                                                      map_location)

        if 'pos_embed' in state_dict.keys():
            ckpt_pos_embed_shape = state_dict['pos_embed'].shape
            if self.pos_embed.shape != ckpt_pos_embed_shape:
                print_log(
                    f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
                    f'to {self.pos_embed.shape}.',
                    logger=logger)

                ckpt_pos_embed_shape = to_2tuple(
                    int(np.sqrt(ckpt_pos_embed_shape[1] - 1)))
                pos_embed_shape = self.patch_embed.patches_resolution

                state_dict['pos_embed'] = self.resize_pos_embed(
                    state_dict['pos_embed'], ckpt_pos_embed_shape,
                    pos_embed_shape, self.interpolate_mode)

        # load state_dict
        load_state_dict(self, state_dict, strict=False, logger=logger)
Пример #17
0
    def init_weights(self):
        """Initialize the weights in backbone."""
        logger = get_root_logger()
        if self.init_cfg is None:
            logger.warning(f'No pre-trained weights for '
                           f'{self.__class__.__name__}, '
                           f'training start from scratch')
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m, std=.02, bias=0.)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m, 1.0)
            for bly in self.layers:
                bly._init_respostnorm()
        else:
            if self.ape:
                raise NotImplementedError
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            ckpt = _load_checkpoint(self.init_cfg.checkpoint,
                                    logger=logger,
                                    map_location='cpu')
            if 'state_dict' in ckpt:
                state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                state_dict = ckpt['model']
            else:
                state_dict = ckpt

            # delete keys for reinitialization
            reinit_keys = ('relative_position_index', 'relative_coords_table')
            for reinit_key in reinit_keys:
                for k in list(state_dict.keys()):
                    if reinit_key in k:
                        del state_dict[k]

            load_state_dict(self, state_dict, strict=False, logger=logger)
Пример #18
0
    def init_weights(self):
        if self.pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m.weight, std=.02)
                    if m.bias is not None:
                        constant_init(m.bias, 0)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m.bias, 0)
                    constant_init(m.weight, 1.0)
                elif isinstance(m, nn.Conv2d):
                    fan_out = m.kernel_size[0] * m.kernel_size[
                        1] * m.out_channels
                    fan_out //= m.groups
                    normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
                    if m.bias is not None:
                        constant_init(m.bias, 0)
        elif isinstance(self.pretrained, str):
            logger = get_root_logger()
            checkpoint = _load_checkpoint(self.pretrained,
                                          logger=logger,
                                          map_location='cpu')
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            elif 'model' in checkpoint:
                state_dict = checkpoint['model']
            else:
                state_dict = checkpoint

            if self.pretrain_style == 'official':
                # Because segformer backbone is not support by mmcls,
                # so we need to convert pretrain weights to match this
                # implementation.
                state_dict = mit_convert(state_dict)

            self.load_state_dict(state_dict, False)
Пример #19
0
def test_resnet3d_backbone():
    """Test resnet3d backbone."""
    with pytest.raises(AssertionError):
        # In ResNet3d: 1 <= num_stages <= 4
        ResNet3d(34, None, num_stages=0)

    with pytest.raises(AssertionError):
        # In ResNet3d: 1 <= num_stages <= 4
        ResNet3d(34, None, num_stages=5)

    with pytest.raises(AssertionError):
        # In ResNet3d: 1 <= num_stages <= 4
        ResNet3d(50, None, num_stages=0)

    with pytest.raises(AssertionError):
        # In ResNet3d: 1 <= num_stages <= 4
        ResNet3d(50, None, num_stages=5)

    with pytest.raises(AssertionError):
        # len(spatial_strides) == len(temporal_strides)
        # == len(dilations) == num_stages
        ResNet3d(50,
                 None,
                 spatial_strides=(1, ),
                 temporal_strides=(1, 1),
                 dilations=(1, 1, 1),
                 num_stages=4)

    with pytest.raises(AssertionError):
        # len(spatial_strides) == len(temporal_strides)
        # == len(dilations) == num_stages
        ResNet3d(34,
                 None,
                 spatial_strides=(1, ),
                 temporal_strides=(1, 1),
                 dilations=(1, 1, 1),
                 num_stages=4)

    with pytest.raises(TypeError):
        # pretrain must be str or None.
        resnet3d_34 = ResNet3d(34, ['resnet', 'bninception'])
        resnet3d_34.init_weights()

    with pytest.raises(TypeError):
        # pretrain must be str or None.
        resnet3d_50 = ResNet3d(50, ['resnet', 'bninception'])
        resnet3d_50.init_weights()

    # resnet3d with depth 34, no pretrained, norm_eval True
    resnet3d_34 = ResNet3d(34, None, pretrained2d=False, norm_eval=True)
    resnet3d_34.init_weights()
    resnet3d_34.train()
    assert check_norm_state(resnet3d_34.modules(), False)

    # resnet3d with depth 50, no pretrained, norm_eval True
    resnet3d_50 = ResNet3d(50, None, pretrained2d=False, norm_eval=True)
    resnet3d_50.init_weights()
    resnet3d_50.train()
    assert check_norm_state(resnet3d_50.modules(), False)

    # resnet3d with depth 50, pretrained2d, norm_eval True
    resnet3d_50_pretrain = ResNet3d(50,
                                    'torchvision://resnet50',
                                    norm_eval=True)
    resnet3d_50_pretrain.init_weights()
    resnet3d_50_pretrain.train()
    assert check_norm_state(resnet3d_50_pretrain.modules(), False)
    from mmcv.runner import _load_checkpoint
    chkp_2d = _load_checkpoint('torchvision://resnet50')
    for name, module in resnet3d_50_pretrain.named_modules():
        if len(name.split('.')) == 4:
            # layer.block.module.submodule
            prefix = name.split('.')[:2]
            module_type = name.split('.')[2]
            submodule_type = name.split('.')[3]

            if module_type == 'downsample':
                name2d = name.replace('conv', '0').replace('bn', '1')
            else:
                layer_id = name.split('.')[2][-1]
                name2d = prefix[0] + '.' + prefix[1] + '.' + \
                    submodule_type + layer_id

            if isinstance(module, nn.Conv3d):
                conv2d_weight = chkp_2d[name2d + '.weight']
                conv3d_weight = getattr(module, 'weight').data
                assert torch.equal(
                    conv3d_weight,
                    conv2d_weight.data.unsqueeze(2).expand_as(conv3d_weight) /
                    conv3d_weight.shape[2])
                if getattr(module, 'bias') is not None:
                    conv2d_bias = chkp_2d[name2d + '.bias']
                    conv3d_bias = getattr(module, 'bias').data
                    assert torch.equal(conv2d_bias, conv3d_bias)

            elif isinstance(module, nn.BatchNorm3d):
                for pname in ['weight', 'bias', 'running_mean', 'running_var']:
                    param_2d = chkp_2d[name2d + '.' + pname]
                    param_3d = getattr(module, pname).data
                assert torch.equal(param_2d, param_3d)

    conv3d = resnet3d_50_pretrain.conv1.conv
    assert torch.equal(
        conv3d.weight,
        chkp_2d['conv1.weight'].unsqueeze(2).expand_as(conv3d.weight) /
        conv3d.weight.shape[2])
    conv3d = resnet3d_50_pretrain.layer3[2].conv2.conv
    assert torch.equal(
        conv3d.weight, chkp_2d['layer3.2.conv2.weight'].unsqueeze(2).expand_as(
            conv3d.weight) / conv3d.weight.shape[2])

    # resnet3d with depth 34, no pretrained, norm_eval False
    resnet3d_34_no_bn_eval = ResNet3d(34,
                                      None,
                                      pretrained2d=False,
                                      norm_eval=False)
    resnet3d_34_no_bn_eval.init_weights()
    resnet3d_34_no_bn_eval.train()
    assert check_norm_state(resnet3d_34_no_bn_eval.modules(), True)

    # resnet3d with depth 50, no pretrained, norm_eval False
    resnet3d_50_no_bn_eval = ResNet3d(50,
                                      None,
                                      pretrained2d=False,
                                      norm_eval=False)
    resnet3d_50_no_bn_eval.init_weights()
    resnet3d_50_no_bn_eval.train()
    assert check_norm_state(resnet3d_50_no_bn_eval.modules(), True)

    # resnet3d with depth 34, no pretrained, frozen_stages, norm_eval False
    frozen_stages = 1
    resnet3d_34_frozen = ResNet3d(34,
                                  None,
                                  pretrained2d=False,
                                  frozen_stages=frozen_stages)
    resnet3d_34_frozen.init_weights()
    resnet3d_34_frozen.train()
    assert resnet3d_34_frozen.conv1.bn.training is False
    for param in resnet3d_34_frozen.conv1.parameters():
        assert param.requires_grad is False
    for i in range(1, frozen_stages + 1):
        layer = getattr(resnet3d_34_frozen, f'layer{i}')
        for mod in layer.modules():
            if isinstance(mod, _BatchNorm):
                assert mod.training is False
        for param in layer.parameters():
            assert param.requires_grad is False
    # test zero_init_residual
    for m in resnet3d_34_frozen.modules():
        if hasattr(m, 'conv2'):
            assert torch.equal(m.conv2.bn.weight,
                               torch.zeros_like(m.conv2.bn.weight))
            assert torch.equal(m.conv2.bn.bias,
                               torch.zeros_like(m.conv2.bn.bias))

    # resnet3d with depth 50, no pretrained, frozen_stages, norm_eval False
    frozen_stages = 1
    resnet3d_50_frozen = ResNet3d(50,
                                  None,
                                  pretrained2d=False,
                                  frozen_stages=frozen_stages)
    resnet3d_50_frozen.init_weights()
    resnet3d_50_frozen.train()
    assert resnet3d_50_frozen.conv1.bn.training is False
    for param in resnet3d_50_frozen.conv1.parameters():
        assert param.requires_grad is False
    for i in range(1, frozen_stages + 1):
        layer = getattr(resnet3d_50_frozen, f'layer{i}')
        for mod in layer.modules():
            if isinstance(mod, _BatchNorm):
                assert mod.training is False
        for param in layer.parameters():
            assert param.requires_grad is False
    # test zero_init_residual
    for m in resnet3d_50_frozen.modules():
        if hasattr(m, 'conv3'):
            assert torch.equal(m.conv3.bn.weight,
                               torch.zeros_like(m.conv3.bn.weight))
            assert torch.equal(m.conv3.bn.bias,
                               torch.zeros_like(m.conv3.bn.bias))

    # resnet3d frozen with depth 34 inference
    input_shape = (1, 3, 6, 64, 64)
    imgs = _demo_inputs(input_shape)
    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            resnet3d_34_frozen = resnet3d_34_frozen.cuda()
            imgs_gpu = imgs.cuda()
            feat = resnet3d_34_frozen(imgs_gpu)
            assert feat.shape == torch.Size([1, 512, 1, 2, 2])
    else:
        feat = resnet3d_34_frozen(imgs)
        assert feat.shape == torch.Size([1, 512, 1, 2, 2])

    # resnet3d with depth 50 inference
    input_shape = (1, 3, 6, 64, 64)
    imgs = _demo_inputs(input_shape)
    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            resnet3d_50_frozen = resnet3d_50_frozen.cuda()
            imgs_gpu = imgs.cuda()
            feat = resnet3d_50_frozen(imgs_gpu)
            assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
    else:
        feat = resnet3d_50_frozen(imgs)
        assert feat.shape == torch.Size([1, 2048, 1, 2, 2])

    # resnet3d with depth 50 in caffe style inference
    resnet3d_50_caffe = ResNet3d(50, None, pretrained2d=False, style='caffe')
    resnet3d_50_caffe.init_weights()
    resnet3d_50_caffe.train()

    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            resnet3d_50_caffe = resnet3d_50_caffe.cuda()
            imgs_gpu = imgs.cuda()
            feat = resnet3d_50_caffe(imgs_gpu)
            assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
    else:
        feat = resnet3d_50_caffe(imgs)
        assert feat.shape == torch.Size([1, 2048, 1, 2, 2])

    # resnet3d with depth 34 in caffe style inference
    resnet3d_34_caffe = ResNet3d(34, None, pretrained2d=False, style='caffe')
    resnet3d_34_caffe.init_weights()
    resnet3d_34_caffe.train()
    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            resnet3d_34_caffe = resnet3d_34_caffe.cuda()
            imgs_gpu = imgs.cuda()
            feat = resnet3d_34_caffe(imgs_gpu)
            assert feat.shape == torch.Size([1, 512, 1, 2, 2])
    else:
        feat = resnet3d_34_caffe(imgs)
        assert feat.shape == torch.Size([1, 512, 1, 2, 2])

    # resnet3d with depth with 3x3x3 inflate_style inference
    resnet3d_50_1x1x1 = ResNet3d(50,
                                 None,
                                 pretrained2d=False,
                                 inflate_style='3x3x3')
    resnet3d_50_1x1x1.init_weights()
    resnet3d_50_1x1x1.train()
    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            resnet3d_50_1x1x1 = resnet3d_50_1x1x1.cuda()
            imgs_gpu = imgs.cuda()
            feat = resnet3d_50_1x1x1(imgs_gpu)
            assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
    else:
        feat = resnet3d_50_1x1x1(imgs)
        assert feat.shape == torch.Size([1, 2048, 1, 2, 2])

    resnet3d_34_1x1x1 = ResNet3d(34,
                                 None,
                                 pretrained2d=False,
                                 inflate_style='3x3x3')
    resnet3d_34_1x1x1.init_weights()
    resnet3d_34_1x1x1.train()

    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            resnet3d_34_1x1x1 = resnet3d_34_1x1x1.cuda()
            imgs_gpu = imgs.cuda()
            feat = resnet3d_34_1x1x1(imgs_gpu)
            assert feat.shape == torch.Size([1, 512, 1, 2, 2])
    else:
        feat = resnet3d_34_1x1x1(imgs)
        assert feat.shape == torch.Size([1, 512, 1, 2, 2])
Пример #20
0
def model_from_checkpoint(
    filename: Union[Path, str],
    model_name=None,
    backbone_name=None,
    classes=None,
    is_coco=False,
    img_size=None,
    map_location=None,
    strict=False,
    revise_keys=[
        (r"^module\.", ""),
    ],
    eval_mode=True,
    logger=None,
):
    """load checkpoint through URL scheme path.
    Args:
        filename (str): checkpoint file name with given prefix
        map_location (str, optional): Same as :func:`torch.load`.
            Default: None
    Returns:
        dict or OrderedDict: The loaded checkpoint.
    """

    if isinstance(filename, Path):
        filename = str(filename)

    checkpoint = _load_checkpoint(filename=filename,
                                  map_location=map_location,
                                  logger=logger)

    if is_coco and classes:
        err_msg = "`is_coco` cannot be set to True if `classes` is passed and `not None`. `classes` has priority. `is_coco` will be ignored."
        if logger is not None:
            logger.warning(err_msg)
        else:
            print(err_msg)

    if classes is None:
        if is_coco:
            classes = CLASSES
        else:
            classes = checkpoint["meta"].get("classes", None)

    class_map = None
    if classes:
        class_map = ClassMap(classes)
        num_classes = len(class_map)

    if img_size is None:
        img_size = checkpoint["meta"].get("img_size", None)

    if model_name is None:
        model_name = checkpoint["meta"].get("model_name", None)

    model_type = None
    if model_name:
        model = model_name.split(".")
        # If model_name contains three or more components, the library and model are the second to last and last components, respectively (e.g. models.mmdet.retinanet)
        if len(model) >= 3:
            model_type = getattr(getattr(models, model[-2]), model[-1])
        # If model_name follows the default convention, the library and model are the first and second components, respectively (e.g. mmdet.retinanet)
        else:
            model_type = getattr(getattr(models, model[0]), model[1])

    if backbone_name is None:
        backbone_name = checkpoint["meta"].get("backbone_name", None)
    if model_type and backbone_name:
        backbone = getattr(model_type.backbones, backbone_name)

    extra_args = {}
    if img_size is None:
        img_size = checkpoint["meta"].get("img_size", None)

    models_with_img_size = ("yolov5", "efficientdet")
    # if 'efficientdet' in model_name:
    if (model_name) and (any(m in model_name for m in models_with_img_size)):
        extra_args["img_size"] = img_size

    # Instantiate model
    if model_type and backbone:
        model = model_type.model(backbone=backbone(pretrained=False),
                                 num_classes=num_classes,
                                 **extra_args)
    else:
        model = None

    # OrderedDict is a subclass of dict
    if not isinstance(checkpoint, dict):
        raise RuntimeError(
            f"No state_dict found in checkpoint file {filename}")
    # get state_dict from checkpoint
    if "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    else:
        state_dict = checkpoint
    # strip prefix of state_dict
    for p, r in revise_keys:
        state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()}

    # load state_dict
    if model:
        load_state_dict(model, state_dict, strict, logger)
    if eval_mode:
        model.eval()

    checkpoint_and_model = {
        "model": model,
        "model_type": model_type,
        "backbone": backbone,
        "class_map": class_map,
        "img_size": img_size,
        "checkpoint": checkpoint,
    }
    return checkpoint_and_model
Пример #21
0
def test_checkpoint_loader():
    from mmcv.runner import _load_checkpoint, save_checkpoint, CheckpointLoader
    import tempfile
    import os
    checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
    model = Model()
    save_checkpoint(model, checkpoint_path)
    checkpoint = _load_checkpoint(checkpoint_path)
    assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']
    # remove the temp file
    os.remove(checkpoint_path)

    filenames = [
        'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth',
        'modelzoo://xx.xx/xx.pth', 'torchvision://xx.xx/xx.pth',
        'open-mmlab://xx.xx/xx.pth', 'openmmlab://xx.xx/xx.pth',
        'mmcls://xx.xx/xx.pth', 'pavi://xx.xx/xx.pth', 's3://xx.xx/xx.pth',
        'ss3://xx.xx/xx.pth', ' s3://xx.xx/xx.pth'
    ]
    fn_names = [
        'load_from_http', 'load_from_http', 'load_from_torchvision',
        'load_from_torchvision', 'load_from_openmmlab', 'load_from_openmmlab',
        'load_from_mmcls', 'load_from_pavi', 'load_from_ceph',
        'load_from_local', 'load_from_local'
    ]

    for filename, fn_name in zip(filenames, fn_names):
        loader = CheckpointLoader._get_checkpoint_loader(filename)
        assert loader.__name__ == fn_name

    @CheckpointLoader.register_scheme(prefixes='ftp://')
    def load_from_ftp(filename, map_location):
        return dict(filename=filename)

    # test register_loader
    filename = 'ftp://xx.xx/xx.pth'
    loader = CheckpointLoader._get_checkpoint_loader(filename)
    assert loader.__name__ == 'load_from_ftp'

    def load_from_ftp1(filename, map_location):
        return dict(filename=filename)

    # test duplicate registered error
    with pytest.raises(KeyError):
        CheckpointLoader.register_scheme('ftp://', load_from_ftp1)

    # test force param
    CheckpointLoader.register_scheme('ftp://', load_from_ftp1, force=True)
    checkpoint = CheckpointLoader.load_checkpoint(filename)
    assert checkpoint['filename'] == filename

    # test print function name
    loader = CheckpointLoader._get_checkpoint_loader(filename)
    assert loader.__name__ == 'load_from_ftp1'

    # test sort
    @CheckpointLoader.register_scheme(prefixes='a/b')
    def load_from_ab(filename, map_location):
        return dict(filename=filename)

    @CheckpointLoader.register_scheme(prefixes='a/b/c')
    def load_from_abc(filename, map_location):
        return dict(filename=filename)

    filename = 'a/b/c/d'
    loader = CheckpointLoader._get_checkpoint_loader(filename)
    assert loader.__name__ == 'load_from_abc'
Пример #22
0
    def init_weights(self):
        if self.pretrained is None:
            super().init_weights()
            if self.use_abs_pos_embed:
                trunc_normal_init(self.absolute_pos_embed, std=0.02)
            for m in self.modules():
                if isinstance(m, Linear):
                    trunc_normal_init(m.weight, std=.02)
                    if m.bias is not None:
                        constant_init(m.bias, 0)
                elif isinstance(m, LayerNorm):
                    constant_init(m.bias, 0)
                    constant_init(m.weight, 1.0)
        elif isinstance(self.pretrained, str):
            logger = get_root_logger()
            ckpt = _load_checkpoint(self.pretrained,
                                    logger=logger,
                                    map_location='cpu')
            if 'state_dict' in ckpt:
                state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                state_dict = ckpt['model']
            else:
                state_dict = ckpt

            if self.pretrain_style == 'official':
                state_dict = swin_convert(state_dict)

            # strip prefix of state_dict
            if list(state_dict.keys())[0].startswith('module.'):
                state_dict = {k[7:]: v for k, v in state_dict.items()}

            # reshape absolute position embedding
            if state_dict.get('absolute_pos_embed') is not None:
                absolute_pos_embed = state_dict['absolute_pos_embed']
                N1, L, C1 = absolute_pos_embed.size()
                N2, C2, H, W = self.absolute_pos_embed.size()
                if N1 != N2 or C1 != C2 or L != H * W:
                    logger.warning('Error in loading absolute_pos_embed, pass')
                else:
                    state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
                        N2, H, W, C2).permute(0, 3, 1, 2).contiguous()

            # interpolate position bias table if needed
            relative_position_bias_table_keys = [
                k for k in state_dict.keys()
                if 'relative_position_bias_table' in k
            ]
            for table_key in relative_position_bias_table_keys:
                table_pretrained = state_dict[table_key]
                table_current = self.state_dict()[table_key]
                L1, nH1 = table_pretrained.size()
                L2, nH2 = table_current.size()
                if nH1 != nH2:
                    logger.warning(f'Error in loading {table_key}, pass')
                else:
                    if L1 != L2:
                        S1 = int(L1**0.5)
                        S2 = int(L2**0.5)
                        table_pretrained_resized = F.interpolate(
                            table_pretrained.permute(1, 0).reshape(
                                1, nH1, S1, S1),
                            size=(S2, S2),
                            mode='bicubic')
                        state_dict[table_key] = table_pretrained_resized.view(
                            nH2, L2).permute(1, 0).contiguous()

            # load state_dict
            self.load_state_dict(state_dict, False)
Пример #23
0
def inflate_weights(model, state_dict_2d, logger=None):
    if isinstance(state_dict_2d, str):
        state_dict_2d = _load_checkpoint(state_dict_2d, map_location='cpu')
        if 'state_dict' in state_dict_2d:
            state_dict_2d = state_dict_2d['state_dict']
    assert isinstance(state_dict_2d, dict)

    first_element = list(state_dict_2d.keys())[0]
    if first_element.startswith('module.'):
        state_dict_2d = {k[7:]: v for k, v in state_dict_2d.items()}
    elif first_element.startswith('backbone.'):
        state_dict_2d = {k[9:]: v for k, v in state_dict_2d.items()}

    missing_keys = []
    shape_mismatch_pairs = []
    shape_inflated_pairs = []
    copied_pairs = []

    for name, module in model.named_modules():
        trg_name = name + '.weight'

        if isinstance(module, nn.Conv3d) and trg_name in state_dict_2d:
            old_weight = state_dict_2d[name + '.weight'].data
            old_weight_shape = old_weight.shape
            assert len(old_weight_shape) in [2, 4]

            if len(old_weight.shape) == 2:
                old_weight = old_weight.unsqueeze(2).unsqueeze(3)
            old_weight = old_weight.unsqueeze(2)

            if not _is_compatible(old_weight.shape, module.weight.data.shape):
                shape_mismatch_pairs.append(
                    [name,
                     list(old_weight_shape),
                     list(module.weight.size())])
                continue

            new_weight = old_weight.expand_as(
                module.weight) / module.weight.data.shape[2]
            module.weight.data.copy_(new_weight)
            shape_inflated_pairs.append(
                [trg_name,
                 list(old_weight_shape),
                 list(module.weight.size())])

            if hasattr(module, 'bias') and module.bias is not None:
                trg_name = name + '.bias'
                new_bias = state_dict_2d[trg_name].data
                module.bias.data.copy_(new_bias)
                copied_pairs.append([trg_name, list(new_bias.size())])
        elif isinstance(module, nn.BatchNorm3d) and trg_name in state_dict_2d:
            for attr_name in ['weight', 'bias', 'running_mean', 'running_var']:
                trg_name = name + '.' + attr_name
                old_attr = state_dict_2d[trg_name].data

                new_attr = getattr(module, attr_name)
                new_attr.data.copy_(old_attr)
                copied_pairs.append([trg_name, list(new_attr.size())])
        elif isinstance(module, (nn.Conv3d, nn.BatchNorm3d)):
            missing_keys.append(name)

    if len(missing_keys) > 0:
        msg = 'Missing keys in source state_dict: {}\n'.format(
            ', '.join(missing_keys))
        if logger is not None:
            logger.warning(msg)

    if shape_mismatch_pairs:
        header = ['key', '2d shape', '3d shape']
        table_data = [header] + shape_mismatch_pairs
        table = AsciiTable(table_data)
        if logger is not None:
            logger.warning('These keys have mismatched shape:\n' + table.table)

    if copied_pairs:
        header = ['key', 'shape']
        table_data = [header] + copied_pairs
        table = AsciiTable(table_data)
        if logger is not None:
            logger.info('These keys have been copied:\n' + table.table)

    if shape_inflated_pairs:
        header = ['key', '2d shape', '3d shape']
        table_data = [header] + shape_inflated_pairs
        table = AsciiTable(table_data)
        if logger is not None:
            logger.info('These keys have been shape inflated:\n' + table.table)
Пример #24
0
    def init_weights(self):
        logger = get_root_logger()
        if self.init_cfg is None:
            logger.warn(f'No pre-trained weights for '
                        f'{self.__class__.__name__}, '
                        f'training start from scratch')
            if self.use_abs_pos_embed:
                trunc_normal_(self.absolute_pos_embed, std=0.02)
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m, std=.02, bias=0.)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m, 1.0)
        else:
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            ckpt = _load_checkpoint(self.init_cfg.checkpoint,
                                    logger=logger,
                                    map_location='cpu')
            if 'state_dict' in ckpt:
                _state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                _state_dict = ckpt['model']
            else:
                _state_dict = ckpt
            if self.convert_weights:
                # supported loading weight from original repo,
                _state_dict = swin_converter(_state_dict)

            state_dict = OrderedDict()
            for k, v in _state_dict.items():
                if k.startswith('backbone.'):
                    state_dict[k[9:]] = v

            # strip prefix of state_dict
            if list(state_dict.keys())[0].startswith('module.'):
                state_dict = {k[7:]: v for k, v in state_dict.items()}

            # reshape absolute position embedding
            if state_dict.get('absolute_pos_embed') is not None:
                absolute_pos_embed = state_dict['absolute_pos_embed']
                N1, L, C1 = absolute_pos_embed.size()
                N2, C2, H, W = self.absolute_pos_embed.size()
                if N1 != N2 or C1 != C2 or L != H * W:
                    logger.warning('Error in loading absolute_pos_embed, pass')
                else:
                    state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
                        N2, H, W, C2).permute(0, 3, 1, 2).contiguous()

            # interpolate position bias table if needed
            relative_position_bias_table_keys = [
                k for k in state_dict.keys()
                if 'relative_position_bias_table' in k
            ]
            for table_key in relative_position_bias_table_keys:
                table_pretrained = state_dict[table_key]
                table_current = self.state_dict()[table_key]
                L1, nH1 = table_pretrained.size()
                L2, nH2 = table_current.size()
                if nH1 != nH2:
                    logger.warning(f'Error in loading {table_key}, pass')
                elif L1 != L2:
                    S1 = int(L1**0.5)
                    S2 = int(L2**0.5)
                    table_pretrained_resized = F.interpolate(
                        table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
                        size=(S2, S2),
                        mode='bicubic')
                    state_dict[table_key] = table_pretrained_resized.view(
                        nH2, L2).permute(1, 0).contiguous()

            # load state_dict
            self.load_state_dict(state_dict, False)