コード例 #1
0
ファイル: test_omt_mvn.py プロジェクト: pyro-ppl/pyro
def test_mean_single_gradient(mvn_dist,
                              k,
                              L21,
                              omega1,
                              L11,
                              L22=0.8,
                              L33=0.9,
                              omega2=0.75,
                              n_samples=20000):
    omega = torch.tensor([omega1, omega2, 0.0])
    loc = torch.zeros(3, requires_grad=True)
    zero_vec = [0.0, 0.0, 0.0]
    off_diag = torch.tensor([zero_vec, [L21, 0.0, 0.0], zero_vec],
                            requires_grad=True)
    L = torch.diag(torch.tensor([L11, L22, L33])) + off_diag

    if mvn_dist == "OMTMultivariateNormal":
        dist = OMTMultivariateNormal(loc, L)
    elif mvn_dist == "AVFMultivariateNormal":
        CV = (0.2 * torch.rand(2, k, 3)).requires_grad_(True)
        dist = AVFMultivariateNormal(loc, L, CV)

    computed_grads = []

    for _ in range(n_samples):
        z = dist.rsample()
        torch.cos((omega * z).sum(-1)).mean().backward()
        assert off_diag.grad.size() == off_diag.size()
        assert loc.grad.size() == loc.size()
        assert torch.triu(off_diag.grad, 1).sum() == 0.0

        computed_grad = off_diag.grad.cpu()[1, 0].item()
        computed_grads.append(computed_grad)
        off_diag.grad.zero_()
        loc.grad.zero_()

    computed_grad = np.mean(computed_grads)
    analytic = analytic_grad(L11=L11,
                             L22=L22,
                             L21=L21,
                             omega1=omega1,
                             omega2=omega2)
    assert_equal(
        analytic,
        computed_grad,
        prec=0.01,
        msg="bad cholesky grad for %s (expected %.5f, got %.5f)" %
        (mvn_dist, analytic, computed_grad),
    )
コード例 #2
0
ファイル: test_omt_mvn.py プロジェクト: youngshingjun/pyro
def test_log_prob(mvn_dist):
    loc = torch.tensor([2.0, 1.0, 1.0, 2.0, 2.0])
    D = torch.tensor([1.0, 2.0, 3.0, 1.0, 3.0])
    W = torch.tensor([[1.0, -1.0, 2.0, 2.0, 4.0], [2.0, 1.0, 1.0, 2.0, 6.0]])
    x = torch.tensor([2.0, 3.0, 4.0, 1.0, 7.0])
    L = D.diag() + torch.tril(W.t().matmul(W))
    cov = torch.mm(L, L.t())

    mvn = MultivariateNormal(loc, cov)
    if mvn_dist == OMTMultivariateNormal:
        mvn_prime = OMTMultivariateNormal(loc, L)
    elif mvn_dist == AVFMultivariateNormal:
        CV = 0.2 * torch.rand(2, 2, 5)
        mvn_prime = AVFMultivariateNormal(loc, L, CV)
    assert_equal(mvn.log_prob(x), mvn_prime.log_prob(x))
コード例 #3
0
ファイル: test_omt_mvn.py プロジェクト: pyro-ppl/pyro
def test_mean_gradient(mvn_dist,
                       k,
                       sample_shape,
                       L21,
                       omega1,
                       L11,
                       L22=0.8,
                       L33=0.9,
                       omega2=0.75):
    if mvn_dist == "OMTMultivariateNormal" and k > 1:
        return

    omega = torch.tensor([omega1, omega2, 0.0])
    loc = torch.zeros(3, requires_grad=True)
    zero_vec = [0.0, 0.0, 0.0]
    off_diag = torch.tensor([zero_vec, [L21, 0.0, 0.0], zero_vec],
                            requires_grad=True)
    L = torch.diag(torch.tensor([L11, L22, L33])) + off_diag

    if mvn_dist == "OMTMultivariateNormal":
        dist = OMTMultivariateNormal(loc, L)
    elif mvn_dist == "AVFMultivariateNormal":
        CV = (1.1 * torch.rand(2, k, 3)).requires_grad_(True)
        dist = AVFMultivariateNormal(loc, L, CV)

    z = dist.rsample(sample_shape)
    torch.cos((omega * z).sum(-1)).mean().backward()

    computed_grad = off_diag.grad.cpu().data.numpy()[1, 0]
    analytic = analytic_grad(L11=L11,
                             L22=L22,
                             L21=L21,
                             omega1=omega1,
                             omega2=omega2)
    assert off_diag.grad.size() == off_diag.size()
    assert loc.grad.size() == loc.size()
    assert torch.triu(off_diag.grad, 1).sum() == 0.0
    assert_equal(
        analytic,
        computed_grad,
        prec=0.005,
        msg="bad cholesky grad for %s (expected %.5f, got %.5f)" %
        (mvn_dist, analytic, computed_grad),
    )