def load_features(args): if args.arch == 'iresnet18': features = iresnet.iresnet18(pretrained=True, num_classes=args.embedding_size) elif args.arch == 'iresnet34': features = iresnet.iresnet34(pretrained=True, num_classes=args.embedding_size) elif args.arch == 'iresnet50': features = iresnet.iresnet50(pretrained=True, num_classes=args.embedding_size) elif args.arch == 'iresnet100': features = iresnet.iresnet100(pretrained=True, num_classes=args.embedding_size) else: raise ValueError() return features
def build_model(args): if args.arch == 'iresgroup': assert args.model_depth in [50, 101, 152] if args.model_depth == 50: model = iresgroup.iresgroup50( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) elif args.model_depth == 101: model = iresgroup.iresgroup101( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) elif args.model_depth == 152: model = iresgroup.iresgroup152( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) if args.arch == 'iresgroupfix': assert args.model_depth in [50, 101, 152] if args.model_depth == 50: model = iresgroupfix.iresgroupfix50( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) elif args.model_depth == 101: model = iresgroupfix.iresgroupfix101( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) elif args.model_depth == 152: model = iresgroupfix.iresgroupfix152( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) if args.arch == 'resgroupfix': assert args.model_depth in [50, 101, 152] if args.model_depth == 50: model = resgroupfix.resgroupfix50( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) elif args.model_depth == 101: model = resgroupfix.resgroupfix101( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) elif args.model_depth == 152: model = resgroupfix.resgroupfix152( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) if args.arch == 'resgroup': assert args.model_depth in [50, 101, 152] if args.model_depth == 50: model = resgroup.resgroup50( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) elif args.model_depth == 101: model = resgroup.resgroup101( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) elif args.model_depth == 152: model = resgroup.resgroup152( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) if args.arch == 'iresnet': assert args.model_depth in [18, 34, 50, 101, 152, 200, 302, 404, 1001] if args.model_depth == 18: model = iresnet.iresnet18( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 34: model = iresnet.iresnet34( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 50: model = iresnet.iresnet50( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 101: model = iresnet.iresnet101( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 152: model = iresnet.iresnet152( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 200: model = iresnet.iresnet200( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 302: model = iresnet.iresnet302( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 404: model = iresnet.iresnet404( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 1001: model = iresnet.iresnet1001( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) ## if args.arch == 'seiresnet': assert args.model_depth in [18, 34, 50, 101, 152, 200, 302, 404, 1001] if args.model_depth == 18: model = seiresnet.seiresnet18( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 34: model = seiresnet.seiresnet34( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 50: model = seiresnet.seiresnet50( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 101: model = seiresnet.seiresnet101( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 152: model = seiresnet.seiresnet152( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 200: model = seiresnet.seiresnet200( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 302: model = seiresnet.seiresnet302( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 404: model = seiresnet.seiresnet404( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 1001: model = seiresnet.seiresnet1001( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) if args.arch == 'seiresgroup': assert args.model_depth in [50, 101, 152] if args.model_depth == 50: model = seiresgroup.iresgroup50( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) elif args.model_depth == 101: model = seiresgroup.iresgroup101( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) elif args.model_depth == 152: model = seiresgroup.iresgroup152( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual, groups=args.groups) ## if args.arch == 'resstage': assert args.model_depth in [18, 34, 50, 101, 152, 200] if args.model_depth == 18: model = resstage.resstage18( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 34: model = resstage.resstage34( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 50: model = resstage.resstage50( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 101: model = resstage.resstage101( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 152: model = resstage.resstage152( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 200: model = resstage.resstage200( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) if args.arch == 'resnet': assert args.model_depth in [18, 34, 50, 101, 152, 200] if args.model_depth == 18: model = resnet.resnet18(pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 34: model = resnet.resnet34(pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 50: model = resnet.resnet50(pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 101: model = resnet.resnet101( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 152: model = resnet.resnet152( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) elif args.model_depth == 200: model = resnet.resnet200( pretrained=args.pretrained, num_classes=args.n_classes, zero_init_residual=args.zero_init_residual) return model
def build_backbone(backbone='resnet-50', layers=50, output_stride=16, norm_layer=None): # if norm_layer is None: # norm_layer = nn.BatchNorm2d # elif norm_layer is 'gn': # norm_layer = GroupNorm # elif norm_layer is 'frn': # norm_layer = FilterResponseNorm2d if backbone is 'resnet': if layers == 50: model = resnet.resnet50(norm_layer=norm_layer) return model elif layers == 101: model = resnet.resnet101(norm_layer=norm_layer) return model elif layers == 152: model = resnet.resnet152(norm_layer=norm_layer) return model elif layers == 200: model = resnet.resnet200(norm_layer=norm_layer) return model elif backbone is 'resgroup': if layers == 50: model = resgroup.resgroup50(norm_layer=norm_layer) return model elif layers == 101: model = resgroup.resgroup101(norm_layer=norm_layer) return model elif layers == 152: model = resgroup.resgroup152(norm_layer=norm_layer) return model elif backbone is 'iresnet': if layers == 50: model = iresnet.iresnet50(norm_layer=norm_layer) return model elif layers == 101: model = iresnet.iresnet101(norm_layer=norm_layer) return model elif layers == 152: model = iresnet.iresnet152(norm_layer=norm_layer) return model elif layers == 200: model = iresnet.iresnet200(norm_layer=norm_layer) return model elif layers == 302: model = iresnet.iresnet302(norm_layer=norm_layer) return model elif layers == 404: model = iresnet.iresnet404(norm_layer=norm_layer) return model elif layers == 1001: model = iresnet.iresnet1001(norm_layer=norm_layer) return model elif backbone is 'iresgroup': if layers == 50: model = iresgroup.iresgroup50(norm_layer=norm_layer) return model elif layers == 101: model = iresgroup.iresgroup101(norm_layer=norm_layer) return model elif layers == 152: model = iresgroup.iresgroup152(norm_layer=norm_layer) return model elif backbone is 'xception': model = xception.xception(output_stride=output_stride, norm_layer=norm_layer) return model