Beispiel #1
0
    def _compute_bound(self, R, logdet=None, inv=None, gradient=False):

        """
        Rotate q(X) as X->RX: q(X)=N(R*mu, R*Cov*R')

        Assume:
        :math:`p(\mathbf{X}) = \prod^M_{m=1} 
               N(\mathbf{x}_m|0, \mathbf{\Lambda})`
        """

        # TODO/FIXME: X and alpha should NOT contain observed values!! Check
        # that.

        # TODO/FIXME: Allow non-zero prior mean!

        # Assume constant mean and precision matrix over plates..

        # Compute rotated moments
        XX_R = dot(R, self.XX, R.T)

        inv_R = inv
        logdet_R = logdet

        # Compute entropy H(X)
        logH_X = random.gaussian_entropy(-2 * self.N * logdet_R, 0)

        # Compute <log p(X)>
        logp_X = random.gaussian_logpdf(np.vdot(XX_R, self.Lambda), 0, 0, 0, 0)

        # Compute the bound
        if terms:
            bound = {self.X: bound}
        else:
            bound = logp_X + logH_X

        if not gradient:
            return bound

        # Compute dH(X)
        dlogH_X = random.gaussian_entropy(-2 * self.N * inv_R.T, 0)

        # Compute d<log p(X)>
        dXX = 2 * dot(self.Lambda, R, self.XX)
        dlogp_X = random.gaussian_logpdf(dXX, 0, 0, 0, 0)

        if terms:
            d_bound = {self.X: dlogp_X + dlogH_X}
        else:
            d_bound = dlogp_X + dlogH_X

        return (bound, d_bound)
    def test_message_to_child(self):
        """
        Test the updating of GaussianMarkovChain.

        Check that the moments and the lower bound contribution are computed
        correctly.
        """

        # TODO: Add plates and missing values!

        # Dimensionalities
        D = 3
        N = 5
        (Y, X, Mu, Lambda, A, V) = self.create_model(N, D)

        # Inference with arbitrary observations
        y = np.random.randn(N, D)
        Y.observe(y)
        X.update()
        (x_vb, xnxn_vb, xpxn_vb) = X.get_moments()

        # Get parameter moments
        (mu0, mumu0) = Mu.get_moments()
        (icov0, logdet0) = Lambda.get_moments()
        (a, aa) = A.get_moments()
        (icov_x, logdetx) = V.get_moments()
        icov_x = np.diag(icov_x)
        # Prior precision
        Z = np.einsum("...kij,...kk->...ij", aa, icov_x)
        U_diag = [icov0 + Z] + (N - 2) * [icov_x + Z] + [icov_x]
        U_super = (N - 1) * [-np.dot(a.T, icov_x)]
        U = misc.block_banded(U_diag, U_super)
        # Prior mean
        mu_prior = np.zeros(D * N)
        mu_prior[:D] = np.dot(icov0, mu0)
        # Data
        Cov = np.linalg.inv(U + np.identity(D * N))
        mu = np.dot(Cov, mu_prior + y.flatten())
        # Moments
        xx = mu[:, np.newaxis] * mu[np.newaxis, :] + Cov
        mu = np.reshape(mu, (N, D))
        xx = np.reshape(xx, (N, D, N, D))

        # Check results
        self.assertAllClose(x_vb, mu, msg="Incorrect mean")
        for n in range(N):
            self.assertAllClose(xnxn_vb[n, :, :], xx[n, :, n, :], msg="Incorrect second moment")
        for n in range(N - 1):
            self.assertAllClose(xpxn_vb[n, :, :], xx[n, :, n + 1, :], msg="Incorrect lagged second moment")

        # Compute the entropy H(X)
        ldet = linalg.logdet_cov(Cov)
        H = random.gaussian_entropy(-ldet, N * D)
        # Compute <log p(X|...)>
        xx = np.reshape(xx, (N * D, N * D))
        mu = np.reshape(mu, (N * D,))
        ldet = -logdet0 - np.sum(np.ones((N - 1, D)) * logdetx)
        P = random.gaussian_logpdf(
            np.einsum("...ij,...ij", xx, U),
            np.einsum("...i,...i", mu, mu_prior),
            np.einsum("...ij,...ij", mumu0, icov0),
            -ldet,
            N * D,
        )

        # The VB bound from the net
        l = X.lower_bound_contribution()

        self.assertAllClose(l, H + P)

        # Compute the true bound <log p(X|...)> + H(X)

        #
        # Simple tests
        #

        def check(N, D, plates=None, mu=None, Lambda=None, A=None, V=None):
            if mu is None:
                mu = np.random.randn(D)
            if Lambda is None:
                Lambda = random.covariance(D)
            if A is None:
                A = np.random.randn(D, D)
            if V is None:
                V = np.random.rand(D)
            X = GaussianMarkovChain(mu, Lambda, A, V, plates=plates, n=N)
            (u0, u1, u2) = X._message_to_child()
            (mu, mumu) = Gaussian._ensure_moments(mu, GaussianMoments, ndim=1).get_moments()
            (Lambda, _) = Wishart._ensure_moments(Lambda, WishartMoments, ndim=1).get_moments()
            (a, aa) = Gaussian._ensure_moments(A, GaussianMoments, ndim=1).get_moments()
            a = a * np.ones((N - 1, D, D))  # explicit broadcasting for simplicity
            aa = aa * np.ones((N - 1, D, D, D))  # explicit broadcasting for simplicity
            (v, _) = Gamma._ensure_moments(V, GammaMoments).get_moments()
            v = v * np.ones((N - 1, D))
            plates_C = X.plates
            plates_mu = X.plates
            C = np.zeros(plates_C + (N, D, N, D))
            plates_mu = np.shape(mu)[:-1]
            m = np.zeros(plates_mu + (N, D))
            m[..., 0, :] = np.einsum("...ij,...j->...i", Lambda, mu)
            C[..., 0, :, 0, :] = Lambda + np.einsum("...dij,...d->...ij", aa[..., 0, :, :, :], v[..., 0, :])
            for n in range(1, N - 1):
                C[..., n, :, n, :] = np.einsum("...dij,...d->...ij", aa[..., n, :, :, :], v[..., n, :]) + v[
                    ..., n, :, None
                ] * np.identity(D)
            for n in range(N - 1):
                C[..., n, :, n + 1, :] = -np.einsum("...di,...d->...id", a[..., n, :, :], v[..., n, :])
                C[..., n + 1, :, n, :] = -np.einsum("...di,...d->...di", a[..., n, :, :], v[..., n, :])
            C[..., -1, :, -1, :] = v[..., -1, :, None] * np.identity(D)
            C = np.reshape(C, plates_C + (N * D, N * D))
            Cov = np.linalg.inv(C)
            Cov = np.reshape(Cov, plates_C + (N, D, N, D))
            m0 = np.einsum("...minj,...nj->...mi", Cov, m)
            m1 = np.zeros(plates_C + (N, D, D))
            m2 = np.zeros(plates_C + (N - 1, D, D))
            for n in range(N):
                m1[..., n, :, :] = Cov[..., n, :, n, :] + np.einsum("...i,...j->...ij", m0[..., n, :], m0[..., n, :])
            for n in range(N - 1):
                m2[..., n, :, :] = Cov[..., n, :, n + 1, :] + np.einsum(
                    "...i,...j->...ij", m0[..., n, :], m0[..., n + 1, :]
                )
            self.assertAllClose(m0, u0 * np.ones(np.shape(m0)))
            self.assertAllClose(m1, u1 * np.ones(np.shape(m1)))
            self.assertAllClose(m2, u2 * np.ones(np.shape(m2)))

            pass

        check(4, 1)
        check(4, 3)

        #
        # Test mu
        #

        # Simple
        check(4, 3, mu=Gaussian(np.random.randn(3), random.covariance(3)))
        # Plates
        check(4, 3, mu=Gaussian(np.random.randn(5, 6, 3), random.covariance(3), plates=(5, 6)))
        # Plates with moments broadcasted over plates
        check(4, 3, mu=Gaussian(np.random.randn(3), random.covariance(3), plates=(5,)))
        check(4, 3, mu=Gaussian(np.random.randn(1, 3), random.covariance(3), plates=(5,)))
        # Plates broadcasting
        check(4, 3, plates=(5,), mu=Gaussian(np.random.randn(3), random.covariance(3), plates=()))
        check(4, 3, plates=(5,), mu=Gaussian(np.random.randn(1, 3), random.covariance(3), plates=(1,)))

        #
        # Test Lambda
        #

        # Simple
        check(4, 3, Lambda=Wishart(10 + np.random.rand(), random.covariance(3)))
        # Plates
        check(4, 3, Lambda=Wishart(10 + np.random.rand(), random.covariance(3), plates=(5, 6)))
        # Plates with moments broadcasted over plates
        check(4, 3, Lambda=Wishart(10 + np.random.rand(), random.covariance(3), plates=(5,)))
        check(4, 3, Lambda=Wishart(10 + np.random.rand(1), random.covariance(3), plates=(5,)))
        # Plates broadcasting
        check(4, 3, plates=(5,), Lambda=Wishart(10 + np.random.rand(), random.covariance(3), plates=()))
        check(4, 3, plates=(5,), Lambda=Wishart(10 + np.random.rand(), random.covariance(3), plates=(1,)))

        #
        # Test A
        #

        # Simple
        check(4, 3, A=GaussianARD(np.random.randn(3, 3), np.random.rand(3, 3), shape=(3,), plates=(3,)))
        # Plates on time axis
        check(5, 3, A=GaussianARD(np.random.randn(4, 3, 3), np.random.rand(4, 3, 3), shape=(3,), plates=(4, 3)))
        # Plates on time axis with broadcasted moments
        check(5, 3, A=GaussianARD(np.random.randn(1, 3, 3), np.random.rand(1, 3, 3), shape=(3,), plates=(4, 3)))
        check(5, 3, A=GaussianARD(np.random.randn(3, 3), np.random.rand(3, 3), shape=(3,), plates=(4, 3)))
        # Plates
        check(
            4,
            3,
            A=GaussianARD(
                np.random.randn(5, 6, 1, 3, 3), np.random.rand(5, 6, 1, 3, 3), shape=(3,), plates=(5, 6, 1, 3)
            ),
        )
        # Plates with moments broadcasted over plates
        check(4, 3, A=GaussianARD(np.random.randn(3, 3), np.random.rand(3, 3), shape=(3,), plates=(5, 1, 3)))
        check(
            4, 3, A=GaussianARD(np.random.randn(1, 1, 3, 3), np.random.rand(1, 1, 3, 3), shape=(3,), plates=(5, 1, 3))
        )
        # Plates broadcasting
        check(4, 3, plates=(5,), A=GaussianARD(np.random.randn(3, 3), np.random.rand(3, 3), shape=(3,), plates=(3,)))
        check(
            4, 3, plates=(5,), A=GaussianARD(np.random.randn(3, 3), np.random.rand(3, 3), shape=(3,), plates=(1, 1, 3))
        )

        #
        # Test v
        #

        # Simple
        check(4, 3, V=Gamma(np.random.rand(1, 3), np.random.rand(1, 3), plates=(1, 3)))
        check(4, 3, V=Gamma(np.random.rand(3), np.random.rand(3), plates=(3,)))
        # Plates
        check(4, 3, V=Gamma(np.random.rand(5, 6, 1, 3), np.random.rand(5, 6, 1, 3), plates=(5, 6, 1, 3)))
        # Plates with moments broadcasted over plates
        check(4, 3, V=Gamma(np.random.rand(1, 3), np.random.rand(1, 3), plates=(5, 1, 3)))
        check(4, 3, V=Gamma(np.random.rand(1, 1, 3), np.random.rand(1, 1, 3), plates=(5, 1, 3)))
        # Plates broadcasting
        check(4, 3, plates=(5,), V=Gamma(np.random.rand(1, 3), np.random.rand(1, 3), plates=(1, 3)))
        check(4, 3, plates=(5,), V=Gamma(np.random.rand(1, 1, 3), np.random.rand(1, 1, 3), plates=(1, 1, 3)))

        #
        # Check with input signals
        #

        mu = 2
        Lambda = 3
        A = 4
        B = 5
        v = 6
        inputs = [[-2], [3]]
        X = GaussianMarkovChain([mu], [[Lambda]], [[A, B]], [v], inputs=inputs)
        V = np.array([[v * A ** 2, -v * A, 0], [-v * A, v * A ** 2, -v * A], [0, -v * A, 0]]) + np.array(
            [[Lambda, 0, 0], [0, v, 0], [0, 0, v]]
        )
        m = (
            np.array([Lambda * mu, 0, 0])
            + np.array([0, v * B * inputs[0][0], v * B * inputs[1][0]])
            - np.array([v * A * B * inputs[0][0], v * A * B * inputs[1][0], 0])
        )
        Cov = np.linalg.inv(V)
        mean = np.dot(Cov, m)

        X.update()
        u = X.get_moments()

        self.assertAllClose(u[0], mean[:, None])
        self.assertAllClose(u[1] - u[0][..., None, :] * u[0][..., :, None], Cov[(0, 1, 2), (0, 1, 2), None, None])
        self.assertAllClose(u[2] - u[0][..., :-1, :, None] * u[0][..., 1:, None, :], Cov[(0, 1), (1, 2), None, None])

        pass
Beispiel #3
0
    def _compute_bound(self, R, logdet=None, inv=None, Q=None, gradient=False, terms=False):
        """
        Rotate q(X) and q(alpha).

        Assume:
        p(X|alpha) = prod_m N(x_m|0,diag(alpha))
        p(alpha) = prod_d G(a_d,b_d)
        """

        ## R = self._full_rotation_matrix(R)
        ## if inv is not None:
        ##     inv = self._full_rotation_matrix(inv)

        #
        # Transform the distributions and moments
        #

        plates_alpha = self.plates_alpha
        plates_X = self.plates_X

        # Compute rotated second moment
        if self.plate_axis is not None:
            # The plate axis has been moved to be the last plate axis

            if Q is None:
                raise ValueError("Plates should be rotated but no Q give")

            # Transform covariance
            sumQ = np.sum(Q, axis=0)
            QCovQ = sumQ[:, None, None] ** 2 * self.CovX

            # Rotate plates
            if self.precompute:
                QX_QX = np.einsum("...kalb,...ik,...il->...iab", self.X_X, Q, Q)
                XX = QX_QX + QCovQ
                XX = sum_to_plates(XX, plates_alpha[:-1], ndim=2)
                Xmu = np.einsum("...kaib,...ik->...iab", self.X_mu, Q)
                Xmu = sum_to_plates(Xmu, plates_alpha[:-1], ndim=2)
            else:
                X = self.X
                mu = self.mu
                QX = np.einsum("...ik,...kj->...ij", Q, X)
                XX = sum_to_plates(QCovQ, plates_alpha[:-1], ndim=2) + sum_to_plates(
                    linalg.outer(QX, QX), plates_alpha[:-1], ndim=2, plates_from=plates_X
                )
                Xmu = sum_to_plates(linalg.outer(QX, self.mu), plates_alpha[:-1], ndim=2, plates_from=plates_X)

            mu2 = self.mu2
            D = np.shape(XX)[-1]
            logdet_Q = D * np.log(np.abs(sumQ))

        else:
            XX = self.XX
            mu2 = self.mu2
            Xmu = self.Xmu
            logdet_Q = 0

        # Compute transformed moments
        # mu2 = np.einsum('...ii->...i', mu2)
        RXmu = np.einsum("...ik,...ki->...i", R, Xmu)
        RXX = np.einsum("...ik,...kj->...ij", R, XX)
        RXXR = np.einsum("...ik,...ik->...i", RXX, R)

        # <(X-mu) * (X-mu)'>_R
        XmuXmu = RXXR - 2 * RXmu + mu2

        D = np.shape(R)[0]

        # Compute q(alpha)
        if self.update_alpha:
            # Parameters
            a0 = self.a0
            b0 = self.b0
            a = self.a
            b = b0 + 0.5 * sum_to_plates(XmuXmu, plates_alpha, plates_from=None, ndim=0)
            # Some expectations
            alpha = a / b
            logb = np.log(b)
            logalpha = -logb  # + const
            b0_alpha = b0 * alpha
            a0_logalpha = a0 * logalpha
        else:
            alpha = self.alpha
            logalpha = 0

        #
        # Compute the cost
        #

        def sum_plates(V, *plates):
            full_plates = misc.broadcasted_shape(*plates)

            r = self.node_X.broadcasting_multiplier(full_plates, np.shape(V))
            return r * np.sum(V)

        XmuXmu_alpha = XmuXmu * alpha

        if logdet is None:
            logdet_R = np.linalg.slogdet(R)[1]
            inv_R = np.linalg.inv(R)
        else:
            logdet_R = logdet
            inv_R = inv

        # Compute entropy H(X)
        logH_X = random.gaussian_entropy(-2 * sum_plates(logdet_R + logdet_Q, plates_X), 0)

        # Compute <log p(X|alpha)>
        logp_X = random.gaussian_logpdf(
            sum_plates(XmuXmu_alpha, plates_alpha[:-1] + [D]), 0, 0, sum_plates(logalpha, plates_X + [D]), 0
        )

        if self.update_alpha:

            # Compute entropy H(alpha)
            # This cancels out with the log(alpha) term in log(p(alpha))
            logH_alpha = 0

            # Compute <log p(alpha)>
            logp_alpha = random.gamma_logpdf(
                sum_plates(b0_alpha, plates_alpha), 0, sum_plates(a0_logalpha, plates_alpha), 0, 0
            )
        else:
            logH_alpha = 0
            logp_alpha = 0

        # Compute the bound
        if terms:
            bound = {self.node_X: logp_X + logH_X}
            if self.update_alpha:
                bound.update({self.node_alpha: logp_alpha + logH_alpha})
        else:
            bound = 0 + logp_X + logp_alpha + logH_X + logH_alpha

        if not gradient:
            return bound

        #
        # Compute the gradient with respect R
        #

        broadcasting_multiplier = self.node_X.broadcasting_multiplier

        def sum_plates(V, plates):
            ones = np.ones(np.shape(R))
            r = broadcasting_multiplier(plates, np.shape(V)[:-2])
            return r * misc.sum_multiply(V, ones, axis=(-1, -2), sumaxis=False, keepdims=False)

        D_XmuXmu = 2 * RXX - 2 * gaussian.transpose_covariance(Xmu)

        DXmuXmu_alpha = np.einsum("...i,...ij->...ij", alpha, D_XmuXmu)
        if self.update_alpha:
            D_b = 0.5 * D_XmuXmu
            XmuXmu_Dalpha = np.einsum(
                "...i,...i,...i,...ij->...ij",
                sum_to_plates(XmuXmu, plates_alpha, plates_from=None, ndim=0),
                alpha,
                -1 / b,
                D_b,
            )
            D_b0_alpha = np.einsum("...i,...i,...i,...ij->...ij", b0, alpha, -1 / b, D_b)
            D_logb = np.einsum("...i,...ij->...ij", 1 / b, D_b)
            D_logalpha = -D_logb
            D_a0_logalpha = a0 * D_logalpha
        else:
            XmuXmu_Dalpha = 0
            D_logalpha = 0

        D_XmuXmu_alpha = DXmuXmu_alpha + XmuXmu_Dalpha
        D_logR = inv_R.T

        # Compute dH(X)
        dlogH_X = random.gaussian_entropy(-2 * sum_plates(D_logR, plates_X), 0)

        # Compute d<log p(X|alpha)>
        dlogp_X = random.gaussian_logpdf(
            sum_plates(D_XmuXmu_alpha, plates_alpha[:-1]),
            0,
            0,
            (sum_plates(D_logalpha, plates_X) * broadcasting_multiplier((D,), plates_alpha[-1:])),
            0,
        )

        if self.update_alpha:

            # Compute dH(alpha)
            # This cancels out with the log(alpha) term in log(p(alpha))
            dlogH_alpha = 0

            # Compute d<log p(alpha)>
            dlogp_alpha = random.gamma_logpdf(
                sum_plates(D_b0_alpha, plates_alpha[:-1]), 0, sum_plates(D_a0_logalpha, plates_alpha[:-1]), 0, 0
            )
        else:
            dlogH_alpha = 0
            dlogp_alpha = 0

        if terms:
            raise NotImplementedError()
            dR_bound = {self.node_X: dlogp_X + dlogH_X}
            if self.update_alpha:
                dR_bound.update({self.node_alpha: dlogp_alpha + dlogH_alpha})
        else:
            dR_bound = 0 * dlogp_X + dlogp_X + dlogp_alpha + dlogH_X + dlogH_alpha

        if self.subset:
            indices = np.ix_(self.subset, self.subset)
            dR_bound = dR_bound[indices]

        if self.plate_axis is None:
            return (bound, dR_bound)

        #
        # Compute the gradient with respect to Q (if Q given)
        #

        # Some pre-computations
        Q_RCovR = np.einsum("...ik,...kl,...il,...->...i", R, self.CovX, R, sumQ)
        if self.precompute:
            Xr_rX = np.einsum("...abcd,...jb,...jd->...jac", self.X_X, R, R)
            QXr_rX = np.einsum("...akj,...ik->...aij", Xr_rX, Q)
            RX_mu = np.einsum("...jk,...akbj->...jab", R, self.X_mu)

        else:
            RX = np.einsum("...ik,...k->...i", R, X)
            QXR = np.einsum("...ik,...kj->...ij", Q, RX)
            QXr_rX = np.einsum("...ik,...jk->...kij", QXR, RX)
            RX_mu = np.einsum("...ik,...jk->...kij", RX, mu)

            QXr_rX = sum_to_plates(QXr_rX, plates_alpha[:-2], ndim=3, plates_from=plates_X[:-1])
            RX_mu = sum_to_plates(RX_mu, plates_alpha[:-2], ndim=3, plates_from=plates_X[:-1])

        def psi(v):
            """
            Compute: d/dQ 1/2*trace(diag(v)*<(X-mu)*(X-mu)>)

            = Q*<X>'*R'*diag(v)*R*<X> + ones * Q diag( tr(R'*diag(v)*R*Cov) ) 
              + mu*diag(v)*R*<X>
            """

            # Precompute all terms to plates_alpha because v has shape
            # plates_alpha.

            # Gradient of 0.5*v*<x>*<x>
            v_QXrrX = np.einsum("...kij,...ik->...ij", QXr_rX, v)

            # Gradient of 0.5*v*Cov
            Q_tr_R_v_R_Cov = np.einsum("...k,...k->...", Q_RCovR, v)[..., None, :]

            # Gradient of mu*v*x
            mu_v_R_X = np.einsum("...ik,...kji->...ij", v, RX_mu)

            return v_QXrrX + Q_tr_R_v_R_Cov - mu_v_R_X

        def sum_plates(V, plates):
            ones = np.ones(np.shape(Q))
            r = self.node_X.broadcasting_multiplier(plates, np.shape(V)[:-2])

            return r * misc.sum_multiply(V, ones, axis=(-1, -2), sumaxis=False, keepdims=False)

        if self.update_alpha:
            D_logb = psi(1 / b)
            XX_Dalpha = -psi(alpha / b * sum_to_plates(XmuXmu, plates_alpha))
            D_logalpha = -D_logb
        else:
            XX_Dalpha = 0
            D_logalpha = 0
        DXX_alpha = 2 * psi(alpha)
        D_XX_alpha = DXX_alpha + XX_Dalpha
        D_logdetQ = D / sumQ
        N = np.shape(Q)[-1]

        # Compute dH(X)
        dQ_logHX = random.gaussian_entropy(-2 * sum_plates(D_logdetQ, plates_X[:-1]), 0)

        # Compute d<log p(X|alpha)>
        dQ_logpX = random.gaussian_logpdf(
            sum_plates(D_XX_alpha, plates_alpha[:-2]),
            0,
            0,
            (sum_plates(D_logalpha, plates_X[:-1]) * broadcasting_multiplier((N, D), plates_alpha[-2:])),
            0,
        )

        if self.update_alpha:

            D_alpha = -psi(alpha / b)
            D_b0_alpha = b0 * D_alpha
            D_a0_logalpha = a0 * D_logalpha

            # Compute dH(alpha)
            # This cancels out with the log(alpha) term in log(p(alpha))
            dQ_logHalpha = 0

            # Compute d<log p(alpha)>
            dQ_logpalpha = random.gamma_logpdf(
                sum_plates(D_b0_alpha, plates_alpha[:-2]), 0, sum_plates(D_a0_logalpha, plates_alpha[:-2]), 0, 0
            )
        else:

            dQ_logHalpha = 0
            dQ_logpalpha = 0

        if terms:
            raise NotImplementedError()
            dQ_bound = {self.node_X: dQ_logpX + dQ_logHX}
            if self.update_alpha:
                dQ_bound.update({self.node_alpha: dQ_logpalpha + dQ_logHalpha})
        else:
            dQ_bound = 0 * dQ_logpX + dQ_logpX + dQ_logpalpha + dQ_logHX + dQ_logHalpha
        return (bound, dR_bound, dQ_bound)
Beispiel #4
0
    def _compute_bound(self, R, logdet=None, inv=None, gradient=False, terms=False):
        """
        Rotate q(X) as X->RX: q(X)=N(R*mu, R*Cov*R')

        Assume:
        :math:`p(\mathbf{X}) = \prod^M_{m=1} 
               N(\mathbf{x}_m|0, \mathbf{\Lambda})`

        Assume unit innovation noise covariance.
        """

        # TODO/FIXME: X and alpha should NOT contain observed values!! Check
        # that.

        # Assume constant mean and precision matrix over plates..

        if inv is None:
            invR = np.linalg.inv(R)
        else:
            invR = inv

        if logdet is None:
            logdetR = np.linalg.slogdet(R)[1]
        else:
            logdetR = logdet

        # Transform moments of X and A:

        Lambda_R_X0X0 = sum_to_plates(dot(self.Lambda, R, self.X0X0), (), plates_from=self.X_node.plates, ndim=2)
        R_XnXn = dot(R, self.XnXn)
        RA_XpXp_A = dot(R, self.A_XpXp_A)
        sumr = np.sum(R, axis=0)
        R_CovA_XpXp = sumr * self.CovA_XpXp

        # Compute entropy H(X)
        M = self.N * np.prod(self.X_node.plates)  # total number of rotated vectors
        logH_X = random.gaussian_entropy(-2 * M * logdetR, 0)

        # Compute <log p(X)>
        yy = tracedot(R_XnXn, R.T) + tracedot(Lambda_R_X0X0, R.T)
        yz = tracedot(dot(R, self.A_XpXn), R.T) + tracedot(self.Lambda_mu_X0, R.T)
        zz = tracedot(RA_XpXp_A, R.T) + np.einsum("...k,...k->...", R_CovA_XpXp, sumr)
        logp_X = random.gaussian_logpdf(yy, yz, zz, 0, 0)

        # Compute the bound
        if terms:
            bound = {self.X_node: logp_X + logH_X}
        else:
            bound = logp_X + logH_X

        if not gradient:
            return bound

        # Compute dH(X)
        dlogH_X = random.gaussian_entropy(-2 * M * invR.T, 0)

        # Compute d<log p(X)>
        dyy = 2 * (R_XnXn + Lambda_R_X0X0)
        dyz = dot(R, self.A_XpXn + self.A_XpXn.T) + self.Lambda_mu_X0
        dzz = 2 * (RA_XpXp_A + R_CovA_XpXp[None, :])
        dlogp_X = random.gaussian_logpdf(dyy, dyz, dzz, 0, 0)

        if terms:
            d_bound = {self.X_node: dlogp_X + dlogH_X}
        else:
            d_bound = +dlogp_X + dlogH_X

        return (bound, d_bound)
    def test_message_to_child(self):
        """
        Test the updating of GaussianMarkovChain.

        Check that the moments and the lower bound contribution are computed
        correctly.
        """

        # TODO: Add plates and missing values!

        # Dimensionalities
        D = 3
        N = 5
        (Y, X, Mu, Lambda, A, V) = self.create_model(N, D)

        # Inference with arbitrary observations
        y = np.random.randn(N,D)
        Y.observe(y)
        X.update()
        (x_vb, xnxn_vb, xpxn_vb) = X.get_moments()

        # Get parameter moments
        (mu0, mumu0) = Mu.get_moments()
        (icov0, logdet0) = Lambda.get_moments()
        (a, aa) = A.get_moments()
        (icov_x, logdetx) = V.get_moments()
        icov_x = np.diag(icov_x)
        # Prior precision
        Z = np.einsum('...kij,...kk->...ij', aa, icov_x)
        U_diag = [icov0+Z] + (N-2)*[icov_x+Z] + [icov_x]
        U_super = (N-1) * [-np.dot(a.T, icov_x)]
        U = misc.block_banded(U_diag, U_super)
        # Prior mean
        mu_prior = np.zeros(D*N)
        mu_prior[:D] = np.dot(icov0,mu0)
        # Data 
        Cov = np.linalg.inv(U + np.identity(D*N))
        mu = np.dot(Cov, mu_prior + y.flatten())
        # Moments
        xx = mu[:,np.newaxis]*mu[np.newaxis,:] + Cov
        mu = np.reshape(mu, (N,D))
        xx = np.reshape(xx, (N,D,N,D))

        # Check results
        self.assertAllClose(x_vb, mu,
                            msg="Incorrect mean")
        for n in range(N):
            self.assertAllClose(xnxn_vb[n,:,:], xx[n,:,n,:],
                                msg="Incorrect second moment")
        for n in range(N-1):
            self.assertAllClose(xpxn_vb[n,:,:], xx[n,:,n+1,:],
                                msg="Incorrect lagged second moment")


        # Compute the entropy H(X)
        ldet = linalg.logdet_cov(Cov)
        H = random.gaussian_entropy(-ldet, N*D)
        # Compute <log p(X|...)>
        xx = np.reshape(xx, (N*D, N*D))
        mu = np.reshape(mu, (N*D,))
        ldet = -logdet0 - np.sum(np.ones((N-1,D))*logdetx)
        P = random.gaussian_logpdf(np.einsum('...ij,...ij', 
                                                   xx, 
                                                   U),
                                         np.einsum('...i,...i', 
                                                   mu, 
                                                   mu_prior),
                                         np.einsum('...ij,...ij', 
                                                   mumu0,
                                                   icov0),
                                         -ldet,
                                         N*D)
                                                   
        # The VB bound from the net
        l = X.lower_bound_contribution()

        self.assertAllClose(l, H+P)
                                                   

        # Compute the true bound <log p(X|...)> + H(X)


        #
        # Simple tests
        #

        def check(N, D, plates=None, mu=None, Lambda=None, A=None, V=None):
            if mu is None:
                mu = np.random.randn(D)
            if Lambda is None:
                Lambda = random.covariance(D)
            if A is None:
                A = np.random.randn(D,D)
            if V is None:
                V = np.random.rand(D)
            X = GaussianMarkovChain(mu,
                                    Lambda,
                                    A,
                                    V,
                                    plates=plates,
                                    n=N)
            (u0, u1, u2) = X._message_to_child()
            (mu, mumu) = Gaussian._ensure_moments(mu, GaussianMoments, ndim=1).get_moments()
            (Lambda, _) = Wishart._ensure_moments(Lambda, WishartMoments, ndim=1).get_moments()
            (a, aa) = Gaussian._ensure_moments(A, GaussianMoments, ndim=1).get_moments()
            a = a * np.ones((N-1,D,D))     # explicit broadcasting for simplicity
            aa = aa * np.ones((N-1,D,D,D)) # explicit broadcasting for simplicity
            (v, _) = Gamma._ensure_moments(V, GammaMoments).get_moments()
            v = v * np.ones((N-1,D))
            plates_C = X.plates
            plates_mu = X.plates
            C = np.zeros(plates_C + (N,D,N,D))
            plates_mu = np.shape(mu)[:-1]
            m = np.zeros(plates_mu + (N,D))
            m[...,0,:] = np.einsum('...ij,...j->...i', Lambda, mu)
            C[...,0,:,0,:] = Lambda + np.einsum('...dij,...d->...ij',
                                                aa[...,0,:,:,:],
                                                v[...,0,:])
            for n in range(1,N-1):
                C[...,n,:,n,:] = (np.einsum('...dij,...d->...ij',
                                            aa[...,n,:,:,:],
                                            v[...,n,:])
                                  + v[...,n,:,None] * np.identity(D))
            for n in range(N-1):
                C[...,n,:,n+1,:] = -np.einsum('...di,...d->...id',
                                              a[...,n,:,:],
                                              v[...,n,:])
                C[...,n+1,:,n,:] = -np.einsum('...di,...d->...di',
                                              a[...,n,:,:],
                                              v[...,n,:])
            C[...,-1,:,-1,:] = v[...,-1,:,None]*np.identity(D)
            C = np.reshape(C, plates_C+(N*D,N*D))
            Cov = np.linalg.inv(C)
            Cov = np.reshape(Cov, plates_C+(N,D,N,D))
            m0 = np.einsum('...minj,...nj->...mi', Cov, m)
            m1 = np.zeros(plates_C+(N,D,D))
            m2 = np.zeros(plates_C+(N-1,D,D))
            for n in range(N):
                m1[...,n,:,:] = Cov[...,n,:,n,:] + np.einsum('...i,...j->...ij',
                                                             m0[...,n,:],
                                                             m0[...,n,:])
            for n in range(N-1):
                m2[...,n,:,:] = Cov[...,n,:,n+1,:] + np.einsum('...i,...j->...ij',
                                                               m0[...,n,:],
                                                               m0[...,n+1,:])
            self.assertAllClose(m0, u0*np.ones(np.shape(m0)))
            self.assertAllClose(m1, u1*np.ones(np.shape(m1)))
            self.assertAllClose(m2, u2*np.ones(np.shape(m2)))

            pass

        check(4,1)
        check(4,3)

        #
        # Test mu
        #

        # Simple
        check(4,3,
              mu=Gaussian(np.random.randn(3),
                          random.covariance(3)))
        # Plates
        check(4,3,
              mu=Gaussian(np.random.randn(5,6,3),
                          random.covariance(3),
                          plates=(5,6)))
        # Plates with moments broadcasted over plates
        check(4,3,
              mu=Gaussian(np.random.randn(3),
                          random.covariance(3),
                          plates=(5,)))
        check(4,3,
              mu=Gaussian(np.random.randn(1,3),
                          random.covariance(3),
                          plates=(5,)))
        # Plates broadcasting
        check(4,3,
              plates=(5,),
              mu=Gaussian(np.random.randn(3),
                          random.covariance(3),
                          plates=()))
        check(4,3,
              plates=(5,),
              mu=Gaussian(np.random.randn(1,3),
                          random.covariance(3),
                          plates=(1,)))

        #
        # Test Lambda
        #
            
        # Simple
        check(4,3,
              Lambda=Wishart(10+np.random.rand(),
                             random.covariance(3)))
        # Plates
        check(4,3,
              Lambda=Wishart(10+np.random.rand(),
                             random.covariance(3),
                             plates=(5,6)))
        # Plates with moments broadcasted over plates
        check(4,3,
              Lambda=Wishart(10+np.random.rand(),
                             random.covariance(3),
                             plates=(5,)))
        check(4,3,
              Lambda=Wishart(10+np.random.rand(1),
                             random.covariance(3),
                             plates=(5,)))
        # Plates broadcasting
        check(4,3,
              plates=(5,),
              Lambda=Wishart(10+np.random.rand(),
                             random.covariance(3),
                             plates=()))
        check(4,3,
              plates=(5,),
              Lambda=Wishart(10+np.random.rand(),
                             random.covariance(3),
                             plates=(1,)))

        #
        # Test A
        #

        # Simple
        check(4,3,
              A=GaussianARD(np.random.randn(3,3),
                            np.random.rand(3,3),
                            shape=(3,),
                            plates=(3,)))
        # Plates on time axis
        check(5,3,
              A=GaussianARD(np.random.randn(4,3,3),
                            np.random.rand(4,3,3),
                            shape=(3,),
                            plates=(4,3)))
        # Plates on time axis with broadcasted moments
        check(5,3,
              A=GaussianARD(np.random.randn(1,3,3),
                            np.random.rand(1,3,3),
                            shape=(3,),
                            plates=(4,3)))
        check(5,3,
              A=GaussianARD(np.random.randn(3,3),
                            np.random.rand(3,3),
                            shape=(3,),
                            plates=(4,3)))
        # Plates
        check(4,3,
              A=GaussianARD(np.random.randn(5,6,1,3,3),
                            np.random.rand(5,6,1,3,3),
                            shape=(3,),
                            plates=(5,6,1,3)))
        # Plates with moments broadcasted over plates
        check(4,3,
              A=GaussianARD(np.random.randn(3,3),
                            np.random.rand(3,3),
                            shape=(3,),
                            plates=(5,1,3)))
        check(4,3,
              A=GaussianARD(np.random.randn(1,1,3,3),
                            np.random.rand(1,1,3,3),
                            shape=(3,),
                            plates=(5,1,3)))
        # Plates broadcasting
        check(4,3,
              plates=(5,),
              A=GaussianARD(np.random.randn(3,3),
                            np.random.rand(3,3),
                            shape=(3,),
                            plates=(3,)))
        check(4,3,
              plates=(5,),
              A=GaussianARD(np.random.randn(3,3),
                            np.random.rand(3,3),
                            shape=(3,),
                            plates=(1,1,3)))

        #
        # Test v
        #
        
        # Simple
        check(4,3,
              V=Gamma(np.random.rand(1,3),
                      np.random.rand(1,3),
                      plates=(1,3)))
        check(4,3,
              V=Gamma(np.random.rand(3),
                      np.random.rand(3),
                      plates=(3,)))
        # Plates
        check(4,3,
              V=Gamma(np.random.rand(5,6,1,3),
                      np.random.rand(5,6,1,3),
                      plates=(5,6,1,3)))
        # Plates with moments broadcasted over plates
        check(4,3,
              V=Gamma(np.random.rand(1,3),
                      np.random.rand(1,3),
                      plates=(5,1,3)))
        check(4,3,
              V=Gamma(np.random.rand(1,1,3),
                      np.random.rand(1,1,3),
                      plates=(5,1,3)))
        # Plates broadcasting
        check(4,3,
              plates=(5,),
              V=Gamma(np.random.rand(1,3),
                      np.random.rand(1,3),
                      plates=(1,3)))
        check(4,3,
              plates=(5,),
              V=Gamma(np.random.rand(1,1,3),
                      np.random.rand(1,1,3),
                      plates=(1,1,3)))

        #
        # Check with input signals
        #

        mu = 2
        Lambda = 3
        A = 4
        B = 5
        v = 6
        inputs = [[-2], [3]]
        X = GaussianMarkovChain([mu], [[Lambda]], [[A,B]], [v], inputs=inputs)
        V = (np.array([[v*A**2, -v*A,    0],
                       [-v*A,    v*A**2, -v*A],
                       [0,       -v*A,   0]]) +
             np.array([[Lambda, 0, 0],
                       [0,      v, 0],
                       [0,      0, v]]))
        m = (np.array([Lambda*mu, 0, 0]) +
             np.array([0, v*B*inputs[0][0], v*B*inputs[1][0]]) -
             np.array([v*A*B*inputs[0][0], v*A*B*inputs[1][0], 0]))
        Cov = np.linalg.inv(V)
        mean = np.dot(Cov, m)

        X.update()
        u = X.get_moments()

        self.assertAllClose(u[0], mean[:,None])
        self.assertAllClose(u[1] - u[0][...,None,:]*u[0][...,:,None],
                            Cov[(0,1,2),(0,1,2),None,None])
        self.assertAllClose(u[2] - u[0][...,:-1,:,None]*u[0][...,1:,None,:],
                            Cov[(0,1),(1,2),None,None])

        pass
    def test_moments(self):
        """
        Test the updating of GaussianMarkovChain.

        Check that the moments and the lower bound contribution are computed
        correctly.
        """

        # TODO: Add plates and missing values!

        # Dimensionalities
        D = 3
        N = 5
        (Y, X, Mu, Lambda, A, V) = self.create_model(N, D)

        # Inference with arbitrary observations
        y = np.random.randn(N,D)
        Y.observe(y)
        X.update()
        (x_vb, xnxn_vb, xpxn_vb) = X.get_moments()

        # Get parameter moments
        (mu0, mumu0) = Mu.get_moments()
        (icov0, logdet0) = Lambda.get_moments()
        (a, aa) = A.get_moments()
        (icov_x, logdetx) = V.get_moments()
        icov_x = np.diag(icov_x)
        # Prior precision
        Z = np.einsum('...kij,...kk->...ij', aa, icov_x)
        U_diag = [icov0+Z] + (N-2)*[icov_x+Z] + [icov_x]
        U_super = (N-1) * [-np.dot(a.T, icov_x)]
        U = utils.block_banded(U_diag, U_super)
        # Prior mean
        mu_prior = np.zeros(D*N)
        mu_prior[:D] = np.dot(icov0,mu0)
        # Data 
        Cov = np.linalg.inv(U + np.identity(D*N))
        mu = np.dot(Cov, mu_prior + y.flatten())
        # Moments
        xx = mu[:,np.newaxis]*mu[np.newaxis,:] + Cov
        mu = np.reshape(mu, (N,D))
        xx = np.reshape(xx, (N,D,N,D))

        # Check results
        testing.assert_allclose(x_vb, mu,
                                err_msg="Incorrect mean")
        for n in range(N):
            testing.assert_allclose(xnxn_vb[n,:,:], xx[n,:,n,:],
                                    err_msg="Incorrect second moment")
        for n in range(N-1):
            testing.assert_allclose(xpxn_vb[n,:,:], xx[n,:,n+1,:],
                                    err_msg="Incorrect lagged second moment")


        # Compute the entropy H(X)
        ldet = linalg.logdet_cov(Cov)
        H = random.gaussian_entropy(-ldet, N*D)
        # Compute <log p(X|...)>
        xx = np.reshape(xx, (N*D, N*D))
        mu = np.reshape(mu, (N*D,))
        ldet = -logdet0 - np.sum(np.ones((N-1,D))*logdetx)
        P = random.gaussian_logpdf(np.einsum('...ij,...ij', 
                                                   xx, 
                                                   U),
                                         np.einsum('...i,...i', 
                                                   mu, 
                                                   mu_prior),
                                         np.einsum('...ij,...ij', 
                                                   mumu0,
                                                   icov0),
                                         -ldet,
                                         N*D)
                                                   
        # The VB bound from the net
        l = X.lower_bound_contribution()

        testing.assert_allclose(l, H+P)