def test_forward_inference(self): entropy_bottleneck = EntropyBottleneck(128) entropy_bottleneck.eval() x = torch.rand(1, 128, 32, 32) y, y_likelihoods = entropy_bottleneck(x) assert y.shape == x.shape assert y_likelihoods.shape == x.shape assert (y == torch.round(x)).all()
def test_forward_inference_ND(self): entropy_bottleneck = EntropyBottleneck(128) entropy_bottleneck.eval() # Test from 1 to 5 dimensions for i in range(1, 6): x = torch.rand(1, 128, *([4] * i)) y, y_likelihoods = entropy_bottleneck(x) assert y.shape == x.shape assert y_likelihoods.shape == x.shape assert (y == torch.round(x)).all()