Esempio n. 1
0
    def __init__(self, entropy_bottleneck_channels, init_weights=True):
        super().__init__()
        self.entropy_bottleneck1 = EntropyBottleneck(
            entropy_bottleneck_channels)

        self.entropy_bottleneck2 = EntropyBottleneck(
            entropy_bottleneck_channels)

        if init_weights:
            self._initialize_weights()
Esempio n. 2
0
    def test_forward_inference(self):
        entropy_bottleneck = EntropyBottleneck(128)
        entropy_bottleneck.eval()
        x = torch.rand(1, 128, 32, 32)
        y, y_likelihoods = entropy_bottleneck(x)

        assert y.shape == x.shape
        assert y_likelihoods.shape == x.shape

        assert (y == torch.round(x)).all()
Esempio n. 3
0
    def test_forward_inference_ND(self):
        entropy_bottleneck = EntropyBottleneck(128)
        entropy_bottleneck.eval()

        # Test from 1 to 5 dimensions
        for i in range(1, 6):
            x = torch.rand(1, 128, *([4] * i))
            y, y_likelihoods = entropy_bottleneck(x)

            assert y.shape == x.shape
            assert y_likelihoods.shape == x.shape

            assert (y == torch.round(x)).all()
Esempio n. 4
0
    def __init__(self, entropy_bottleneck_channels, init_weights=None):
        super().__init__()
        self.entropy_bottleneck = EntropyBottleneck(
            entropy_bottleneck_channels)

        if init_weights is not None:
            warnings.warn(
                "init_weights was removed as it was never functional",
                DeprecationWarning,
            )
Esempio n. 5
0
    def __init__(self,
                 entropy_bottleneck_channels,
                 scale=8,
                 init_weights=True):
        super().__init__()
        for m in range(scale):
            self.add_module(f'entropy_bottleneck_{str(m)}',
                            EntropyBottleneck(entropy_bottleneck_channels))

        if init_weights:
            self._initialize_weights()
Esempio n. 6
0
    def test_forward_training(self):
        entropy_bottleneck = EntropyBottleneck(128)
        x = torch.rand(1, 128, 32, 32)
        y, y_likelihoods = entropy_bottleneck(x)

        assert isinstance(entropy_bottleneck, EntropyModel)
        assert y.shape == x.shape
        assert y_likelihoods.shape == x.shape

        assert ((y - x) <= 0.5).all()
        assert ((y - x) >= -0.5).all()
        assert (y != torch.round(x)).any()
Esempio n. 7
0
    def test_compression_2D(self):
        x = torch.rand(1, 128, 32, 32)
        eb = EntropyBottleneck(128)
        eb.update()
        s = eb.compress(x)
        x2 = eb.decompress(s, x.size()[2:])

        assert torch.allclose(torch.round(x), x2)
Esempio n. 8
0
    def test_scripting(self):
        entropy_bottleneck = EntropyBottleneck(128)
        x = torch.rand(1, 128, 32, 32)

        torch.manual_seed(32)
        y0 = entropy_bottleneck(x)

        m = torch.jit.script(entropy_bottleneck)

        torch.manual_seed(32)
        y1 = m(x)

        assert torch.allclose(y0[0], y1[0])
        assert torch.all(y1[1] == 0)  # not yet supported
Esempio n. 9
0
    def test_compression_ND(self):
        eb = EntropyBottleneck(128)
        eb.update()
        # Test 0D
        x = torch.rand(1, 128)
        s = eb.compress(x)
        x2 = eb.decompress(s, [])

        assert torch.allclose(torch.round(x), x2)

        # Test from 1 to 5 dimensions
        for i in range(1, 6):
            x = torch.rand(1, 128, *([4] * i))
            s = eb.compress(x)
            x2 = eb.decompress(s, x.size()[2:])

            assert torch.allclose(torch.round(x), x2)
Esempio n. 10
0
    def test_loss(self):
        entropy_bottleneck = EntropyBottleneck(128)
        loss = entropy_bottleneck.loss()

        assert len(loss.size()) == 0
        assert loss.numel() == 1
Esempio n. 11
0
 def test_script(self):
     eb = EntropyBottleneck(32)
     eb = torch.jit.script(eb)
     x = torch.rand(1, 32, 4, 4)
     x_q, likelihoods = eb(x)
     assert (likelihoods == torch.zeros_like(x_q)).all()