예제 #1
0
def fpn256_resnext50(num_classes=1, num_channels=3):
    assert num_channels == 3
    encoder = E.SEResNeXt50Encoder()
    decoder = D.FPNDecoder(features=encoder.output_filters,
                           prediction_block=DoubleConvRelu,
                           bottleneck=FPNBottleneckBlockBN,
                           fpn_features=256)

    return SegmentationModel(encoder, decoder, num_classes)
예제 #2
0
def seresnext50_unet64(input_channels=3, num_classes=1, dropout=0.0, pretrained=True):
    encoder = E.SEResNeXt50Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4])
    if input_channels != 3:
        encoder.change_input_channels(input_channels)

    return UnetV3SegmentationModel(
        encoder,
        num_classes=num_classes,
        unet_channels=[64, 128, 256, 256],
        dropout=dropout,
        abn_block=partial(ABN, activation=ACT_RELU),
    )
예제 #3
0
def seresnext50_unet_v2(input_channels=6,
                        num_classes=5,
                        dropout=0.0,
                        pretrained=True,
                        classifiers=True):
    encoder = E.SEResNeXt50Encoder(pretrained=pretrained,
                                   layers=[0, 1, 2, 3, 4])
    return UnetV2SegmentationModel(
        encoder,
        num_classes=num_classes,
        disaster_type_classes=len(DISASTER_TYPES) if classifiers else None,
        damage_type_classes=len(DAMAGE_TYPES) if classifiers else None,
        unet_channels=[64, 128, 256, 256],
        dropout=dropout,
        abn_block=partial(ABN, activation=ACT_RELU),
    )
예제 #4
0
def resnext50_fpn(num_classes=1, fpn_features=128):
    encoder = E.SEResNeXt50Encoder()
    return FPNSegmentationModel(encoder, num_classes, fpn_features, dropout=0.5)
예제 #5
0
def seresnext50_fpncat128(num_classes=5, dropout=0.0, pretrained=True):
    encoder = E.SEResNeXt50Encoder(pretrained=pretrained)
    return FPNCatSegmentationModel(encoder, num_classes=num_classes, fpn_channels=128, dropout=dropout)
예제 #6
0
def hd_fpn_resnext50(num_classes=1, num_channels=3, fpn_features=128):
    assert num_channels == 3
    encoder = E.SEResNeXt50Encoder(layers=[0, 1, 2, 3, 4])
    return HiResSegmentationModel(encoder, num_classes, fpn_features)