示例#1
0
def load_features(args):
    if args.arch == 'iresnet18':
        features = iresnet.iresnet18(pretrained=True,
                                     num_classes=args.embedding_size)
    elif args.arch == 'iresnet34':
        features = iresnet.iresnet34(pretrained=True,
                                     num_classes=args.embedding_size)
    elif args.arch == 'iresnet50':
        features = iresnet.iresnet50(pretrained=True,
                                     num_classes=args.embedding_size)
    elif args.arch == 'iresnet100':
        features = iresnet.iresnet100(pretrained=True,
                                      num_classes=args.embedding_size)
    else:
        raise ValueError()
    return features
示例#2
0
def build_model(args):

    if args.arch == 'iresgroup':
        assert args.model_depth in [50, 101, 152]

        if args.model_depth == 50:
            model = iresgroup.iresgroup50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 101:
            model = iresgroup.iresgroup101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 152:
            model = iresgroup.iresgroup152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)

    if args.arch == 'iresgroupfix':
        assert args.model_depth in [50, 101, 152]

        if args.model_depth == 50:
            model = iresgroupfix.iresgroupfix50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 101:
            model = iresgroupfix.iresgroupfix101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 152:
            model = iresgroupfix.iresgroupfix152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)

    if args.arch == 'resgroupfix':
        assert args.model_depth in [50, 101, 152]

        if args.model_depth == 50:
            model = resgroupfix.resgroupfix50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 101:
            model = resgroupfix.resgroupfix101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 152:
            model = resgroupfix.resgroupfix152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)

    if args.arch == 'resgroup':
        assert args.model_depth in [50, 101, 152]

        if args.model_depth == 50:
            model = resgroup.resgroup50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 101:
            model = resgroup.resgroup101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 152:
            model = resgroup.resgroup152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)

    if args.arch == 'iresnet':
        assert args.model_depth in [18, 34, 50, 101, 152, 200, 302, 404, 1001]

        if args.model_depth == 18:
            model = iresnet.iresnet18(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 34:
            model = iresnet.iresnet34(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 50:
            model = iresnet.iresnet50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 101:
            model = iresnet.iresnet101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 152:
            model = iresnet.iresnet152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 200:
            model = iresnet.iresnet200(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 302:
            model = iresnet.iresnet302(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 404:
            model = iresnet.iresnet404(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 1001:
            model = iresnet.iresnet1001(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)

##
    if args.arch == 'seiresnet':
        assert args.model_depth in [18, 34, 50, 101, 152, 200, 302, 404, 1001]

        if args.model_depth == 18:
            model = seiresnet.seiresnet18(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 34:
            model = seiresnet.seiresnet34(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 50:
            model = seiresnet.seiresnet50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 101:
            model = seiresnet.seiresnet101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 152:
            model = seiresnet.seiresnet152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 200:
            model = seiresnet.seiresnet200(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 302:
            model = seiresnet.seiresnet302(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 404:
            model = seiresnet.seiresnet404(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 1001:
            model = seiresnet.seiresnet1001(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)

    if args.arch == 'seiresgroup':
        assert args.model_depth in [50, 101, 152]

        if args.model_depth == 50:
            model = seiresgroup.iresgroup50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 101:
            model = seiresgroup.iresgroup101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 152:
            model = seiresgroup.iresgroup152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)

##

    if args.arch == 'resstage':
        assert args.model_depth in [18, 34, 50, 101, 152, 200]

        if args.model_depth == 18:
            model = resstage.resstage18(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 34:
            model = resstage.resstage34(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 50:
            model = resstage.resstage50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 101:
            model = resstage.resstage101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 152:
            model = resstage.resstage152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 200:
            model = resstage.resstage200(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)

    if args.arch == 'resnet':
        assert args.model_depth in [18, 34, 50, 101, 152, 200]

        if args.model_depth == 18:
            model = resnet.resnet18(pretrained=args.pretrained,
                                    num_classes=args.n_classes,
                                    zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 34:
            model = resnet.resnet34(pretrained=args.pretrained,
                                    num_classes=args.n_classes,
                                    zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 50:
            model = resnet.resnet50(pretrained=args.pretrained,
                                    num_classes=args.n_classes,
                                    zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 101:
            model = resnet.resnet101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 152:
            model = resnet.resnet152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 200:
            model = resnet.resnet200(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)

    return model
def build_backbone(backbone='resnet-50',
                   layers=50,
                   output_stride=16,
                   norm_layer=None):
    # if norm_layer is None:
    #     norm_layer = nn.BatchNorm2d
    # elif norm_layer is 'gn':
    #     norm_layer = GroupNorm
    # elif norm_layer is 'frn':
    #     norm_layer = FilterResponseNorm2d
    if backbone is 'resnet':
        if layers == 50:
            model = resnet.resnet50(norm_layer=norm_layer)
            return model
        elif layers == 101:
            model = resnet.resnet101(norm_layer=norm_layer)
            return model
        elif layers == 152:
            model = resnet.resnet152(norm_layer=norm_layer)
            return model
        elif layers == 200:
            model = resnet.resnet200(norm_layer=norm_layer)
            return model

    elif backbone is 'resgroup':
        if layers == 50:
            model = resgroup.resgroup50(norm_layer=norm_layer)
            return model
        elif layers == 101:
            model = resgroup.resgroup101(norm_layer=norm_layer)
            return model
        elif layers == 152:
            model = resgroup.resgroup152(norm_layer=norm_layer)
            return model

    elif backbone is 'iresnet':
        if layers == 50:
            model = iresnet.iresnet50(norm_layer=norm_layer)
            return model
        elif layers == 101:
            model = iresnet.iresnet101(norm_layer=norm_layer)
            return model
        elif layers == 152:
            model = iresnet.iresnet152(norm_layer=norm_layer)
            return model
        elif layers == 200:
            model = iresnet.iresnet200(norm_layer=norm_layer)
            return model
        elif layers == 302:
            model = iresnet.iresnet302(norm_layer=norm_layer)
            return model
        elif layers == 404:
            model = iresnet.iresnet404(norm_layer=norm_layer)
            return model
        elif layers == 1001:
            model = iresnet.iresnet1001(norm_layer=norm_layer)
            return model

    elif backbone is 'iresgroup':
        if layers == 50:
            model = iresgroup.iresgroup50(norm_layer=norm_layer)
            return model
        elif layers == 101:
            model = iresgroup.iresgroup101(norm_layer=norm_layer)
            return model
        elif layers == 152:
            model = iresgroup.iresgroup152(norm_layer=norm_layer)
            return model

    elif backbone is 'xception':
        model = xception.xception(output_stride=output_stride,
                                  norm_layer=norm_layer)
        return model