Example #1
0
def get_model(pretrained=True, n_class=24):
    # model = torchvision.models.resnext50_32x4d(pretrained=False)
    # model = torchvision.models.resnext101_32x8d(pretrained=False)
    model = ResNet(**MODEL_CONFIGS["resnest50_fast_1s1x64d"])
    n_features = model.fc.in_features
    model.fc = nn.Linear(n_features, 264)
    # model.load_state_dict(torch.load('resnext50_32x4d_extra_2.pt'))
    # model.load_state_dict(torch.load('resnext101_32x8d_wsl_extra_4.pt'))
    fn = '../input/birds-cp-1/resnest50_fast_1s1x64d_conf_1.pt'
    model.load_state_dict(torch.load(fn, map_location='cpu'))
    model.fc = nn.Linear(n_features, n_class)
    return model
Example #2
0
def resnet101(pretrained=False, root='~/.encoding/models', **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    kwargs['radix'] = 0
    kwargs['rectified_conv'] = True
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(torch.hub.load_state_dict_from_url(
            rectify_model_urls['resnet101'], progress=True, check_hash=True,
            map_location=torch.device('cpu')))
    return model
Example #3
0
def resnext50_32x4d(pretrained=False, root='~/.encoding/models', **kwargs):
    r"""ResNeXt-50 32x4d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['radix'] = 0
    kwargs['groups'] = 32
    kwargs['bottleneck_width'] = 4
    kwargs['rectified_conv'] = True
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(torch.hub.load_state_dict_from_url(
            rectify_model_urls['resnext50_32x4d'], progress=True, check_hash=True,
            map_location=torch.device('cpu')))
    return model
Example #4
0
def resnest50(num_classes, pretrained=True, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3],
                   radix=2,
                   groups=1,
                   bottleneck_width=64,
                   num_classes=num_classes,
                   deep_stem=True,
                   stem_width=32,
                   avg_down=True,
                   avd=True,
                   avd_first=False,
                   dilation=2,
                   **kwargs)
    if pretrained:
        model.load_state_dict(
            torch.hub.load_state_dict_from_url(resnest_model_urls['resnest50'],
                                               progress=True,
                                               check_hash=True))
    return model