Exemple #1
0
    def init_weights(self, pretrained=None):
        """Initialize model weights."""
        if isinstance(pretrained, str):
            logger = get_root_logger()
            state_dict = get_state_dict(pretrained)
            load_state_dict(
                self.top, state_dict['top'], strict=False, logger=logger)
            for i in range(self.num_stages):
                load_state_dict(
                    self.multi_stage_mspn[i].downsample,
                    state_dict['bottlenecks'],
                    strict=False,
                    logger=logger)

        for m in self.multi_stage_mspn.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)
            elif isinstance(m, nn.BatchNorm2d):
                constant_init(m, 1)
            elif isinstance(m, nn.Linear):
                normal_init(m, std=0.01)

        for m in self.top.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d6')
    net = EfficientDet(config, pretrained_backbone=False)
    config.num_classes = 1
    config.image_size = 1024
    net.class_net = HeadNet(config,
                            num_outputs=config.num_classes,
                            norm_kwargs=dict(eps=.001, momentum=.01))
    checkpoint = torch.load('pretrained/efficientdet_d6-51cb0132.pth')
    load_state_dict(net, checkpoint)
    return DetBenchTrain(net, config)
Exemple #3
0
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d6')
    net = EfficientDet(config, pretrained_backbone=False)
    config.num_classes = 1
    config.image_size = 1024
    net.class_net = HeadNet(config,
                            num_outputs=config.num_classes,
                            norm_kwargs=dict(eps=.001, momentum=.01))
    checkpoint = torch.load(
        'effdet6-baseline-1024-4x8-sa-fold0/best-checkpoint-052epoch.bin')
    load_state_dict(net, checkpoint['model_state_dict'])

    return DetBenchTrain(net, config)
Exemple #4
0
def load_checkpoint(model,
                    filename,
                    map_location=None,
                    strict=False,
                    logger=None):
    """Load checkpoint from a file or URI.

    Args:
        model (Module): Module to load checkpoint.
        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
            details.
        map_location (str): Same as :func:`torch.load`.
        strict (bool): Whether to allow different params for the model and
            checkpoint.
        logger (:mod:`logging.Logger` or None): The logger for error message.

    Returns:
        dict or OrderedDict: The loaded checkpoint.
    """
    checkpoint = _load_checkpoint(filename, map_location)
    # OrderedDict is a subclass of dict
    if not isinstance(checkpoint, dict):
        raise RuntimeError(
            f'No state_dict found in checkpoint file {filename}')
    # get state_dict from checkpoint
    if 'state_dict' in checkpoint:
        state_dict_tmp = checkpoint['state_dict']
    else:
        state_dict_tmp = checkpoint

    state_dict = OrderedDict()
    # strip prefix of state_dict
    for k, v in state_dict_tmp.items():
        if k.startswith('module.backbone.'):
            state_dict[k[16:]] = v
        elif k.startswith('module.'):
            state_dict[k[7:]] = v
        elif k.startswith('backbone.'):
            state_dict[k[9:]] = v
        else:
            state_dict[k] = v
    # load state_dict
    load_state_dict(model, state_dict, strict, logger)
    return checkpoint
Exemple #5
0
def load_pretrained(model, url, filter_fn=None, strict=False, logger=None):
    if not url:
        print(
            "=> Warning: Pretrained model URL is empty, using random initialization."
        )
        return

    state_dict = load_state_dict_from_url(url,
                                          progress=False,
                                          map_location='cpu')

    input_conv = 'conv_stem'
    classifier = 'classifier'
    in_chans = getattr(model, input_conv).weight.shape[1]
    #num_classes = getattr(model, classifier).weight.shape[0]

    input_conv_weight = input_conv + '.weight'
    pretrained_in_chans = state_dict[input_conv_weight].shape[1]
    if in_chans != pretrained_in_chans:
        if in_chans == 1:
            print(
                '=> Converting pretrained input conv {} from {} to 1 channel'.
                format(input_conv_weight, pretrained_in_chans))
            conv1_weight = state_dict[input_conv_weight]
            state_dict[input_conv_weight] = conv1_weight.sum(dim=1,
                                                             keepdim=True)
        else:
            print(
                '=> Discarding pretrained input conv {} since input channel count != {}'
                .format(input_conv_weight, pretrained_in_chans))
            del state_dict[input_conv_weight]
            strict = False

#    classifier_weight = classifier + '.weight'
#    pretrained_num_classes = state_dict[classifier_weight].shape[0]
#    if num_classes != pretrained_num_classes:
#        print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes))
#        del state_dict[classifier_weight]
#        del state_dict[classifier + '.bias']
#        strict = False

    if filter_fn is not None:
        state_dict = filter_fn(state_dict)

    load_state_dict(model, state_dict, strict=strict, logger=get_root_logger())
Exemple #6
0
    def init_weights(self, pretrained=None):
        """Initialize model weights."""
        if isinstance(pretrained, str):
            logger = get_root_logger()
            state_dict_tmp = get_state_dict(pretrained)
            state_dict = OrderedDict()
            state_dict['top'] = OrderedDict()
            state_dict['bottlenecks'] = OrderedDict()
            for k, v in state_dict_tmp.items():
                if k.startswith('layer'):
                    if 'downsample.0' in k:
                        state_dict['bottlenecks'][k.replace(
                            'downsample.0', 'downsample.conv')] = v
                    elif 'downsample.1' in k:
                        state_dict['bottlenecks'][k.replace(
                            'downsample.1', 'downsample.bn')] = v
                    else:
                        state_dict['bottlenecks'][k] = v
                elif k.startswith('conv1'):
                    state_dict['top'][k.replace('conv1', 'top.0.conv')] = v
                elif k.startswith('bn1'):
                    state_dict['top'][k.replace('bn1', 'top.0.bn')] = v

            load_state_dict(self.top,
                            state_dict['top'],
                            strict=False,
                            logger=logger)
            for i in range(self.num_stages):
                load_state_dict(self.multi_stage_mspn[i].downsample,
                                state_dict['bottlenecks'],
                                strict=False,
                                logger=logger)
        else:
            for m in self.multi_stage_mspn.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, nn.BatchNorm2d):
                    constant_init(m, 1)
                elif isinstance(m, nn.Linear):
                    normal_init(m, std=0.01)

            for m in self.top.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
Exemple #7
0
def load_checkpoint(model,
                    filename,
                    map_location=None,
                    strict=False,
                    logger=None):
    """Load checkpoint from a file or URI.

    Args:
        model (Module): Module to load checkpoint.
        filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
        map_location (str): Same as :func:`torch.load`.
        strict (bool): Whether to allow different params for the model and
            checkpoint.
        logger (:mod:`logging.Logger` or None): The logger for error message.

    Returns:
        dict or OrderedDict: The loaded checkpoint.
    """
    # load checkpoint from modelzoo or file or url
    if filename.startswith('modelzoo://'):
        import torchvision
        model_urls = dict()
        for _, name, ispkg in pkgutil.walk_packages(
                torchvision.models.__path__):
            if not ispkg:
                _zoo = import_module('torchvision.models.{}'.format(name))
                if hasattr(_zoo, 'model_urls'):
                    _urls = getattr(_zoo, 'model_urls')
                    model_urls.update(_urls)
        model_name = filename[11:]
        checkpoint = load_url_dist(model_urls[model_name])
    elif filename.startswith('open-mmlab://'):
        model_name = filename[13:]
        checkpoint = load_url_dist(open_mmlab_model_urls[model_name])
    elif filename.startswith(('http://', 'https://')):
        checkpoint = load_url_dist(filename)
    else:
        if not osp.isfile(filename):
            raise IOError('{} is not a checkpoint file'.format(filename))
        checkpoint = torch.load(filename, map_location=map_location)
    # get state_dict from checkpoint
    if isinstance(checkpoint, OrderedDict):
        state_dict = checkpoint
    elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        raise RuntimeError(
            'No state_dict found in checkpoint file {}'.format(filename))
    # strip prefix of state_dict
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
    # load state_dict

    if True:
        # TODO(xiao): init stage3 with stage2
        state_dict_ = state_dict.copy()
        for k, v in state_dict.items():
            if "extra_heads.0" in k:
                kk = k.replace("extra_heads.0", "extra_heads.1")
                state_dict_[kk] = v.clone()
        state_dict = state_dict_.copy()

    if hasattr(model, 'module'):
        load_state_dict(model.module, state_dict, strict, logger)
    else:
        load_state_dict(model, state_dict, strict, logger)
    return checkpoint
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, list):
            logger = logging.getLogger()
            filename = pretrained[0]
            if filename.startswith('modelzoo://'):
                warnings.warn(
                    'The URL scheme of "modelzoo://" is deprecated, please '
                    'use "torchvision://" instead')
                model_urls = get_torchvision_models()
                model_name = filename[11:]
                checkpoint = load_url_dist(model_urls[model_name])
            elif filename.startswith('torchvision://'):
                model_urls = get_torchvision_models()
                model_name = filename[14:]
                checkpoint = load_url_dist(model_urls[model_name])
            elif filename.startswith('open-mmlab://'):
                model_name = filename[13:]
                checkpoint = load_url_dist(open_mmlab_model_urls[model_name])
            elif filename.startswith(('http://', 'https://')):
                checkpoint = load_url_dist(filename)
            else:
                if not osp.isfile(filename):
                    raise IOError(
                        '{} is not a checkpoint file'.format(filename))
                checkpoint = torch.load(filename, map_location=map_location)
            #print(checkpoint.keys)
            #state_dict = {'deep_'+k : v for k,v in checkpoint.items() if 'layer' in k}
            #print(state_dict.keys())
            # load state_dict
            # get state_dict from checkpoint
            if isinstance(checkpoint, OrderedDict):
                state_dict = checkpoint
            elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                raise RuntimeError(
                    'No state_dict found in checkpoint file {}'.format(
                        filename))
            # strip prefix of state_dict
            if list(state_dict.keys())[0].startswith('module.'):
                state_dict = {
                    k[7:]: v
                    for k, v in checkpoint['state_dict'].items()
                }
            if hasattr(self, 'module'):
                load_state_dict(self.module, state_dict, False, logger)
            else:
                load_state_dict(self, state_dict, False, logger)
            filename = pretrained[1]
            if filename.startswith('modelzoo://'):
                warnings.warn(
                    'The URL scheme of "modelzoo://" is deprecated, please '
                    'use "torchvision://" instead')
                model_urls = get_torchvision_models()
                model_name = filename[11:]
                checkpoint = load_url_dist(model_urls[model_name])
            elif filename.startswith('torchvision://'):
                model_urls = get_torchvision_models()
                model_name = filename[14:]
                checkpoint = load_url_dist(model_urls[model_name])
            elif filename.startswith('open-mmlab://'):
                model_name = filename[13:]
                checkpoint = load_url_dist(open_mmlab_model_urls[model_name])
            elif filename.startswith(('http://', 'https://')):
                checkpoint = load_url_dist(filename)
            else:
                if not osp.isfile(filename):
                    raise IOError(
                        '{} is not a checkpoint file'.format(filename))
                checkpoint = torch.load(filename, map_location=map_location)
            if isinstance(checkpoint, OrderedDict):
                state_dict = checkpoint
            elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                raise RuntimeError(
                    'No state_dict found in checkpoint file {}'.format(
                        filename))
            # strip prefix of state_dict
            if list(state_dict.keys())[0].startswith('module.'):
                state_dict = {
                    'shallow_' + k[7:]: v
                    for k, v in checkpoint['state_dict'].items()
                    if 'layer' in k
                }
            else:
                state_dict = {
                    'shallow_' + k: v
                    for k, v in checkpoint.items() if 'layer' in k
                }
            if hasattr(self, 'module'):
                load_state_dict(self.module, state_dict, False, logger)
            else:
                load_state_dict(self, state_dict, False, logger)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                    constant_init(m, 1)

            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottleneck) and hasattr(
                            m, 'conv2_offset'):
                        constant_init(m.conv2_offset, 0)

            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
        else:
            raise TypeError('pretrained must be a str or None')