def test_rsample_number(self, broadcast_all): rate = (torch.ones(1) / 2).log() def new_mock_tensor(shape): x = torch.ones(shape) x.exponential_ = Mock(return_value=x) return x rate.new = Mock(side_effect=new_mock_tensor) broadcast_all.return_value = (rate,) dist = SimpleExponential(0.5) self.assertTrue(((dist.rsample(sample_shape=torch.Size([2])) - 2.0).abs() < 0.0001).all())
def test_log_prob_number(self): dist = SimpleExponential(math.log(0.5)) self.assertTrue( ((dist.log_prob(torch.ones(2, 2)) + 1.1931).abs() < 0.0001).all())
def test_log_prob_tensor(self): dist = SimpleExponential((torch.ones(2, 2) / 2).log()) self.assertTrue( ((dist.log_prob(torch.ones(2, 2)) + 1.1931).abs() < 0.0001).all())
def test_divergence(self): callback = SimpleExponentialSimpleExponentialKL(key, key) input = SimpleExponential(torch.ones(2, 2) - 1.6931) target = SimpleExponential(torch.zeros(2, 2)) self.assertTrue(((callback.compute(input, target) - 0.3068).abs() < 0.0001).all())