Exemplo n.º 1
0
    def test_forward_inference(self):
        entropy_bottleneck = EntropyBottleneck(128)
        entropy_bottleneck.eval()
        x = torch.rand(1, 128, 32, 32)
        y, y_likelihoods = entropy_bottleneck(x)

        assert y.shape == x.shape
        assert y_likelihoods.shape == x.shape

        assert (y == torch.round(x)).all()
Exemplo n.º 2
0
    def test_forward_inference_ND(self):
        entropy_bottleneck = EntropyBottleneck(128)
        entropy_bottleneck.eval()

        # Test from 1 to 5 dimensions
        for i in range(1, 6):
            x = torch.rand(1, 128, *([4] * i))
            y, y_likelihoods = entropy_bottleneck(x)

            assert y.shape == x.shape
            assert y_likelihoods.shape == x.shape

            assert (y == torch.round(x)).all()