Ejemplo n.º 1
0
 def func(chol_vec, delta):
     chol = at.stack([
         at.stack([at.exp(0.1 * chol_vec[0]), 0]),
         at.stack([chol_vec[1], 2 * at.exp(chol_vec[2])]),
     ])
     cov = at.dot(chol, chol.T)
     return MvNormalLogp()(cov, delta)
Ejemplo n.º 2
0
 def test_hessian(self):
     chol_vec = at.vector("chol_vec")
     chol_vec.tag.test_value = np.array([0.1, 2, 3])
     chol = at.stack([
         at.stack([at.exp(0.1 * chol_vec[0]), 0]),
         at.stack([chol_vec[1], 2 * at.exp(chol_vec[2])]),
     ])
     cov = at.dot(chol, chol.T)
     delta = at.matrix("delta")
     delta.tag.test_value = np.ones((5, 2))
     logp = MvNormalLogp()(cov, delta)
     g_cov, g_delta = at.grad(logp, [cov, delta])
     at.grad(g_delta.sum() + g_cov.sum(), [delta, cov])
Ejemplo n.º 3
0
 def test_hessian(self):
     chol_vec = at.vector("chol_vec")
     chol_vec.tag.test_value = floatX(np.array([0.1, 2, 3]))
     chol = at.stack([
         at.stack([at.exp(0.1 * chol_vec[0]), 0]),
         at.stack([chol_vec[1], 2 * at.exp(chol_vec[2])]),
     ])
     cov = at.dot(chol, chol.T)
     delta = at.matrix("delta")
     delta.tag.test_value = floatX(np.ones((5, 2)))
     logp = MvNormalLogp()(cov, delta)
     g_cov, g_delta = at.grad(logp, [cov, delta])
     # TODO: What's the test?  Something needs to be asserted.
     at.grad(g_delta.sum() + g_cov.sum(), [delta, cov])
Ejemplo n.º 4
0
    def test_logp(self):
        np.random.seed(42)

        chol_val = floatX(np.array([[1, 0.9], [0, 2]]))
        cov_val = floatX(np.dot(chol_val, chol_val.T))
        cov = at.matrix("cov")
        cov.tag.test_value = cov_val
        delta_val = floatX(np.random.randn(5, 2))
        delta = at.matrix("delta")
        delta.tag.test_value = delta_val
        expect = stats.multivariate_normal(mean=np.zeros(2), cov=cov_val)
        expect = expect.logpdf(delta_val).sum()
        logp = MvNormalLogp()(cov, delta)
        logp_f = aesara.function([cov, delta], logp)
        logp = logp_f(cov_val, delta_val)
        npt.assert_allclose(logp, expect)