Exemplo n.º 1
0
def se_resnet152(num_classes=1_000):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(SEBottleneck, [3, 8, 36, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


if __name__ == "__main__":

    import torch
    model = se_resnet50()

    # 替换网络层
    for name, module in model.named_modules():
        print("layer name:{}, layer instance:{}".format(name, module))
    in_feat_num = model.fc.in_features
    model.fc = nn.Linear(in_feat_num, 102)

    # forward
    fake_img = torch.randn((1, 3, 224, 224))  # batchsize * channel * height * width
    output = model(fake_img)
    print(output.shape)




Exemplo n.º 2
0
def se_resnet50(num_classes=1_000, pretrained=False):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(SEBottleneck, [3, 4, 6, 3])
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    if pretrained:
        # state_dict = load_state_dict_from_url(model_urls['resnet50'],
        #                                       progress=progress)
        # model.load_state_dict(state_dict)
        model.load_state_dict(torch.load(PATH))
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, num_classes)
    return model


def se_resnet101(num_classes=1_000):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(SEBottleneck, [3, 4, 23, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def se_resnet152(num_classes=1_000):
Exemplo n.º 3
0
def cbamresnet10(n_classes: int) -> nn.Module:
    model = ResNet(PreActCbamBlock, [1, 1, 1, 1])
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    return model
Exemplo n.º 4
0
def prepare_resnet(resnet: ResNet, num_classes: int):
    resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)