コード例 #1
0
def build_shufflenetv2_backbone(cfg):
    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    bn_norm = cfg.MODEL.BACKBONE.NORM
    model_size = cfg.MODEL.BACKBONE.DEPTH
    # fmt: on

    model = ShuffleNetV2(bn_norm, model_size=model_size)

    if pretrain:
        new_state_dict = OrderedDict()
        state_dict = torch.load(pretrain_path)["state_dict"]
        for k, v in state_dict.items():
            if k[:7] == 'module.':
                k = k[7:]
            new_state_dict[k] = v

        incompatible = model.load_state_dict(new_state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys))
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(
                    incompatible.unexpected_keys))

    return model
コード例 #2
0
ファイル: osnet.py プロジェクト: 15575432921/fast-reid-group
def build_osnet_backbone(cfg):
    """
    Create a OSNet instance from config.
    Returns:
        OSNet: a :class:`OSNet` instance
    """

    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
    bn_norm = cfg.MODEL.BACKBONE.NORM
    depth = cfg.MODEL.BACKBONE.DEPTH
    # fmt: on

    num_blocks_per_stage = [2, 2, 2]
    num_channels_per_stage = {
        "x1_0": [64, 256, 384, 512],
        "x0_75": [48, 192, 288, 384],
        "x0_5": [32, 128, 192, 256],
        "x0_25": [16, 64, 96, 128]
    }[depth]
    model = OSNet([OSBlock, OSBlock, OSBlock],
                  num_blocks_per_stage,
                  num_channels_per_stage,
                  bn_norm,
                  IN=with_ibn)

    if pretrain:
        # Load pretrain path if specifically
        if pretrain_path:
            try:
                state_dict = torch.load(pretrain_path,
                                        map_location=torch.device('cpu'))
                logger.info(f"Loading pretrained model from {pretrain_path}")
            except FileNotFoundError as e:
                logger.info(
                    f'{pretrain_path} is not found! Please check this path.')
                raise e
            except KeyError as e:
                logger.info(
                    "State dict keys error! Please check the state dict.")
                raise e
        else:
            if with_ibn:
                pretrain_key = "osnet_ibn_" + depth
            else:
                pretrain_key = "osnet_" + depth

            state_dict = init_pretrained_weights(model, pretrain_key)

        incompatible = model.load_state_dict(state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys))
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(
                    incompatible.unexpected_keys))
    return model
コード例 #3
0
ファイル: resnest.py プロジェクト: xhuljl/fast-reid
def build_resnest_backbone(cfg):
    """
    Create a ResNest instance from config.
    Returns:
        ResNet: a :class:`ResNet` instance.
    """

    # fmt: off
    pretrain      = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    last_stride   = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm       = cfg.MODEL.BACKBONE.NORM
    depth         = cfg.MODEL.BACKBONE.DEPTH
    # fmt: on

    num_blocks_per_stage = {
        "50x": [3, 4, 6, 3],
        "101x": [3, 4, 23, 3],
        "200x": [3, 24, 36, 3],
        "269x": [3, 30, 48, 8],
    }[depth]

    stem_width = {
        "50x": 32,
        "101x": 64,
        "200x": 64,
        "269x": 64,
    }[depth]

    model = ResNeSt(last_stride, Bottleneck, num_blocks_per_stage,
                    radix=2, groups=1, bottleneck_width=64,
                    deep_stem=True, stem_width=stem_width, avg_down=True,
                    avd=True, avd_first=False, norm_layer=bn_norm)
    if pretrain:
        # Load pretrain path if specifically
        if pretrain_path:
            try:
                state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
                logger.info(f"Loading pretrained model from {pretrain_path}")
            except FileNotFoundError as e:
                logger.info(f'{pretrain_path} is not found! Please check this path.')
                raise e
            except KeyError as e:
                logger.info("State dict keys error! Please check the state dict.")
                raise e
        else:
            state_dict = torch.hub.load_state_dict_from_url(
                model_urls['resnest' + depth[:-1]], progress=True, check_hash=True, map_location=torch.device('cpu'))

        incompatible = model.load_state_dict(state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys)
            )
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(incompatible.unexpected_keys)
            )
    return model
コード例 #4
0
def build_resnext_backbone(cfg):
    """
    Create a ResNeXt instance from config.
    Returns:
        ResNeXt: a :class:`ResNeXt` instance.
    """

    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm = cfg.MODEL.BACKBONE.NORM
    num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
    with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
    with_nl = cfg.MODEL.BACKBONE.WITH_NL
    depth = cfg.MODEL.BACKBONE.DEPTH

    num_blocks_per_stage = {'50x': [3, 4, 6, 3], '101x': [3, 4, 23, 3], '152x': [3, 8, 36, 3], }[depth]
    nl_layers_per_stage = {'50x': [0, 2, 3, 0], '101x': [0, 2, 3, 0]}[depth]
    model = ResNeXt(last_stride, bn_norm, num_splits, with_ibn, Bottleneck, num_blocks_per_stage)

    if pretrain:
        if pretrain_path:
            try:
                state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
                # Remove module.encoder in name
                new_state_dict = {}
                for k in state_dict:
                    new_k = '.'.join(k.split('.')[2:])
                    if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
                        new_state_dict[new_k] = state_dict[k]
                state_dict = new_state_dict
                logger.info(f"Loading pretrained model from {pretrain_path}")
            except FileNotFoundError as e:
                logger.info(f'{pretrain_path} is not found! Please check this path.')
                raise e
            except KeyError as e:
                logger.info("State dict keys error! Please check the state dict.")
                raise e
        else:
            key = depth
            if with_ibn: key = 'ibn_' + key

            state_dict = init_pretrained_weights(key)

        incompatible = model.load_state_dict(state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys)
            )
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(incompatible.unexpected_keys)
            )

    return model
コード例 #5
0
ファイル: repvgg.py プロジェクト: 15575432921/fast-reid-group
def build_repvgg_backbone(cfg):
    """
    Create a RepVGG instance from config.
    Returns:
        RepVGG: a :class: `RepVGG` instance.
    """

    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm = cfg.MODEL.BACKBONE.NORM
    depth = cfg.MODEL.BACKBONE.DEPTH
    # fmt: on

    func_dict = {
        'A0': create_RepVGG_A0,
        'A1': create_RepVGG_A1,
        'A2': create_RepVGG_A2,
        'B0': create_RepVGG_B0,
        'B1': create_RepVGG_B1,
        'B1g2': create_RepVGG_B1g2,
        'B1g4': create_RepVGG_B1g4,
        'B2': create_RepVGG_B2,
        'B2g2': create_RepVGG_B2g2,
        'B2g4': create_RepVGG_B2g4,
        'B3': create_RepVGG_B3,
        'B3g2': create_RepVGG_B3g2,
        'B3g4': create_RepVGG_B3g4,
    }

    model = func_dict[depth](last_stride, bn_norm)

    if pretrain:
        try:
            state_dict = torch.load(pretrain_path,
                                    map_location=torch.device("cpu"))
            logger.info(f"Loading pretrained model from {pretrain_path}")
        except FileNotFoundError as e:
            logger.info(
                f'{pretrain_path} is not found! Please check this path.')
            raise e
        except KeyError as e:
            logger.info("State dict keys error! Please check the state dict.")
            raise e

        incompatible = model.load_state_dict(state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys))
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(
                    incompatible.unexpected_keys))

    return model
コード例 #6
0
ファイル: resnet.py プロジェクト: daip13/LPC_MOT
def build_resnet_backbone(cfg):
    """
    Create a ResNet instance from config.
    Returns:
        ResNet: a :class:`ResNet` instance.
    """

    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm = cfg.MODEL.BACKBONE.NORM
    num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
    with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
    with_se = cfg.MODEL.BACKBONE.WITH_SE
    with_nl = cfg.MODEL.BACKBONE.WITH_NL
    depth = cfg.MODEL.BACKBONE.DEPTH
    use_cd = cfg.MODEL.BACKBONE.CD

    num_blocks_per_stage = {'18x': [2, 2, 2, 2], '34x': [3, 4, 6, 3], '50x': [3, 4, 6, 3],
                            '101x': [3, 4, 23, 3],}[depth]
    nl_layers_per_stage = {'18x': [0, 0, 0, 0], '34x': [0, 0, 0, 0], '50x': [0, 2, 3, 0], '101x': [0, 2, 9, 0]}[depth]
    block = {'18x': BasicBlock, '34x': BasicBlock, '50x': Bottleneck, '101x': Bottleneck}[depth]
    model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block,
                   num_blocks_per_stage, nl_layers_per_stage, use_cd=use_cd)
    if pretrain:
        # Load pretrain path if specifically
        if pretrain_path:
            try:
                state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
                logger.info(f"Loading pretrained model from {pretrain_path}")
            except FileNotFoundError as e:
                logger.info(f'{pretrain_path} is not found! Please check this path.')
                raise e
            except KeyError as e:
                logger.info("State dict keys error! Please check the state dict.")
                raise e
        else:
            key = depth
            if with_ibn: key = 'ibn_' + key
            if with_se:  key = 'se_' + key

            state_dict = init_pretrained_weights(key)

        incompatible = model.load_state_dict(state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys)
            )
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(incompatible.unexpected_keys)
            )

    return model
コード例 #7
0
ファイル: resnet.py プロジェクト: zymale/fast-reid
def build_resnet_backbone(cfg):
    """
    Create a ResNet instance from config.
    Returns:
        ResNet: a :class:`ResNet` instance.
    """

    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm = cfg.MODEL.BACKBONE.NORM
    num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
    with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
    with_se = cfg.MODEL.BACKBONE.WITH_SE
    with_nl = cfg.MODEL.BACKBONE.WITH_NL
    depth = cfg.MODEL.BACKBONE.DEPTH

    num_blocks_per_stage = {
        50: [3, 4, 6, 3],
        101: [3, 4, 23, 3],
        152: [3, 8, 36, 3],
    }[depth]
    nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 9, 0]}[depth]
    model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se,
                   with_nl, Bottleneck, num_blocks_per_stage,
                   nl_layers_per_stage)
    if pretrain:
        if not with_ibn:
            # original resnet
            state_dict = model_zoo.load_url(model_urls[depth])
        else:
            # ibn resnet
            state_dict = torch.load(pretrain_path)['state_dict']
            # remove module in name
            new_state_dict = {}
            for k in state_dict:
                new_k = '.'.join(k.split('.')[1:])
                if new_k in model.state_dict() and (
                        model.state_dict()[new_k].shape
                        == state_dict[k].shape):
                    new_state_dict[new_k] = state_dict[k]
            state_dict = new_state_dict

        incompatible = model.load_state_dict(state_dict, strict=False)
        logger = logging.getLogger(__name__)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys))
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(
                    incompatible.unexpected_keys))
    return model
コード例 #8
0
ファイル: regnet.py プロジェクト: zyg11/fast-reid
def build_regnet_backbone(cfg):
    # fmt: off
    pretrain      = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    last_stride   = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm       = cfg.MODEL.BACKBONE.NORM
    depth         = cfg.MODEL.BACKBONE.DEPTH
    # fmt: on

    cfg_files = {
        '800x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml',
        '800y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml',
        '1600x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml',
        '1600y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml',
        '3200x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml',
        '3200y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml',
        '4000x': 'fastreid/modeling/backbones/regnet/regnety/RegNetX-4.0GF_dds_8gpu.yaml',
        '4000y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml',
        '6400x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml',
        '6400y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml',
    }[depth]

    regnet_cfg.merge_from_file(cfg_files)
    model = RegNet(last_stride, bn_norm)

    if pretrain:
        # Load pretrain path if specifically
        if pretrain_path:
            try:
                state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
                logger.info(f"Loading pretrained model from {pretrain_path}")
            except FileNotFoundError as e:
                logger.info(f'{pretrain_path} is not found! Please check this path.')
                raise e
            except KeyError as e:
                logger.info("State dict keys error! Please check the state dict.")
                raise e
        else:
            key = depth
            state_dict = init_pretrained_weights(key)

        incompatible = model.load_state_dict(state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys)
            )
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(incompatible.unexpected_keys)
            )
    return model
コード例 #9
0
ファイル: uda_base.py プロジェクト: X-funbean/fast-reid
    def build_model(cls,
                    cfg,
                    load_model=True,
                    show_model=True,
                    use_dsbn=False):
        cfg = cfg.clone()  # cfg can be modified by model
        cfg.defrost()
        cfg.MODEL.DEVICE = "cpu"

        model = build_model(cfg)
        logger = logging.getLogger('fastreid')

        if load_model:
            pretrain_path = cfg.MODEL.PRETRAIN_PATH
            try:
                state_dict = torch.load(
                    pretrain_path, map_location=torch.device("cpu"))['model']
                for layer in cfg.MODEL.IGNORE_LAYERS:
                    if layer in state_dict.keys():
                        del state_dict[layer]
                logger.info(f"Loading pretrained model from {pretrain_path}")
            except FileNotFoundError as e:
                logger.info(
                    f"{pretrain_path} is not found! Please check this path.")
                raise e
            except KeyError as e:
                logger.info(
                    "State dict keys error! Please check the state dict.")
                raise e

            incompatible = model.load_state_dict(state_dict, strict=False)
            if incompatible.missing_keys:
                logger.info(
                    get_missing_parameters_message(incompatible.missing_keys))
            if incompatible.unexpected_keys:
                logger.info(
                    get_unexpected_parameters_message(
                        incompatible.unexpected_keys))

        if use_dsbn:
            logger.info("==> Convert BN to Domain Specific BN")
            convert_dsbn(model)

        if show_model:
            logger.info("Model:\n{}".format(model))

        model.to(torch.device("cuda"))
        return model
コード例 #10
0
ファイル: effnet.py プロジェクト: 15575432921/fast-reid-group
def build_effnet_backbone(cfg):
    # fmt: off
    pretrain      = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    last_stride   = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm       = cfg.MODEL.BACKBONE.NORM
    depth         = cfg.MODEL.BACKBONE.DEPTH
    # fmt: on

    cfg_files = {
        'b0': 'fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml',
        'b1': 'fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml',
        'b2': 'fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml',
        'b3': 'fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml',
        'b4': 'fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml',
        'b5': 'fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml',
    }[depth]

    effnet_cfg.merge_from_file(cfg_files)
    model = EffNet(last_stride, bn_norm)

    if pretrain:
        # Load pretrain path if specifically
        if pretrain_path:
            try:
                state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))["model_state"]
                logger.info(f"Loading pretrained model from {pretrain_path}")
            except FileNotFoundError as e:
                logger.info(f'{pretrain_path} is not found! Please check this path.')
                raise e
            except KeyError as e:
                logger.info("State dict keys error! Please check the state dict.")
                raise e
        else:
            key = depth
            state_dict = init_pretrained_weights(key)

        incompatible = model.load_state_dict(state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys)
            )
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(incompatible.unexpected_keys)
            )
    return model
コード例 #11
0
def build_mobilenetv2_backbone(cfg):
    """
    Create a MobileNetV2 instance from config.
    Returns:
        MobileNetV2: a :class: `MobileNetV2` instance.
    """
    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    bn_norm = cfg.MODEL.BACKBONE.NORM
    depth = cfg.MODEL.BACKBONE.DEPTH
    # fmt: on

    width_mult = {
        "1.0x": 1.0,
        "0.75x": 0.75,
        "0.5x": 0.5,
        "0.35x": 0.35,
        '0.25x': 0.25,
        '0.1x': 0.1,
    }[depth]

    model = MobileNetV2(bn_norm, width_mult)

    if pretrain:
        try:
            state_dict = torch.load(pretrain_path,
                                    map_location=torch.device('cpu'))
            logger.info(f"Loading pretrained model from {pretrain_path}")
        except FileNotFoundError as e:
            logger.info(
                f'{pretrain_path} is not found! Please check this path.')
            raise e
        except KeyError as e:
            logger.info("State dict keys error! Please check the state dict.")
            raise e

        incompatible = model.load_state_dict(state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys))
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(
                    incompatible.unexpected_keys))

    return model
コード例 #12
0
ファイル: regnet.py プロジェクト: wangqi12332155/fast-reid
def build_regnet_backbone(cfg):
    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm = cfg.MODEL.BACKBONE.NORM
    depth = cfg.MODEL.BACKBONE.DEPTH

    cfg_files = {
        '800x':
        'fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml',
        '800y':
        'fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml',
        '1600x':
        'fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml',
        '1600y':
        'fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml',
        '3200x':
        'fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml',
        '3200y':
        'fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml',
        '4000x':
        'fastreid/modeling/backbones/regnet/regnety/RegNetX-4.0GF_dds_8gpu.yaml',
        '4000y':
        'fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml',
        '6400x':
        'fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml',
        '6400y':
        'fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml',
    }[depth]

    regnet_cfg.merge_from_file(cfg_files)
    model = RegNet(last_stride, bn_norm)

    if pretrain:
        key = depth
        state_dict = init_pretrained_weights(key)

        incompatible = model.load_state_dict(state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys))
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(
                    incompatible.unexpected_keys))
    return model
コード例 #13
0
def build_regnet_backbone(cfg):
    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm = cfg.MODEL.BACKBONE.NORM
    volume = cfg.MODEL.BACKBONE.VOLUME

    cfg_files = {
        '800x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml',
        '800y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml',
        '1600x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml',
        '1600y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml',
        '3200x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml',
        '3200y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml',
        '6400x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml',
        '6400y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml',
    }[volume]

    regnet_cfg.merge_from_file(cfg_files)
    model = RegNet(last_stride, bn_norm)

    if pretrain:
        try:
            state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model_state']
        except FileNotFoundError as e:
            logger.info(f'{pretrain_path} is not found! Please check this path.')
            raise e

        logger.info(f"Loading pretrained model from {pretrain_path}")

        incompatible = model.load_state_dict(state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys)
            )
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(incompatible.unexpected_keys)
            )
    return model
コード例 #14
0
def build_resnest_backbone(cfg):
    """
    Create a ResNest instance from config.
    Returns:
        ResNet: a :class:`ResNet` instance.
    """

    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm = cfg.MODEL.BACKBONE.NORM
    num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
    with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
    with_se = cfg.MODEL.BACKBONE.WITH_SE
    with_nl = cfg.MODEL.BACKBONE.WITH_NL
    depth = cfg.MODEL.BACKBONE.DEPTH

    num_blocks_per_stage = {"50x": [3, 4, 6, 3], "101x": [3, 4, 23, 3], "200x": [3, 24, 36, 3],
                            "269x": [3, 30, 48, 8]}[depth]
    nl_layers_per_stage = {"50x": [0, 2, 3, 0], "101x": [0, 2, 3, 0], "200x": [0, 2, 3, 0], "269x": [0, 2, 3, 0]}[depth]
    stem_width = {"50x": 32, "101x": 64, "200x": 64, "269x": 64}[depth]
    model = ResNest(last_stride, bn_norm, num_splits, with_ibn, with_nl, Bottleneck, num_blocks_per_stage,
                    nl_layers_per_stage, radix=2, groups=1, bottleneck_width=64,
                    deep_stem=True, stem_width=stem_width, avg_down=True,
                    avd=True, avd_first=False)
    if pretrain:
        state_dict = torch.hub.load_state_dict_from_url(
            model_urls['resnest' + depth[:-1]], progress=True, check_hash=True, map_location=torch.device('cpu'))

        incompatible = model.load_state_dict(state_dict, strict=False)
        logger = logging.getLogger(__name__)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys)
            )
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(incompatible.unexpected_keys)
            )
    return model
コード例 #15
0
def build_resnet_backbone(cfg):
    """
    Create a ResNet instance from config.
    Returns:
        ResNet: a :class:`ResNet` instance.
    """

    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm = cfg.MODEL.BACKBONE.NORM
    num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
    with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
    with_se = cfg.MODEL.BACKBONE.WITH_SE
    with_nl = cfg.MODEL.BACKBONE.WITH_NL
    depth = cfg.MODEL.BACKBONE.DEPTH

    num_blocks_per_stage = {
        34: [3, 4, 6, 3],
        50: [3, 4, 6, 3],
        101: [3, 4, 23, 3],
        152: [3, 8, 36, 3],
    }[depth]
    nl_layers_per_stage = {
        34: [0, 2, 3, 0],
        50: [0, 2, 3, 0],
        101: [0, 2, 9, 0]
    }[depth]
    block = {34: BasicBlock, 50: Bottleneck, 101: Bottleneck}[depth]
    model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se,
                   with_nl, block, num_blocks_per_stage, nl_layers_per_stage)
    if pretrain:
        if not with_ibn:
            try:
                state_dict = torch.load(
                    pretrain_path, map_location=torch.device('cpu'))['model']
                # Remove module.encoder in name
                new_state_dict = {}
                for k in state_dict:
                    new_k = '.'.join(k.split('.')[2:])
                    if new_k in model.state_dict() and (
                            model.state_dict()[new_k].shape
                            == state_dict[k].shape):
                        new_state_dict[new_k] = state_dict[k]
                state_dict = new_state_dict
                logger.info(f"Loading pretrained model from {pretrain_path}")
            except FileNotFoundError or KeyError:
                # original resnet
                state_dict = model_zoo.load_url(model_urls[depth])
                logger.info("Loading pretrained model from torchvision")
        else:
            state_dict = torch.load(
                pretrain_path,
                map_location=torch.device('cpu'))['state_dict']  # ibn-net
            # Remove module in name
            new_state_dict = {}
            for k in state_dict:
                new_k = '.'.join(k.split('.')[1:])
                if new_k in model.state_dict() and (
                        model.state_dict()[new_k].shape
                        == state_dict[k].shape):
                    new_state_dict[new_k] = state_dict[k]
            state_dict = new_state_dict
            logger.info(f"Loading pretrained model from {pretrain_path}")
        incompatible = model.load_state_dict(state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys))
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(
                    incompatible.unexpected_keys))
    return model
コード例 #16
0
ファイル: resnest.py プロジェクト: hzphzp/MetaBIN
def build_resnest_backbone(cfg):
    """
    Create a ResNest instance from config.
    Returns:
        ResNet: a :class:`ResNet` instance.
    """

    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm = cfg.MODEL.BACKBONE.NORM
    num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
    with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
    with_se = cfg.MODEL.BACKBONE.WITH_SE
    with_nl = cfg.MODEL.BACKBONE.WITH_NL
    depth = cfg.MODEL.BACKBONE.DEPTH

    num_blocks_per_stage = {
        50: [3, 4, 6, 3],
        101: [3, 4, 23, 3],
        200: [3, 24, 36, 3],
        269: [3, 30, 48, 8]
    }[depth]
    nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
    stem_width = {50: 32, 101: 64, 200: 64, 269: 64}[depth]
    model = ResNest(last_stride,
                    bn_norm,
                    num_splits,
                    with_ibn,
                    with_nl,
                    Bottleneck,
                    num_blocks_per_stage,
                    nl_layers_per_stage,
                    radix=2,
                    groups=1,
                    bottleneck_width=64,
                    deep_stem=True,
                    stem_width=stem_width,
                    avg_down=True,
                    avd=True,
                    avd_first=False)
    if pretrain:
        # if not with_ibn:
        # original resnet
        state_dict = torch.hub.load_state_dict_from_url(model_urls['resnest' +
                                                                   str(depth)],
                                                        progress=True,
                                                        check_hash=True)
        # else:
        #     raise KeyError('Not implementation ibn in resnest')
        # # ibn resnet
        # state_dict = torch.load(pretrain_path)['state_dict']
        # # remove module in name
        # new_state_dict = {}
        # for k in state_dict:
        #     new_k = '.'.join(k.split('.')[1:])
        #     if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
        #         new_state_dict[new_k] = state_dict[k]
        # state_dict = new_state_dict
        incompatible = model.load_state_dict(state_dict, strict=False)
        logger = logging.getLogger(__name__)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys))
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(
                    incompatible.unexpected_keys))
    return model
コード例 #17
0
def build_resnet_backbone(cfg):
    """
    Create a ResNet instance from config.
    Returns:
        ResNet: a :class:`ResNet` instance.
    """

    # fmt: off
    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm = cfg.MODEL.NORM.TYPE_BACKBONE
    norm_opt = dict()
    norm_opt['BN_AFFINE'] = cfg.MODEL.NORM.BN_AFFINE
    norm_opt['BN_RUNNING'] = cfg.MODEL.NORM.BN_RUNNING
    norm_opt['IN_AFFINE'] = cfg.MODEL.NORM.IN_AFFINE
    norm_opt['IN_RUNNING'] = cfg.MODEL.NORM.IN_RUNNING

    norm_opt['BN_W_FREEZE'] = cfg.MODEL.NORM.BN_W_FREEZE
    norm_opt['BN_B_FREEZE'] = cfg.MODEL.NORM.BN_B_FREEZE
    norm_opt['IN_W_FREEZE'] = cfg.MODEL.NORM.IN_W_FREEZE
    norm_opt['IN_B_FREEZE'] = cfg.MODEL.NORM.IN_B_FREEZE

    norm_opt['BIN_INIT'] = cfg.MODEL.NORM.BIN_INIT
    norm_opt['IN_FC_MULTIPLY'] = cfg.MODEL.NORM.IN_FC_MULTIPLY
    num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
    with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
    with_se = cfg.MODEL.BACKBONE.WITH_SE
    with_nl = cfg.MODEL.BACKBONE.WITH_NL
    depth = cfg.MODEL.BACKBONE.DEPTH

    num_blocks_per_stage = {18: [2,2,2,2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
    nl_layers_per_stage = {18: [0,0,0,0], 34: [0, 2, 3, 0], 50: [0, 2, 3, 0], 101: [0, 2, 9, 0], 152: [0, 2, 9, 0]}[depth]
    block = {18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152:Bottleneck,}[depth]
    model = ResNet(last_stride, bn_norm, norm_opt, num_splits, with_ibn, with_se, with_nl, block,
                   num_blocks_per_stage, nl_layers_per_stage)
    if pretrain:
        if not with_ibn:
            try:
                state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
                # Remove module.encoder in name
                new_state_dict = {}
                for k in state_dict:
                    new_k = '.'.join(k.split('.')[2:])
                    if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
                        new_state_dict[new_k] = state_dict[k]
                state_dict = new_state_dict
                logger.info(f"Loading pretrained model from {pretrain_path}")
            except FileNotFoundError or KeyError:
                # original resnet
                state_dict = model_zoo.load_url(model_urls[depth])
                logger.info("Loading pretrained model from torchvision")
        else:
            state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['state_dict']  # ibn-net
            # Remove module in name
            new_state_dict = {}
            for k in state_dict:
                new_k = '.'.join(k.split('.')[1:])
                if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
                    new_state_dict[new_k] = state_dict[k]
            state_dict = new_state_dict
            logger.info(f"Loading pretrained model from {pretrain_path}")

        for name, param in state_dict.copy().items():

            if 'downsample' in name:  # layer1.0.downsample.0.weight
                new_name = name.split('.')
                if new_name[-2] == '0':
                    new_name[-2] = 'conv'
                elif new_name[-2] == '1':
                    new_name[-2] = 'bn'
                new_name = '.'.join(new_name)
                state_dict[new_name] = copy.copy(state_dict[name])
                del state_dict[name]


        if cfg.MODEL.NORM.TYPE_BACKBONE == 'BIN_gate2':
            for name, values in state_dict.copy().items():
                if 'bn' in name:
                    if ('weight' in name) or ('bias' in name):
                        # bn.weight, bn.bias -> bn.bat_n.weight, bn.bat_n.bias
                        if cfg.MODEL.NORM.LOAD_BN_AFFINE:
                            split_name = name.split('.')
                            for i, local_name in enumerate(split_name):
                                if 'bn' in local_name:
                                    split_name.insert(i + 1, 'bat_n')
                                    break
                            new_name = '.'.join(split_name)
                            state_dict[new_name] = values
                        # bn.weight, bn.bias -> bn.ins_n.weight, bn.ins_n.bias
                        if cfg.MODEL.NORM.LOAD_IN_AFFINE:
                            split_name = name.split('.')
                            for i, local_name in enumerate(split_name):
                                if 'bn' in local_name:
                                    split_name.insert(i + 1, 'ins_n')
                                    break
                            new_name = '.'.join(split_name)
                            state_dict[new_name] = values
                        del state_dict[name]
                    elif ('running_mean' in name) or ('running_var' in name):
                        # bn.running_mean, bn.running_var -> bn.bat_n.running_mean, bn.bat_n.running_var
                        if cfg.MODEL.NORM.LOAD_BN_RUNNING:
                            split_name = name.split('.')
                            for i, local_name in enumerate(split_name):
                                if 'bn' in local_name:
                                    split_name.insert(i + 1, 'bat_n')
                                    break
                            new_name = '.'.join(split_name)
                            state_dict[new_name] = values
                        # bn.running_mean, bn.running_var -> bn.ins_n.running_mean, bn.ins_n.running_var
                        if cfg.MODEL.NORM.LOAD_IN_RUNNING:
                            split_name = name.split('.')
                            for i, local_name in enumerate(split_name):
                                if 'bn' in local_name:
                                    split_name.insert(i + 1, 'ins_n')
                                    break
                            new_name = '.'.join(split_name)
                            state_dict[new_name] = values
                        del state_dict[name]

        else:
            if not cfg.MODEL.NORM.LOAD_BN_AFFINE:
                for name, param in state_dict.copy().items():
                    if ('bn' in name) or ('norm' in name):
                        if ('weight' in name) or ('bias' in name):
                            del state_dict[name]
            if not cfg.MODEL.NORM.LOAD_BN_RUNNING:
                for name, param in state_dict.copy().items():
                    if ('bn' in name) or ('norm' in name):
                        if ('running_mean' in name) or ('running_var' in name):
                            del state_dict[name]
            if not cfg.MODEL.NORM.IN_RUNNING and cfg.MODEL.NORM.TYPE_BACKBONE == "IN":
                for name, param in state_dict.copy().items():
                    if ('bn' in name) or ('norm' in name):
                        if ('running_mean' in name) or ('running_var' in name):
                            del state_dict[name]

        incompatible = model.load_state_dict(state_dict, strict=False)


        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys)
            )
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(incompatible.unexpected_keys)
            )
    return model
コード例 #18
0
ファイル: mobilenet_v2.py プロジェクト: hzphzp/MetaBIN
def build_mobilenet_v2_backbone(cfg):

    pretrain = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
    bn_norm = cfg.MODEL.NORM.TYPE_BACKBONE
    norm_opt = dict()
    norm_opt['BN_AFFINE'] = cfg.MODEL.NORM.BN_AFFINE
    norm_opt['BN_RUNNING'] = cfg.MODEL.NORM.BN_RUNNING
    norm_opt['IN_AFFINE'] = cfg.MODEL.NORM.IN_AFFINE
    norm_opt['IN_RUNNING'] = cfg.MODEL.NORM.IN_RUNNING

    norm_opt['BN_W_FREEZE'] = cfg.MODEL.NORM.BN_W_FREEZE
    norm_opt['BN_B_FREEZE'] = cfg.MODEL.NORM.BN_B_FREEZE
    norm_opt['IN_W_FREEZE'] = cfg.MODEL.NORM.IN_W_FREEZE
    norm_opt['IN_B_FREEZE'] = cfg.MODEL.NORM.IN_B_FREEZE

    norm_opt['BIN_INIT'] = cfg.MODEL.NORM.BIN_INIT
    norm_opt['IN_FC_MULTIPLY'] = cfg.MODEL.NORM.IN_FC_MULTIPLY
    # num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
    # with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
    # with_se = cfg.MODEL.BACKBONE.WITH_SE
    # with_nl = cfg.MODEL.BACKBONE.WITH_NL
    depth = cfg.MODEL.BACKBONE.DEPTH / 10.0

    model = MobileNetV2(width_mult=depth,
                        bn_norm=bn_norm,
                        norm_opt=norm_opt,
                        last_stride=last_stride)

    # model_urls = {
    #     # 1.0: top-1 71.3
    #     'mobilenetv2_x1_0':
    #         'https://mega.nz/#!NKp2wAIA!1NH1pbNzY_M2hVk_hdsxNM1NUOWvvGPHhaNr-fASF6c',
    #     # 1.4: top-1 73.9
    #     'mobilenetv2_x1_4':
    #         'https://mega.nz/#!RGhgEIwS!xN2s2ZdyqI6vQ3EwgmRXLEW3khr9tpXg96G9SUJugGk',
    # }

    if pretrain and pretrain_path is not "":
        requires_dict = OrderedDict()
        for name, values in model.named_parameters():
            requires_dict[name] = copy.copy(values.requires_grad)
        pretrained_dict = torch.load(pretrain_path,
                                     map_location=torch.device('cpu'))
        state_dict_new = OrderedDict()
        for name, values in pretrained_dict.copy().items():
            name_split = name.split('.')
            if 'conv' in name_split[0]:
                name_split[0] = name_split[0].replace('conv', 'layer')
                name = '.'.join(name_split)
            state_dict_new[name] = copy.copy(
                values
            )  # change conv -> layer (to compatibility with resnet's name)

        state_dict = OrderedDict()
        for name, values in state_dict_new.copy().items():
            # conv3.0.~~ -> conv3.~~
            if 'conv3.0' in name:
                name = name.replace('conv3.0', 'conv3')
            # conv3.1.~~ -> bn.~~
            elif 'conv3.1' in name:
                name = name.replace('conv3.1', 'bn')
            state_dict[name] = values

        if cfg.MODEL.NORM.TYPE_BACKBONE == 'BIN_gate2':
            for name, values in state_dict.copy().items():
                if 'bn' in name:
                    if ('weight' in name) or ('bias' in name):
                        # bn.weight, bn.bias -> bn.bat_n.weight, bn.bat_n.bias
                        if cfg.MODEL.NORM.LOAD_BN_AFFINE:
                            new_name = name.replace('bn', 'bn.bat_n')
                            state_dict[new_name] = values
                        # bn.weight, bn.bias -> bn.ins_n.weight, bn.ins_n.bias
                        if cfg.MODEL.NORM.LOAD_IN_AFFINE:
                            new_name = name.replace('bn', 'bn.ins_n')
                            state_dict[new_name] = values
                        del state_dict[name]
                    elif ('running_mean' in name) or ('running_var' in name):
                        # bn.running_mean, bn.running_var -> bn.bat_n.running_mean, bn.bat_n.running_var
                        if cfg.MODEL.NORM.LOAD_BN_RUNNING:
                            new_name = name.replace('bn', 'bn.bat_n')
                            state_dict[new_name] = values
                        # bn.running_mean, bn.running_var -> bn.ins_n.running_mean, bn.ins_n.running_var
                        if cfg.MODEL.NORM.LOAD_IN_RUNNING:
                            new_name = name.replace('bn', 'bn.ins_n')
                            state_dict[new_name] = values
                        del state_dict[name]

        else:
            if not cfg.MODEL.NORM.LOAD_BN_AFFINE:
                for name, param in state_dict.copy().items():
                    if ('bn' in name) or ('norm' in name):
                        if ('weight' in name) or ('bias' in name):
                            del state_dict[name]
                            print(name)
            if not cfg.MODEL.NORM.LOAD_BN_RUNNING:
                for name, param in state_dict.copy().items():
                    if ('bn' in name) or ('norm' in name):
                        if ('running_mean' in name) or ('running_var' in name):
                            del state_dict[name]
            if not cfg.MODEL.NORM.IN_RUNNING and cfg.MODEL.NORM.TYPE_BACKBONE == "IN":
                for name, param in state_dict.copy().items():
                    if ('bn' in name) or ('norm' in name):
                        if ('running_mean' in name) or ('running_var' in name):
                            del state_dict[name]

        if not cfg.MODEL.BACKBONE.NUM_BATCH_TRACKED:
            for name, values in state_dict.copy().items():
                if 'num_batches_tracked' in name:
                    del state_dict[name]

        for name, values in requires_dict.copy().items():
            if name in state_dict:
                state_dict[name].requires_grad = copy.copy(requires_dict[name])
                # print(requires_dict[name])

        if cfg.MODEL.NORM.TYPE_BACKBONE == 'DualNorm':
            for name, values in state_dict.copy().items():
                if ('bn' in name):
                    if ('layer1' in name) or ('layer2' in name) or ('layer3' in name) or \
                            ('layer4' in name) or ('layer5' in name) or ('layer6' in name):
                        del state_dict[name]

        incompatible = model.load_state_dict(state_dict, strict=False)

        # if cfg.MODEL.BACKBONE.NUM_BATCH_TRACKED:

        # for name, values in model.named_parameters():
        #     values.requires_grad = requires_dict[name]

        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys))
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(
                    incompatible.unexpected_keys))

        # if depth == 1.0:
        # init_pretrained_weights(model, model_urls['mobilenetv2_x1_0'])
        # import warnings
        # warnings.warn('The imagenet pretrained weights need to be manually downloaded from {}'
        #         .format(model_urls['mobilenetv2_x1_0']))

        # elif depth == 1.4:
        # init_pretrained_weights(model, model_urls['mobilenetv2_x1_4'])
        # import warnings
        # warnings.warn('The imagenet pretrained weights need to be manually downloaded from {}'
        #         .format(model_urls['mobilenetv2_x1_4']))

        # try:
        # state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
        # # Remove module.encoder in name
        # new_state_dict = {}
        # for k in state_dict:
        #     new_k = '.'.join(k.split('.')[2:])
        #     if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
        #         new_state_dict[new_k] = state_dict[k]
        # state_dict = new_state_dict
        # logger.info(f"Loading pretrained model from {pretrain_path}")
    # except FileNotFoundError or KeyError:
    #     # original resnet
    #     state_dict = model_zoo.load_url(model_urls[depth])
    #     logger.info("Loading pretrained model from torchvision")

    return model
コード例 #19
0
def build_vit_backbone(cfg):
    """
    Create a Vision Transformer instance from config.
    Returns:
        SwinTransformer: a :class:`SwinTransformer` instance.
    """
    # fmt: off
    input_size      = cfg.INPUT.SIZE_TRAIN
    pretrain        = cfg.MODEL.BACKBONE.PRETRAIN
    pretrain_path   = cfg.MODEL.BACKBONE.PRETRAIN_PATH
    depth           = cfg.MODEL.BACKBONE.DEPTH
    sie_xishu       = cfg.MODEL.BACKBONE.SIE_COE
    stride_size     = cfg.MODEL.BACKBONE.STRIDE_SIZE
    drop_ratio      = cfg.MODEL.BACKBONE.DROP_RATIO
    drop_path_ratio = cfg.MODEL.BACKBONE.DROP_PATH_RATIO
    attn_drop_rate  = cfg.MODEL.BACKBONE.ATT_DROP_RATE
    # fmt: on

    num_depth = {
        'small': 8,
        'base': 12,
    }[depth]

    num_heads = {
        'small': 8,
        'base': 12,
    }[depth]

    mlp_ratio = {
        'small': 3.,
        'base': 4.
    }[depth]

    qkv_bias = {
        'small': False,
        'base': True
    }[depth]

    qk_scale = {
        'small': 768 ** -0.5,
        'base': None,
    }[depth]

    model = VisionTransformer(img_size=input_size, sie_xishu=sie_xishu, stride_size=stride_size, depth=num_depth,
                              num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              drop_path_rate=drop_path_ratio, drop_rate=drop_ratio, attn_drop_rate=attn_drop_rate)

    if pretrain:
        try:
            state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
            logger.info(f"Loading pretrained model from {pretrain_path}")

            if 'model' in state_dict:
                state_dict = state_dict.pop('model')
            if 'state_dict' in state_dict:
                state_dict = state_dict.pop('state_dict')
            for k, v in state_dict.items():
                if 'head' in k or 'dist' in k:
                    continue
                if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
                    # For old models that I trained prior to conv based patchification
                    O, I, H, W = model.patch_embed.proj.weight.shape
                    v = v.reshape(O, -1, H, W)
                elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
                    # To resize pos embedding when using model at different size from pretrained weights
                    if 'distilled' in pretrain_path:
                        logger.info("distill need to choose right cls token in the pth.")
                        v = torch.cat([v[:, 0:1], v[:, 2:]], dim=1)
                    v = resize_pos_embed(v, model.pos_embed.data, model.patch_embed.num_y, model.patch_embed.num_x)
                state_dict[k] = v
        except FileNotFoundError as e:
            logger.info(f'{pretrain_path} is not found! Please check this path.')
            raise e
        except KeyError as e:
            logger.info("State dict keys error! Please check the state dict.")
            raise e

        incompatible = model.load_state_dict(state_dict, strict=False)
        if incompatible.missing_keys:
            logger.info(
                get_missing_parameters_message(incompatible.missing_keys)
            )
        if incompatible.unexpected_keys:
            logger.info(
                get_unexpected_parameters_message(incompatible.unexpected_keys)
            )

    return model