def test_UNet():
    x = torch.rand((1, 1, 32 * 12, 32 * 12))
    with torch.no_grad():
        # unet
        unet = UNet()
        unet(x)
        # custom encoder
        unet = UNet(
            encoder=lambda *args, **kwargs: ResNet.resnet26(*args, **kwargs).encoder,
        )
        unet(x)
        # change decoder
        unet = UNet(decoder=partial(UNetDecoder, widths=[256, 128, 64, 32, 16]))
        unet(x)
        # using efficienet net
        unet = UNet(
            encoder=lambda *args, **kwargs: EfficientNet.efficientnet_b2(
                *args, **kwargs
            ).encoder
        )
        unet(x)
        # combine them
        unet = UNet(
            encoder=lambda *args, **kwargs: EfficientNet.efficientnet_b2(
                *args, **kwargs
            ).encoder,
            decoder=partial(UNetDecoder, widths=[256, 128, 64, 32, 16]),
        )
        unet(x)
        unet = UNet(
            encoder=lambda *args, **kwargs: EfficientNetLite.efficientnet_lite3(
                *args, **kwargs
            ).encoder,
        )
        unet(x)
        # customize the encoder
        unet = UNet(
            encoder=partial(ResNetEncoder, block=ResNetBasicBlock, depths=[1, 1, 2, 2])
        )
        unet(x)
        unet = UNet(
            encoder=partial(
                ResNetEncoder, block=ResNetBottleneckBlock, depths=[1, 1, 2, 2]
            )
        )
        unet(x)
        # custom block
        unet = UNet(encoder=partial(UNetEncoder, block=SENetBasicBlock))
        unet(x)

        # using .from_encoder
        unet = UNet.from_encoder(
            lambda *args, **kwargs: ResNet.resnet26(*args, **kwargs)
        )
        unet(x)
        # with AutoModel
        unet = UNet.from_encoder(partial(AutoModel.from_name, "resnet18"))
        unet(x)
        with pytest.raises(AttributeError):
            unet = UNet.from_encoder(lambda *args, **kwargs: None)
Esempio n. 2
0
def test_EfficientNetLite2():
    x = torch.rand(1, 3, 224, 224)
    model = EfficientNetLite.efficientnet_lite2()
    pred = model(x)
    assert pred.shape[-1] == 1000