def test_densenet(): from torchvision.models import densenet121 net1 = E.DenseNet121Encoder(pretrained=False) net2 = densenet121(pretrained=False) net2.classifier = None print(count_parameters(net1), count_parameters(net2))
def test_densenet_encoder(): dn121 = E.DenseNet121Encoder(layers=[0, 1, 2, 3, 4]) out121 = dn121(torch.randn(2, 3, 512, 512)) print([o.size() for o in out121]) dn169 = E.DenseNet169Encoder(layers=[0, 1, 2, 3, 4]) out169 = dn169(torch.randn(2, 3, 512, 512)) print([o.size() for o in out169]) dn201 = E.DenseNet201Encoder(layers=[0, 1, 2, 3, 4]) out201 = dn201(torch.randn(2, 3, 512, 512)) print([o.size() for o in out201])
def densenet121_unet32(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): encoder = E.DenseNet121Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) if input_channels != 3: encoder.change_input_channels(input_channels) return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], dropout=dropout)
def densenet121_unet128(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): encoder = E.DenseNet121Encoder(pretrained=pretrained, layers=[1, 2, 3, 4]) if input_channels != 3: encoder.change_input_channels(input_channels) return UnetV3SegmentationModel( encoder, num_classes=num_classes, unet_channels=[128, 128, 256], last_upsample_filters=128, dropout=dropout, abn_block=partial(ABN, activation=ACT_RELU), )
def densenet121_unet_v2(input_channels=6, num_classes=5, dropout=0.0, pretrained=True, classifiers=True): encoder = E.DenseNet121Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) return UnetV2SegmentationModel( encoder, num_classes=num_classes, disaster_type_classes=len(DISASTER_TYPES) if classifiers else None, damage_type_classes=len(DAMAGE_TYPES) if classifiers else None, unet_channels=[64, 128, 256, 256], dropout=dropout, abn_block=partial(ABN, activation=ACT_RELU), )