def test_mean_single_gradient(mvn_dist, k, L21, omega1, L11, L22=0.8, L33=0.9, omega2=0.75, n_samples=20000): omega = torch.tensor([omega1, omega2, 0.0]) loc = torch.zeros(3, requires_grad=True) zero_vec = [0.0, 0.0, 0.0] off_diag = torch.tensor([zero_vec, [L21, 0.0, 0.0], zero_vec], requires_grad=True) L = torch.diag(torch.tensor([L11, L22, L33])) + off_diag if mvn_dist == "OMTMultivariateNormal": dist = OMTMultivariateNormal(loc, L) elif mvn_dist == "AVFMultivariateNormal": CV = (0.2 * torch.rand(2, k, 3)).requires_grad_(True) dist = AVFMultivariateNormal(loc, L, CV) computed_grads = [] for _ in range(n_samples): z = dist.rsample() torch.cos((omega * z).sum(-1)).mean().backward() assert off_diag.grad.size() == off_diag.size() assert loc.grad.size() == loc.size() assert torch.triu(off_diag.grad, 1).sum() == 0.0 computed_grad = off_diag.grad.cpu()[1, 0].item() computed_grads.append(computed_grad) off_diag.grad.zero_() loc.grad.zero_() computed_grad = np.mean(computed_grads) analytic = analytic_grad(L11=L11, L22=L22, L21=L21, omega1=omega1, omega2=omega2) assert_equal( analytic, computed_grad, prec=0.01, msg="bad cholesky grad for %s (expected %.5f, got %.5f)" % (mvn_dist, analytic, computed_grad), )
def test_log_prob(mvn_dist): loc = torch.tensor([2.0, 1.0, 1.0, 2.0, 2.0]) D = torch.tensor([1.0, 2.0, 3.0, 1.0, 3.0]) W = torch.tensor([[1.0, -1.0, 2.0, 2.0, 4.0], [2.0, 1.0, 1.0, 2.0, 6.0]]) x = torch.tensor([2.0, 3.0, 4.0, 1.0, 7.0]) L = D.diag() + torch.tril(W.t().matmul(W)) cov = torch.mm(L, L.t()) mvn = MultivariateNormal(loc, cov) if mvn_dist == OMTMultivariateNormal: mvn_prime = OMTMultivariateNormal(loc, L) elif mvn_dist == AVFMultivariateNormal: CV = 0.2 * torch.rand(2, 2, 5) mvn_prime = AVFMultivariateNormal(loc, L, CV) assert_equal(mvn.log_prob(x), mvn_prime.log_prob(x))
def test_mean_gradient(mvn_dist, k, sample_shape, L21, omega1, L11, L22=0.8, L33=0.9, omega2=0.75): if mvn_dist == "OMTMultivariateNormal" and k > 1: return omega = torch.tensor([omega1, omega2, 0.0]) loc = torch.zeros(3, requires_grad=True) zero_vec = [0.0, 0.0, 0.0] off_diag = torch.tensor([zero_vec, [L21, 0.0, 0.0], zero_vec], requires_grad=True) L = torch.diag(torch.tensor([L11, L22, L33])) + off_diag if mvn_dist == "OMTMultivariateNormal": dist = OMTMultivariateNormal(loc, L) elif mvn_dist == "AVFMultivariateNormal": CV = (1.1 * torch.rand(2, k, 3)).requires_grad_(True) dist = AVFMultivariateNormal(loc, L, CV) z = dist.rsample(sample_shape) torch.cos((omega * z).sum(-1)).mean().backward() computed_grad = off_diag.grad.cpu().data.numpy()[1, 0] analytic = analytic_grad(L11=L11, L22=L22, L21=L21, omega1=omega1, omega2=omega2) assert off_diag.grad.size() == off_diag.size() assert loc.grad.size() == loc.size() assert torch.triu(off_diag.grad, 1).sum() == 0.0 assert_equal( analytic, computed_grad, prec=0.005, msg="bad cholesky grad for %s (expected %.5f, got %.5f)" % (mvn_dist, analytic, computed_grad), )