Ejemplo n.º 1
0
def test_densenet():
    from torchvision.models import densenet121

    net1 = E.DenseNet121Encoder(pretrained=False)
    net2 = densenet121(pretrained=False)
    net2.classifier = None

    print(count_parameters(net1), count_parameters(net2))
Ejemplo n.º 2
0
def test_densenet_encoder():
    dn121 = E.DenseNet121Encoder(layers=[0, 1, 2, 3, 4])
    out121 = dn121(torch.randn(2, 3, 512, 512))
    print([o.size() for o in out121])

    dn169 = E.DenseNet169Encoder(layers=[0, 1, 2, 3, 4])
    out169 = dn169(torch.randn(2, 3, 512, 512))
    print([o.size() for o in out169])

    dn201 = E.DenseNet201Encoder(layers=[0, 1, 2, 3, 4])
    out201 = dn201(torch.randn(2, 3, 512, 512))
    print([o.size() for o in out201])
Ejemplo n.º 3
0
def densenet121_unet32(input_channels=3,
                       num_classes=1,
                       dropout=0.0,
                       pretrained=True):
    encoder = E.DenseNet121Encoder(pretrained=pretrained,
                                   layers=[0, 1, 2, 3, 4])
    if input_channels != 3:
        encoder.change_input_channels(input_channels)

    return UnetSegmentationModel(encoder,
                                 num_classes=num_classes,
                                 unet_channels=[32, 64, 128, 256],
                                 dropout=dropout)
Ejemplo n.º 4
0
def densenet121_unet128(input_channels=3, num_classes=1, dropout=0.0, pretrained=True):
    encoder = E.DenseNet121Encoder(pretrained=pretrained, layers=[1, 2, 3, 4])
    if input_channels != 3:
        encoder.change_input_channels(input_channels)

    return UnetV3SegmentationModel(
        encoder,
        num_classes=num_classes,
        unet_channels=[128, 128, 256],
        last_upsample_filters=128,
        dropout=dropout,
        abn_block=partial(ABN, activation=ACT_RELU),
    )
Ejemplo n.º 5
0
def densenet121_unet_v2(input_channels=6,
                        num_classes=5,
                        dropout=0.0,
                        pretrained=True,
                        classifiers=True):
    encoder = E.DenseNet121Encoder(pretrained=pretrained,
                                   layers=[0, 1, 2, 3, 4])
    return UnetV2SegmentationModel(
        encoder,
        num_classes=num_classes,
        disaster_type_classes=len(DISASTER_TYPES) if classifiers else None,
        damage_type_classes=len(DAMAGE_TYPES) if classifiers else None,
        unet_channels=[64, 128, 256, 256],
        dropout=dropout,
        abn_block=partial(ABN, activation=ACT_RELU),
    )