def __init__(self, entropy_bottleneck_channels, init_weights=True): super().__init__() self.entropy_bottleneck1 = EntropyBottleneck( entropy_bottleneck_channels) self.entropy_bottleneck2 = EntropyBottleneck( entropy_bottleneck_channels) if init_weights: self._initialize_weights()
def test_forward_inference(self): entropy_bottleneck = EntropyBottleneck(128) entropy_bottleneck.eval() x = torch.rand(1, 128, 32, 32) y, y_likelihoods = entropy_bottleneck(x) assert y.shape == x.shape assert y_likelihoods.shape == x.shape assert (y == torch.round(x)).all()
def test_forward_inference_ND(self): entropy_bottleneck = EntropyBottleneck(128) entropy_bottleneck.eval() # Test from 1 to 5 dimensions for i in range(1, 6): x = torch.rand(1, 128, *([4] * i)) y, y_likelihoods = entropy_bottleneck(x) assert y.shape == x.shape assert y_likelihoods.shape == x.shape assert (y == torch.round(x)).all()
def __init__(self, entropy_bottleneck_channels, init_weights=None): super().__init__() self.entropy_bottleneck = EntropyBottleneck( entropy_bottleneck_channels) if init_weights is not None: warnings.warn( "init_weights was removed as it was never functional", DeprecationWarning, )
def __init__(self, entropy_bottleneck_channels, scale=8, init_weights=True): super().__init__() for m in range(scale): self.add_module(f'entropy_bottleneck_{str(m)}', EntropyBottleneck(entropy_bottleneck_channels)) if init_weights: self._initialize_weights()
def test_forward_training(self): entropy_bottleneck = EntropyBottleneck(128) x = torch.rand(1, 128, 32, 32) y, y_likelihoods = entropy_bottleneck(x) assert isinstance(entropy_bottleneck, EntropyModel) assert y.shape == x.shape assert y_likelihoods.shape == x.shape assert ((y - x) <= 0.5).all() assert ((y - x) >= -0.5).all() assert (y != torch.round(x)).any()
def test_compression_2D(self): x = torch.rand(1, 128, 32, 32) eb = EntropyBottleneck(128) eb.update() s = eb.compress(x) x2 = eb.decompress(s, x.size()[2:]) assert torch.allclose(torch.round(x), x2)
def test_scripting(self): entropy_bottleneck = EntropyBottleneck(128) x = torch.rand(1, 128, 32, 32) torch.manual_seed(32) y0 = entropy_bottleneck(x) m = torch.jit.script(entropy_bottleneck) torch.manual_seed(32) y1 = m(x) assert torch.allclose(y0[0], y1[0]) assert torch.all(y1[1] == 0) # not yet supported
def test_compression_ND(self): eb = EntropyBottleneck(128) eb.update() # Test 0D x = torch.rand(1, 128) s = eb.compress(x) x2 = eb.decompress(s, []) assert torch.allclose(torch.round(x), x2) # Test from 1 to 5 dimensions for i in range(1, 6): x = torch.rand(1, 128, *([4] * i)) s = eb.compress(x) x2 = eb.decompress(s, x.size()[2:]) assert torch.allclose(torch.round(x), x2)
def test_loss(self): entropy_bottleneck = EntropyBottleneck(128) loss = entropy_bottleneck.loss() assert len(loss.size()) == 0 assert loss.numel() == 1
def test_script(self): eb = EntropyBottleneck(32) eb = torch.jit.script(eb) x = torch.rand(1, 32, 4, 4) x_q, likelihoods = eb(x) assert (likelihoods == torch.zeros_like(x_q)).all()