def create_model(model_name, num_classes=1000, pretrained=False, **kwargs): if 'test_time_pool' in kwargs: test_time_pool = kwargs.pop('test_time_pool') else: test_time_pool = True if 'extra' in kwargs: extra = kwargs.pop('extra') else: extra = True if model_name == 'dpn68': model = dpn68(num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) elif model_name == 'dpn68b': model = dpn68b(num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) elif model_name == 'dpn92': model = dpn92(num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool, extra=extra) elif model_name == 'dpn98': model = dpn98(num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) elif model_name == 'dpn131': model = dpn131(num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) elif model_name == 'dpn107': model = dpn107(num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) elif model_name == 'resnet18': model = resnet18(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'resnet34': model = resnet34(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'resnet50': model = resnet50(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'resnet101': model = resnet101(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'resnet152': model = resnet152(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'densenet121': model = densenet121(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'densenet161': model = densenet161(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'densenet169': model = densenet169(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'densenet201': model = densenet201(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'inception_v3': model = inception_v3(num_classes=num_classes, pretrained=pretrained, transform_input=False, **kwargs) else: assert False, "Unknown model architecture (%s)" % model_name return model
from dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107 from sotabench.image_classification import ImageNet import torchvision.transforms as transforms import PIL normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) input_transform = transforms.Compose([ transforms.Resize(256, PIL.Image.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]) ImageNet.benchmark(model=dpn131(pretrained=True), paper_model_name='DPN-131 x224', paper_arxiv_id='1707.01629', paper_pwc_id='dual-path-networks', input_transform=input_transform, batch_size=256, num_gpu=1)