Ejemplo n.º 1
0
    def setup(self):
        """
        This method should be called just before optimization.
        """

        # Get moments of X
        (X, XnXn, XpXn) = self.X_node.get_moments()

        # TODO/FIXME: Sum to plates of A/CovA
        XpXp = XnXn[..., :-1, :, :]

        #
        # Expectations with respect to X
        #

        self.X0 = X[..., 0, :]
        self.X0X0 = XnXn[..., 0, :, :]
        # self.XnXn = np.sum(XnXn[...,1:,:,:], axis=-3)
        self.XnXn = sum_to_plates(XnXn[..., 1:, :, :], (), plates_from=self.X_node.plates + (self.N - 1,), ndim=2)

        # Get moments of the fixed parameter nodes
        mu = self.X_node.parents[0].get_moments()[0]
        self.Lambda = self.X_node.parents[1].get_moments()[0]
        self.Lambda_mu_X0 = linalg.outer(np.einsum("...ik,...k->...i", self.Lambda, mu), self.X0)
        self.Lambda_mu_X0 = sum_to_plates(self.Lambda_mu_X0, (), plates_from=self.X_node.plates, ndim=2)

        #
        # Prepare the rotation for A
        #

        (self.A_XpXn, self.A_XpXp_A, self.CovA_XpXp) = self._computations_for_A_and_X(XpXn, XpXp)

        self.A_rotator.setup(plate_axis=-1)
Ejemplo n.º 2
0
    def test_rotate_plates(self):

        # Basic test for Gaussian vectors
        X = GaussianARD(np.random.randn(3,2),
                        np.random.rand(3,2),
                        shape=(2,),
                        plates=(3,))
        (u0, u1) = X.get_moments()
        Cov = u1 - linalg.outer(u0, u0, ndim=1)
        Q = np.random.randn(3,3)
        Qu0 = np.einsum('ik,kj->ij', Q, u0)
        QCov = np.einsum('k,kij->kij', np.sum(Q, axis=0)**2, Cov)
        Qu1 = QCov + linalg.outer(Qu0, Qu0, ndim=1)
        X.rotate_plates(Q, plate_axis=-1)
        (u0, u1) = X.get_moments()
        self.assertAllClose(u0, Qu0)
        self.assertAllClose(u1, Qu1)

        # Test full covariance, that is, with observations
        X = GaussianARD(np.random.randn(3,2),
                        np.random.rand(3,2),
                        shape=(2,),
                        plates=(3,))
        Y = Gaussian(X, [[2.0, 1.5], [1.5, 3.0]],
                     plates=(3,))
        Y.observe(np.random.randn(3,2))
        X.update()
        (u0, u1) = X.get_moments()
        Cov = u1 - linalg.outer(u0, u0, ndim=1)
        Q = np.random.randn(3,3)
        Qu0 = np.einsum('ik,kj->ij', Q, u0)
        QCov = np.einsum('k,kij->kij', np.sum(Q, axis=0)**2, Cov)
        Qu1 = QCov + linalg.outer(Qu0, Qu0, ndim=1)
        X.rotate_plates(Q, plate_axis=-1)
        (u0, u1) = X.get_moments()
        self.assertAllClose(u0, Qu0)
        self.assertAllClose(u1, Qu1)

        pass
Ejemplo n.º 3
0
    def test_initialization(self):
        """
        Test initialization methods of GaussianARD
        """

        X = GaussianARD(1, 2, shape=(2, ), plates=(3, ))

        # Prior initialization
        mu = 1 * np.ones((3, 2))
        alpha = 2 * np.ones((3, 2))
        X.initialize_from_prior()
        u = X._message_to_child()
        self.assertAllClose(u[0] * np.ones((3, 2)), mu)
        self.assertAllClose(
            u[1] * np.ones((3, 2, 2)),
            linalg.outer(mu, mu, ndim=1) + misc.diag(1 / alpha, ndim=1))

        # Parameter initialization
        mu = np.random.randn(3, 2)
        alpha = np.random.rand(3, 2)
        X.initialize_from_parameters(mu, alpha)
        u = X._message_to_child()
        self.assertAllClose(u[0], mu)
        self.assertAllClose(
            u[1],
            linalg.outer(mu, mu, ndim=1) + misc.diag(1 / alpha, ndim=1))

        # Value initialization
        x = np.random.randn(3, 2)
        X.initialize_from_value(x)
        u = X._message_to_child()
        self.assertAllClose(u[0], x)
        self.assertAllClose(u[1], linalg.outer(x, x, ndim=1))

        # Random initialization
        X.initialize_from_random()

        pass
Ejemplo n.º 4
0
    def test_initialization(self):
        """
        Test initialization methods of GaussianARD
        """

        X = GaussianARD(1, 2, shape=(2,), plates=(3,))

        # Prior initialization
        mu = 1 * np.ones((3, 2))
        alpha = 2 * np.ones((3, 2))
        X.initialize_from_prior()
        u = X._message_to_child()
        self.assertAllClose(u[0]*np.ones((3,2)), 
                            mu)
        self.assertAllClose(u[1]*np.ones((3,2,2)), 
                            linalg.outer(mu, mu, ndim=1) + 
                            misc.diag(1/alpha, ndim=1))

        # Parameter initialization
        mu = np.random.randn(3, 2)
        alpha = np.random.rand(3, 2)
        X.initialize_from_parameters(mu, alpha)
        u = X._message_to_child()
        self.assertAllClose(u[0], mu)
        self.assertAllClose(u[1], linalg.outer(mu, mu, ndim=1) + 
                                  misc.diag(1/alpha, ndim=1))

        # Value initialization
        x = np.random.randn(3, 2)
        X.initialize_from_value(x)
        u = X._message_to_child()
        self.assertAllClose(u[0], x)
        self.assertAllClose(u[1], linalg.outer(x, x, ndim=1))

        # Random initialization
        X.initialize_from_random()

        pass
Ejemplo n.º 5
0
    def _compute_moments(self, *u_parents):
        """
        Compute the moments of the sum
        """

        u0 = functools.reduce(np.add, (u_parent[0] for u_parent in u_parents))
        u1 = functools.reduce(np.add, (u_parent[1] for u_parent in u_parents))

        for i in range(self.N):
            for j in range(i + 1, self.N):
                xi_xj = linalg.outer(u_parents[i][0],
                                     u_parents[j][0],
                                     ndim=self.ndim)
                xj_xi = linalg.transpose(xi_xj, ndim=self.ndim)
                u1 = u1 + xi_xj + xj_xi

        return [u0, u1]
Ejemplo n.º 6
0
    def _compute_moments(self, *u_parents):
        """
        Compute the moments of the sum
        """

        u0 = functools.reduce(np.add,
                              (u_parent[0] for u_parent in u_parents))
        u1 = functools.reduce(np.add,
                              (u_parent[1] for u_parent in u_parents))

        for i in range(self.N):
            for j in range(i+1, self.N):
                xi_xj = linalg.outer(u_parents[i][0], u_parents[j][0], ndim=self.ndim)
                xj_xi = linalg.transpose(xi_xj, ndim=self.ndim)
                u1 = u1 + xi_xj + xj_xi
                                                                     
        return [u0, u1]
Ejemplo n.º 7
0
    def _compute_moments(self, *u_nodes):
        x = misc.concatenate(*[u[0] for u in u_nodes], axis=-1)
        xx = misc.block_diag(*[u[1] for u in u_nodes])

        # Explicitly broadcast xx to plates of x
        x_plates = np.shape(x)[:-1]
        xx = np.ones(x_plates)[..., None, None] * xx

        # Compute the cross-covariance terms using the means of each variable
        # (because covariances are zero for factorized nodes in the VB
        # approximation)
        i_start = 0
        for m in range(len(u_nodes)):
            i_end = i_start + np.shape(u_nodes[m][0])[-1]
            j_start = 0
            for n in range(m):
                j_end = j_start + np.shape(u_nodes[n][0])[-1]
                xm_xn = linalg.outer(u_nodes[m][0], u_nodes[n][0], ndim=1)
                xx[..., i_start:i_end, j_start:j_end] = xm_xn
                xx[..., j_start:j_end, i_start:i_end] = misc.T(xm_xn)
                j_start = j_end
            i_start = i_end

        return [x, xx]
Ejemplo n.º 8
0
    def _compute_moments(self, *u_nodes):
        x = misc.concatenate(*[u[0] for u in u_nodes], axis=-1)
        xx = misc.block_diag(*[u[1] for u in u_nodes])

        # Explicitly broadcast xx to plates of x
        x_plates = np.shape(x)[:-1]
        xx = np.ones(x_plates)[...,None,None] * xx

        # Compute the cross-covariance terms using the means of each variable
        # (because covariances are zero for factorized nodes in the VB
        # approximation)
        i_start = 0
        for m in range(len(u_nodes)):
            i_end = i_start + np.shape(u_nodes[m][0])[-1]
            j_start = 0
            for n in range(m):
                j_end = j_start + np.shape(u_nodes[n][0])[-1]
                xm_xn = linalg.outer(u_nodes[m][0], u_nodes[n][0], ndim=1)
                xx[...,i_start:i_end,j_start:j_end] = xm_xn
                xx[...,j_start:j_end,i_start:i_end] = misc.T(xm_xn)
                j_start = j_end
            i_start = i_end

        return [x, xx]
Ejemplo n.º 9
0
    def test_message_to_child(self):
        """
        Test the message from SumMultiply to its children.
        """

        def compare_moments(u0, u1, *args):
            Y = SumMultiply(*args)
            u_Y = Y.get_moments()
            self.assertAllClose(u_Y[0], u0)
            self.assertAllClose(u_Y[1], u1)

        # Test constant parent
        y = np.random.randn(2,3,4)
        compare_moments(y,
                        linalg.outer(y, y, ndim=2),
                        'ij->ij',
                        y)

        # Do nothing for 2-D array
        Y = GaussianARD(np.random.randn(5,2,3),
                        np.random.rand(5,2,3),
                        plates=(5,),
                        shape=(2,3))
        y = Y.get_moments()
        compare_moments(y[0],
                        y[1],
                        'ij->ij',
                        Y)
        compare_moments(y[0],
                        y[1],
                        Y,
                        [0,1],
                        [0,1])

        # Sum over the rows of a matrix
        Y = GaussianARD(np.random.randn(5,2,3),
                        np.random.rand(5,2,3),
                        plates=(5,),
                        shape=(2,3))
        y = Y.get_moments()
        mu = np.einsum('...ij->...j', y[0])
        cov = np.einsum('...ijkl->...jl', y[1])
        compare_moments(mu,
                        cov,
                        'ij->j',
                        Y)
        compare_moments(mu,
                        cov,
                        Y,
                        [0,1],
                        [1])

        # Inner product of three vectors
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         plates=(),
                         shape=(2,))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(6,1,2),
                         np.random.rand(6,1,2),
                         plates=(6,1),
                         shape=(2,))
        x2 = X2.get_moments()
        X3 = GaussianARD(np.random.randn(7,6,5,2),
                         np.random.rand(7,6,5,2),
                         plates=(7,6,5),
                         shape=(2,))
        x3 = X3.get_moments()
        mu = np.einsum('...i,...i,...i->...', x1[0], x2[0], x3[0])
        cov = np.einsum('...ij,...ij,...ij->...', x1[1], x2[1], x3[1])
        compare_moments(mu,
                        cov,
                        'i,i,i',
                        X1,
                        X2,
                        X3)
        compare_moments(mu,
                        cov,
                        'i,i,i->',
                        X1,
                        X2,
                        X3)
        compare_moments(mu,
                        cov,
                        X1,
                        [9],
                        X2,
                        [9],
                        X3,
                        [9])
        compare_moments(mu,
                        cov,
                        X1,
                        [9],
                        X2,
                        [9],
                        X3,
                        [9],
                        [])
                            

        # Outer product of two vectors
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         plates=(5,),
                         shape=(2,))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(6,1,2),
                         np.random.rand(6,1,2),
                         plates=(6,1),
                         shape=(2,))
        x2 = X2.get_moments()
        mu = np.einsum('...i,...j->...ij', x1[0], x2[0])
        cov = np.einsum('...ik,...jl->...ijkl', x1[1], x2[1])
        compare_moments(mu,
                        cov,
                        'i,j->ij',
                        X1,
                        X2)
        compare_moments(mu,
                        cov,
                        X1,
                        [9],
                        X2,
                        [7],
                        [9,7])

        # Matrix product
        Y1 = GaussianARD(np.random.randn(3,2),
                         np.random.rand(3,2),
                         plates=(),
                         shape=(3,2))
        y1 = Y1.get_moments()
        Y2 = GaussianARD(np.random.randn(5,2,3),
                         np.random.rand(5,2,3),
                         plates=(5,),
                         shape=(2,3))
        y2 = Y2.get_moments()
        mu = np.einsum('...ik,...kj->...ij', y1[0], y2[0])
        cov = np.einsum('...ikjl,...kmln->...imjn', y1[1], y2[1])
        compare_moments(mu,
                        cov,
                        'ik,kj->ij',
                        Y1,
                        Y2)
        compare_moments(mu,
                        cov,
                        Y1,
                        ['i','k'],
                        Y2,
                        ['k','j'],
                        ['i','j'])

        # Trace of a matrix product
        Y1 = GaussianARD(np.random.randn(3,2),
                         np.random.rand(3,2),
                         plates=(),
                         shape=(3,2))
        y1 = Y1.get_moments()
        Y2 = GaussianARD(np.random.randn(5,2,3),
                         np.random.rand(5,2,3),
                         plates=(5,),
                         shape=(2,3))
        y2 = Y2.get_moments()
        mu = np.einsum('...ij,...ji->...', y1[0], y2[0])
        cov = np.einsum('...ikjl,...kilj->...', y1[1], y2[1])
        compare_moments(mu,
                        cov,
                        'ij,ji',
                        Y1,
                        Y2)
        compare_moments(mu,
                        cov,
                        'ij,ji->',
                        Y1,
                        Y2)
        compare_moments(mu,
                        cov,
                        Y1,
                        ['i','j'],
                        Y2,
                        ['j','i'])
        compare_moments(mu,
                        cov,
                        Y1,
                        ['i','j'],
                        Y2,
                        ['j','i'],
                        [])

        # Vector-matrix-vector product
        X1 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         plates=(),
                         shape=(3,))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(6,1,2),
                         np.random.rand(6,1,2),
                         plates=(6,1),
                         shape=(2,))
        x2 = X2.get_moments()
        Y = GaussianARD(np.random.randn(3,2),
                        np.random.rand(3,2),
                        plates=(),
                        shape=(3,2))
        y = Y.get_moments()
        mu = np.einsum('...i,...ij,...j->...', x1[0], y[0], x2[0])
        cov = np.einsum('...ia,...ijab,...jb->...', x1[1], y[1], x2[1])
        compare_moments(mu,
                        cov,
                        'i,ij,j',
                        X1,
                        Y,
                        X2)
        compare_moments(mu,
                        cov,
                        X1,
                        [1],
                        Y,
                        [1,2],
                        X2,
                        [2])
        
        # Complex sum-product of 0-D, 1-D, 2-D and 3-D arrays
        V = GaussianARD(np.random.randn(7,6,5),
                        np.random.rand(7,6,5),
                        plates=(7,6,5),
                        shape=())
        v = V.get_moments()
        X = GaussianARD(np.random.randn(6,1,2),
                        np.random.rand(6,1,2),
                        plates=(6,1),
                        shape=(2,))
        x = X.get_moments()
        Y = GaussianARD(np.random.randn(3,4),
                        np.random.rand(3,4),
                        plates=(5,),
                        shape=(3,4))
        y = Y.get_moments()
        Z = GaussianARD(np.random.randn(4,2,3),
                        np.random.rand(4,2,3),
                        plates=(6,5),
                        shape=(4,2,3))
        z = Z.get_moments()
        mu = np.einsum('...,...i,...kj,...jik->...k', v[0], x[0], y[0], z[0])
        cov = np.einsum('...,...ia,...kjcb,...jikbac->...kc', v[1], x[1], y[1], z[1])
        compare_moments(mu,
                        cov,
                        ',i,kj,jik->k',
                        V,
                        X,
                        Y,
                        Z)
        compare_moments(mu,
                        cov,
                        V,
                        [],
                        X,
                        ['i'],
                        Y,
                        ['k','j'],
                        Z,
                        ['j','i','k'],
                        ['k'])

        #
        # Gaussian-gamma parents
        #

        # Outer product of vectors
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         shape=(2,))
        x1 = X1.get_moments()
        X2 = GaussianGamma(
            np.random.randn(6,1,2),
            random.covariance(2),
            np.random.rand(6,1),
            np.random.rand(6,1),
            plates=(6,1)
        )
        x2 = X2.get_moments()
        Y = SumMultiply('i,j->ij', X1, X2)
        u = Y._message_to_child()
        y = np.einsum('...i,...j->...ij', x1[0], x2[0])
        yy = np.einsum('...ik,...jl->...ijkl', x1[1], x2[1])
        self.assertAllClose(u[0], y)
        self.assertAllClose(u[1], yy)
        self.assertAllClose(u[2], x2[2])
        self.assertAllClose(u[3], x2[3])

        pass
Ejemplo n.º 10
0
    def test_message_to_child(self):
        """
        Test the message to child of GaussianGammaISO node.
        """

        # Simple test
        mu = np.array([1, 2, 3])
        Lambda = np.identity(3)
        a = 2
        b = 10
        X_alpha = GaussianGammaISO(mu, Lambda, a, b)
        u = X_alpha._message_to_child()
        self.assertEqual(len(u), 4)
        tau = np.array(a / b)
        self.assertAllClose(u[0], tau[..., None] * mu)
        self.assertAllClose(
            u[1],
            (linalg.inv(Lambda) + tau[..., None, None] * linalg.outer(mu, mu)))
        self.assertAllClose(u[2], tau)
        self.assertAllClose(u[3], -np.log(b) + special.psi(a))

        # Test with unknown parents
        mu = Gaussian(np.arange(3), 10 * np.identity(3))
        Lambda = Wishart(10, np.identity(3))
        a = 2
        b = Gamma(3, 15)
        X_alpha = GaussianGammaISO(mu, Lambda, a, b)
        u = X_alpha._message_to_child()
        (mu, mumu) = mu._message_to_child()
        Cov_mu = mumu - linalg.outer(mu, mu)
        (Lambda, _) = Lambda._message_to_child()
        (b, _) = b._message_to_child()
        (tau, logtau) = Gamma(
            a, b + 0.5 * np.sum(Lambda * Cov_mu))._message_to_child()
        self.assertAllClose(u[0], tau[..., None] * mu)
        self.assertAllClose(
            u[1],
            (linalg.inv(Lambda) + tau[..., None, None] * linalg.outer(mu, mu)))
        self.assertAllClose(u[2], tau)
        self.assertAllClose(u[3], logtau)

        # Test with plates
        mu = Gaussian(np.reshape(np.arange(3 * 4), (4, 3)),
                      10 * np.identity(3),
                      plates=(4, ))
        Lambda = Wishart(10, np.identity(3))
        a = 2
        b = Gamma(3, 15)
        X_alpha = GaussianGammaISO(mu, Lambda, a, b, plates=(4, ))
        u = X_alpha._message_to_child()
        (mu, mumu) = mu._message_to_child()
        Cov_mu = mumu - linalg.outer(mu, mu)
        (Lambda, _) = Lambda._message_to_child()
        (b, _) = b._message_to_child()
        (tau, logtau) = Gamma(
            a, b +
            0.5 * np.sum(Lambda * Cov_mu, axis=(-1, -2)))._message_to_child()
        self.assertAllClose(u[0] * np.ones((4, 1)),
                            np.ones((4, 1)) * tau[..., None] * mu)
        self.assertAllClose(
            u[1] * np.ones((4, 1, 1)),
            np.ones((4, 1, 1)) *
            (linalg.inv(Lambda) + tau[..., None, None] * linalg.outer(mu, mu)))
        self.assertAllClose(u[2] * np.ones(4), np.ones(4) * tau)
        self.assertAllClose(u[3] * np.ones(4), np.ones(4) * logtau)

        pass
Ejemplo n.º 11
0
    def test_message_to_child(self):
        """
        Test the message to child of GaussianGamma node.
        """

        # Simple test
        mu = np.array([1,2,3])
        Lambda = np.identity(3)
        a = 2
        b = 10
        X_alpha = GaussianGamma(mu, Lambda, a, b)
        u = X_alpha._message_to_child()
        self.assertEqual(len(u), 4)
        tau = np.array(a/b)
        self.assertAllClose(u[0],
                            tau[...,None] * mu)
        self.assertAllClose(u[1],
                            (linalg.inv(Lambda) 
                             + tau[...,None,None] * linalg.outer(mu, mu)))
        self.assertAllClose(u[2],
                            tau)
        self.assertAllClose(u[3],
                            -np.log(b) + special.psi(a))

        # Test with unknown parents
        mu = Gaussian(np.arange(3), 10*np.identity(3))
        Lambda = Wishart(10, np.identity(3))
        a = 2
        b = Gamma(3, 15)
        X_alpha = GaussianGamma(mu, Lambda, a, b)
        u = X_alpha._message_to_child()
        (mu, mumu) = mu._message_to_child()
        Cov_mu = mumu - linalg.outer(mu, mu)
        (Lambda, _) = Lambda._message_to_child()
        (b, _) = b._message_to_child()
        (tau, logtau) = Gamma(a, b + 0.5*np.sum(Lambda*Cov_mu))._message_to_child()
        self.assertAllClose(u[0],
                            tau[...,None] * mu)
        self.assertAllClose(u[1],
                            (linalg.inv(Lambda)
                             + tau[...,None,None] * linalg.outer(mu, mu)))
        self.assertAllClose(u[2],
                            tau)
        self.assertAllClose(u[3],
                            logtau)

        # Test with plates
        mu = Gaussian(np.reshape(np.arange(3*4), (4,3)),
                      10*np.identity(3),
                      plates=(4,))
        Lambda = Wishart(10, np.identity(3))
        a = 2
        b = Gamma(3, 15)
        X_alpha = GaussianGamma(mu, Lambda, a, b, plates=(4,))
        u = X_alpha._message_to_child()
        (mu, mumu) = mu._message_to_child()
        Cov_mu = mumu - linalg.outer(mu, mu)
        (Lambda, _) = Lambda._message_to_child()
        (b, _) = b._message_to_child()
        (tau, logtau) = Gamma(a, 
                              b + 0.5*np.sum(Lambda*Cov_mu, 
                                             axis=(-1,-2)))._message_to_child()
        self.assertAllClose(u[0] * np.ones((4,1)),
                            np.ones((4,1)) * tau[...,None] * mu)
        self.assertAllClose(u[1] * np.ones((4,1,1)),
                            np.ones((4,1,1)) * (linalg.inv(Lambda)
                                                + tau[...,None,None] * linalg.outer(mu, mu)))
        self.assertAllClose(u[2] * np.ones(4),
                            np.ones(4) * tau)
        self.assertAllClose(u[3] * np.ones(4),
                            np.ones(4) * logtau)
        
        pass
Ejemplo n.º 12
0
 def covariance(self):
     (x, xx) = self.get()[:2]
     Cov = xx - linalg.outer(x, x, ndim=1)
     return Cov
Ejemplo n.º 13
0
    def test_message_to_child(self):
        """
        Test the message from SumMultiply to its children.
        """
        def compare_moments(u0, u1, *args):
            Y = SumMultiply(*args)
            u_Y = Y.get_moments()
            self.assertAllClose(u_Y[0], u0)
            self.assertAllClose(u_Y[1], u1)

        # Test constant parent
        y = np.random.randn(2, 3, 4)
        compare_moments(y, linalg.outer(y, y, ndim=2), 'ij->ij', y)

        # Do nothing for 2-D array
        Y = GaussianARD(np.random.randn(5, 2, 3),
                        np.random.rand(5, 2, 3),
                        plates=(5, ),
                        shape=(2, 3))
        y = Y.get_moments()
        compare_moments(y[0], y[1], 'ij->ij', Y)
        compare_moments(y[0], y[1], Y, [0, 1], [0, 1])

        # Sum over the rows of a matrix
        Y = GaussianARD(np.random.randn(5, 2, 3),
                        np.random.rand(5, 2, 3),
                        plates=(5, ),
                        shape=(2, 3))
        y = Y.get_moments()
        mu = np.einsum('...ij->...j', y[0])
        cov = np.einsum('...ijkl->...jl', y[1])
        compare_moments(mu, cov, 'ij->j', Y)
        compare_moments(mu, cov, Y, [0, 1], [1])

        # Inner product of three vectors
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         plates=(),
                         shape=(2, ))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(6, 1, 2),
                         np.random.rand(6, 1, 2),
                         plates=(6, 1),
                         shape=(2, ))
        x2 = X2.get_moments()
        X3 = GaussianARD(np.random.randn(7, 6, 5, 2),
                         np.random.rand(7, 6, 5, 2),
                         plates=(7, 6, 5),
                         shape=(2, ))
        x3 = X3.get_moments()
        mu = np.einsum('...i,...i,...i->...', x1[0], x2[0], x3[0])
        cov = np.einsum('...ij,...ij,...ij->...', x1[1], x2[1], x3[1])
        compare_moments(mu, cov, 'i,i,i', X1, X2, X3)
        compare_moments(mu, cov, 'i,i,i->', X1, X2, X3)
        compare_moments(mu, cov, X1, [9], X2, [9], X3, [9])
        compare_moments(mu, cov, X1, [9], X2, [9], X3, [9], [])

        # Outer product of two vectors
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         plates=(5, ),
                         shape=(2, ))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(6, 1, 2),
                         np.random.rand(6, 1, 2),
                         plates=(6, 1),
                         shape=(2, ))
        x2 = X2.get_moments()
        mu = np.einsum('...i,...j->...ij', x1[0], x2[0])
        cov = np.einsum('...ik,...jl->...ijkl', x1[1], x2[1])
        compare_moments(mu, cov, 'i,j->ij', X1, X2)
        compare_moments(mu, cov, X1, [9], X2, [7], [9, 7])

        # Matrix product
        Y1 = GaussianARD(np.random.randn(3, 2),
                         np.random.rand(3, 2),
                         plates=(),
                         shape=(3, 2))
        y1 = Y1.get_moments()
        Y2 = GaussianARD(np.random.randn(5, 2, 3),
                         np.random.rand(5, 2, 3),
                         plates=(5, ),
                         shape=(2, 3))
        y2 = Y2.get_moments()
        mu = np.einsum('...ik,...kj->...ij', y1[0], y2[0])
        cov = np.einsum('...ikjl,...kmln->...imjn', y1[1], y2[1])
        compare_moments(mu, cov, 'ik,kj->ij', Y1, Y2)
        compare_moments(mu, cov, Y1, ['i', 'k'], Y2, ['k', 'j'], ['i', 'j'])

        # Trace of a matrix product
        Y1 = GaussianARD(np.random.randn(3, 2),
                         np.random.rand(3, 2),
                         plates=(),
                         shape=(3, 2))
        y1 = Y1.get_moments()
        Y2 = GaussianARD(np.random.randn(5, 2, 3),
                         np.random.rand(5, 2, 3),
                         plates=(5, ),
                         shape=(2, 3))
        y2 = Y2.get_moments()
        mu = np.einsum('...ij,...ji->...', y1[0], y2[0])
        cov = np.einsum('...ikjl,...kilj->...', y1[1], y2[1])
        compare_moments(mu, cov, 'ij,ji', Y1, Y2)
        compare_moments(mu, cov, 'ij,ji->', Y1, Y2)
        compare_moments(mu, cov, Y1, ['i', 'j'], Y2, ['j', 'i'])
        compare_moments(mu, cov, Y1, ['i', 'j'], Y2, ['j', 'i'], [])

        # Vector-matrix-vector product
        X1 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         plates=(),
                         shape=(3, ))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(6, 1, 2),
                         np.random.rand(6, 1, 2),
                         plates=(6, 1),
                         shape=(2, ))
        x2 = X2.get_moments()
        Y = GaussianARD(np.random.randn(3, 2),
                        np.random.rand(3, 2),
                        plates=(),
                        shape=(3, 2))
        y = Y.get_moments()
        mu = np.einsum('...i,...ij,...j->...', x1[0], y[0], x2[0])
        cov = np.einsum('...ia,...ijab,...jb->...', x1[1], y[1], x2[1])
        compare_moments(mu, cov, 'i,ij,j', X1, Y, X2)
        compare_moments(mu, cov, X1, [1], Y, [1, 2], X2, [2])

        # Complex sum-product of 0-D, 1-D, 2-D and 3-D arrays
        V = GaussianARD(np.random.randn(7, 6, 5),
                        np.random.rand(7, 6, 5),
                        plates=(7, 6, 5),
                        shape=())
        v = V.get_moments()
        X = GaussianARD(np.random.randn(6, 1, 2),
                        np.random.rand(6, 1, 2),
                        plates=(6, 1),
                        shape=(2, ))
        x = X.get_moments()
        Y = GaussianARD(np.random.randn(3, 4),
                        np.random.rand(3, 4),
                        plates=(5, ),
                        shape=(3, 4))
        y = Y.get_moments()
        Z = GaussianARD(np.random.randn(4, 2, 3),
                        np.random.rand(4, 2, 3),
                        plates=(6, 5),
                        shape=(4, 2, 3))
        z = Z.get_moments()
        mu = np.einsum('...,...i,...kj,...jik->...k', v[0], x[0], y[0], z[0])
        cov = np.einsum('...,...ia,...kjcb,...jikbac->...kc', v[1], x[1], y[1],
                        z[1])
        compare_moments(mu, cov, ',i,kj,jik->k', V, X, Y, Z)
        compare_moments(mu, cov, V, [], X, ['i'], Y, ['k', 'j'], Z,
                        ['j', 'i', 'k'], ['k'])

        # Test with constant nodes
        N = 10
        D = 5
        a = np.random.randn(N, D)
        B = Gaussian(
            np.random.randn(D),
            random.covariance(D),
        )
        X = SumMultiply('i,i->', B, a)
        np.testing.assert_allclose(
            X.get_moments()[0],
            np.einsum('ni,i->n', a,
                      B.get_moments()[0]),
        )
        np.testing.assert_allclose(
            X.get_moments()[1],
            np.einsum('ni,nj,ij->n', a, a,
                      B.get_moments()[1]),
        )

        #
        # Gaussian-gamma parents
        #

        # Outer product of vectors
        X1 = GaussianARD(np.random.randn(2), np.random.rand(2), shape=(2, ))
        x1 = X1.get_moments()
        X2 = GaussianGamma(np.random.randn(6, 1, 2),
                           random.covariance(2),
                           np.random.rand(6, 1),
                           np.random.rand(6, 1),
                           plates=(6, 1))
        x2 = X2.get_moments()
        Y = SumMultiply('i,j->ij', X1, X2)
        u = Y._message_to_child()
        y = np.einsum('...i,...j->...ij', x1[0], x2[0])
        yy = np.einsum('...ik,...jl->...ijkl', x1[1], x2[1])
        self.assertAllClose(u[0], y)
        self.assertAllClose(u[1], yy)
        self.assertAllClose(u[2], x2[2])
        self.assertAllClose(u[3], x2[3])

        # Test with constant nodes
        N = 10
        M = 8
        D = 5
        a = np.random.randn(N, 1, D)
        B = GaussianGamma(
            np.random.randn(M, D),
            random.covariance(D, size=(M, )),
            np.random.rand(M),
            np.random.rand(M),
            ndim=1,
        )
        X = SumMultiply('i,i->', B, a)
        np.testing.assert_allclose(
            X.get_moments()[0],
            np.einsum('nmi,mi->nm', a,
                      B.get_moments()[0]),
        )
        np.testing.assert_allclose(
            X.get_moments()[1],
            np.einsum('nmi,nmj,mij->nm', a, a,
                      B.get_moments()[1]),
        )
        np.testing.assert_allclose(
            X.get_moments()[2],
            B.get_moments()[2],
        )
        np.testing.assert_allclose(
            X.get_moments()[3],
            B.get_moments()[3],
        )

        pass
Ejemplo n.º 14
0
    def _compute_bound(self, R, logdet=None, inv=None, Q=None, gradient=False, terms=False):
        """
        Rotate q(X) and q(alpha).

        Assume:
        p(X|alpha) = prod_m N(x_m|0,diag(alpha))
        p(alpha) = prod_d G(a_d,b_d)
        """

        ## R = self._full_rotation_matrix(R)
        ## if inv is not None:
        ##     inv = self._full_rotation_matrix(inv)

        #
        # Transform the distributions and moments
        #

        plates_alpha = self.plates_alpha
        plates_X = self.plates_X

        # Compute rotated second moment
        if self.plate_axis is not None:
            # The plate axis has been moved to be the last plate axis

            if Q is None:
                raise ValueError("Plates should be rotated but no Q give")

            # Transform covariance
            sumQ = np.sum(Q, axis=0)
            QCovQ = sumQ[:, None, None] ** 2 * self.CovX

            # Rotate plates
            if self.precompute:
                QX_QX = np.einsum("...kalb,...ik,...il->...iab", self.X_X, Q, Q)
                XX = QX_QX + QCovQ
                XX = sum_to_plates(XX, plates_alpha[:-1], ndim=2)
                Xmu = np.einsum("...kaib,...ik->...iab", self.X_mu, Q)
                Xmu = sum_to_plates(Xmu, plates_alpha[:-1], ndim=2)
            else:
                X = self.X
                mu = self.mu
                QX = np.einsum("...ik,...kj->...ij", Q, X)
                XX = sum_to_plates(QCovQ, plates_alpha[:-1], ndim=2) + sum_to_plates(
                    linalg.outer(QX, QX), plates_alpha[:-1], ndim=2, plates_from=plates_X
                )
                Xmu = sum_to_plates(linalg.outer(QX, self.mu), plates_alpha[:-1], ndim=2, plates_from=plates_X)

            mu2 = self.mu2
            D = np.shape(XX)[-1]
            logdet_Q = D * np.log(np.abs(sumQ))

        else:
            XX = self.XX
            mu2 = self.mu2
            Xmu = self.Xmu
            logdet_Q = 0

        # Compute transformed moments
        # mu2 = np.einsum('...ii->...i', mu2)
        RXmu = np.einsum("...ik,...ki->...i", R, Xmu)
        RXX = np.einsum("...ik,...kj->...ij", R, XX)
        RXXR = np.einsum("...ik,...ik->...i", RXX, R)

        # <(X-mu) * (X-mu)'>_R
        XmuXmu = RXXR - 2 * RXmu + mu2

        D = np.shape(R)[0]

        # Compute q(alpha)
        if self.update_alpha:
            # Parameters
            a0 = self.a0
            b0 = self.b0
            a = self.a
            b = b0 + 0.5 * sum_to_plates(XmuXmu, plates_alpha, plates_from=None, ndim=0)
            # Some expectations
            alpha = a / b
            logb = np.log(b)
            logalpha = -logb  # + const
            b0_alpha = b0 * alpha
            a0_logalpha = a0 * logalpha
        else:
            alpha = self.alpha
            logalpha = 0

        #
        # Compute the cost
        #

        def sum_plates(V, *plates):
            full_plates = misc.broadcasted_shape(*plates)

            r = self.node_X.broadcasting_multiplier(full_plates, np.shape(V))
            return r * np.sum(V)

        XmuXmu_alpha = XmuXmu * alpha

        if logdet is None:
            logdet_R = np.linalg.slogdet(R)[1]
            inv_R = np.linalg.inv(R)
        else:
            logdet_R = logdet
            inv_R = inv

        # Compute entropy H(X)
        logH_X = random.gaussian_entropy(-2 * sum_plates(logdet_R + logdet_Q, plates_X), 0)

        # Compute <log p(X|alpha)>
        logp_X = random.gaussian_logpdf(
            sum_plates(XmuXmu_alpha, plates_alpha[:-1] + [D]), 0, 0, sum_plates(logalpha, plates_X + [D]), 0
        )

        if self.update_alpha:

            # Compute entropy H(alpha)
            # This cancels out with the log(alpha) term in log(p(alpha))
            logH_alpha = 0

            # Compute <log p(alpha)>
            logp_alpha = random.gamma_logpdf(
                sum_plates(b0_alpha, plates_alpha), 0, sum_plates(a0_logalpha, plates_alpha), 0, 0
            )
        else:
            logH_alpha = 0
            logp_alpha = 0

        # Compute the bound
        if terms:
            bound = {self.node_X: logp_X + logH_X}
            if self.update_alpha:
                bound.update({self.node_alpha: logp_alpha + logH_alpha})
        else:
            bound = 0 + logp_X + logp_alpha + logH_X + logH_alpha

        if not gradient:
            return bound

        #
        # Compute the gradient with respect R
        #

        broadcasting_multiplier = self.node_X.broadcasting_multiplier

        def sum_plates(V, plates):
            ones = np.ones(np.shape(R))
            r = broadcasting_multiplier(plates, np.shape(V)[:-2])
            return r * misc.sum_multiply(V, ones, axis=(-1, -2), sumaxis=False, keepdims=False)

        D_XmuXmu = 2 * RXX - 2 * gaussian.transpose_covariance(Xmu)

        DXmuXmu_alpha = np.einsum("...i,...ij->...ij", alpha, D_XmuXmu)
        if self.update_alpha:
            D_b = 0.5 * D_XmuXmu
            XmuXmu_Dalpha = np.einsum(
                "...i,...i,...i,...ij->...ij",
                sum_to_plates(XmuXmu, plates_alpha, plates_from=None, ndim=0),
                alpha,
                -1 / b,
                D_b,
            )
            D_b0_alpha = np.einsum("...i,...i,...i,...ij->...ij", b0, alpha, -1 / b, D_b)
            D_logb = np.einsum("...i,...ij->...ij", 1 / b, D_b)
            D_logalpha = -D_logb
            D_a0_logalpha = a0 * D_logalpha
        else:
            XmuXmu_Dalpha = 0
            D_logalpha = 0

        D_XmuXmu_alpha = DXmuXmu_alpha + XmuXmu_Dalpha
        D_logR = inv_R.T

        # Compute dH(X)
        dlogH_X = random.gaussian_entropy(-2 * sum_plates(D_logR, plates_X), 0)

        # Compute d<log p(X|alpha)>
        dlogp_X = random.gaussian_logpdf(
            sum_plates(D_XmuXmu_alpha, plates_alpha[:-1]),
            0,
            0,
            (sum_plates(D_logalpha, plates_X) * broadcasting_multiplier((D,), plates_alpha[-1:])),
            0,
        )

        if self.update_alpha:

            # Compute dH(alpha)
            # This cancels out with the log(alpha) term in log(p(alpha))
            dlogH_alpha = 0

            # Compute d<log p(alpha)>
            dlogp_alpha = random.gamma_logpdf(
                sum_plates(D_b0_alpha, plates_alpha[:-1]), 0, sum_plates(D_a0_logalpha, plates_alpha[:-1]), 0, 0
            )
        else:
            dlogH_alpha = 0
            dlogp_alpha = 0

        if terms:
            raise NotImplementedError()
            dR_bound = {self.node_X: dlogp_X + dlogH_X}
            if self.update_alpha:
                dR_bound.update({self.node_alpha: dlogp_alpha + dlogH_alpha})
        else:
            dR_bound = 0 * dlogp_X + dlogp_X + dlogp_alpha + dlogH_X + dlogH_alpha

        if self.subset:
            indices = np.ix_(self.subset, self.subset)
            dR_bound = dR_bound[indices]

        if self.plate_axis is None:
            return (bound, dR_bound)

        #
        # Compute the gradient with respect to Q (if Q given)
        #

        # Some pre-computations
        Q_RCovR = np.einsum("...ik,...kl,...il,...->...i", R, self.CovX, R, sumQ)
        if self.precompute:
            Xr_rX = np.einsum("...abcd,...jb,...jd->...jac", self.X_X, R, R)
            QXr_rX = np.einsum("...akj,...ik->...aij", Xr_rX, Q)
            RX_mu = np.einsum("...jk,...akbj->...jab", R, self.X_mu)

        else:
            RX = np.einsum("...ik,...k->...i", R, X)
            QXR = np.einsum("...ik,...kj->...ij", Q, RX)
            QXr_rX = np.einsum("...ik,...jk->...kij", QXR, RX)
            RX_mu = np.einsum("...ik,...jk->...kij", RX, mu)

            QXr_rX = sum_to_plates(QXr_rX, plates_alpha[:-2], ndim=3, plates_from=plates_X[:-1])
            RX_mu = sum_to_plates(RX_mu, plates_alpha[:-2], ndim=3, plates_from=plates_X[:-1])

        def psi(v):
            """
            Compute: d/dQ 1/2*trace(diag(v)*<(X-mu)*(X-mu)>)

            = Q*<X>'*R'*diag(v)*R*<X> + ones * Q diag( tr(R'*diag(v)*R*Cov) ) 
              + mu*diag(v)*R*<X>
            """

            # Precompute all terms to plates_alpha because v has shape
            # plates_alpha.

            # Gradient of 0.5*v*<x>*<x>
            v_QXrrX = np.einsum("...kij,...ik->...ij", QXr_rX, v)

            # Gradient of 0.5*v*Cov
            Q_tr_R_v_R_Cov = np.einsum("...k,...k->...", Q_RCovR, v)[..., None, :]

            # Gradient of mu*v*x
            mu_v_R_X = np.einsum("...ik,...kji->...ij", v, RX_mu)

            return v_QXrrX + Q_tr_R_v_R_Cov - mu_v_R_X

        def sum_plates(V, plates):
            ones = np.ones(np.shape(Q))
            r = self.node_X.broadcasting_multiplier(plates, np.shape(V)[:-2])

            return r * misc.sum_multiply(V, ones, axis=(-1, -2), sumaxis=False, keepdims=False)

        if self.update_alpha:
            D_logb = psi(1 / b)
            XX_Dalpha = -psi(alpha / b * sum_to_plates(XmuXmu, plates_alpha))
            D_logalpha = -D_logb
        else:
            XX_Dalpha = 0
            D_logalpha = 0
        DXX_alpha = 2 * psi(alpha)
        D_XX_alpha = DXX_alpha + XX_Dalpha
        D_logdetQ = D / sumQ
        N = np.shape(Q)[-1]

        # Compute dH(X)
        dQ_logHX = random.gaussian_entropy(-2 * sum_plates(D_logdetQ, plates_X[:-1]), 0)

        # Compute d<log p(X|alpha)>
        dQ_logpX = random.gaussian_logpdf(
            sum_plates(D_XX_alpha, plates_alpha[:-2]),
            0,
            0,
            (sum_plates(D_logalpha, plates_X[:-1]) * broadcasting_multiplier((N, D), plates_alpha[-2:])),
            0,
        )

        if self.update_alpha:

            D_alpha = -psi(alpha / b)
            D_b0_alpha = b0 * D_alpha
            D_a0_logalpha = a0 * D_logalpha

            # Compute dH(alpha)
            # This cancels out with the log(alpha) term in log(p(alpha))
            dQ_logHalpha = 0

            # Compute d<log p(alpha)>
            dQ_logpalpha = random.gamma_logpdf(
                sum_plates(D_b0_alpha, plates_alpha[:-2]), 0, sum_plates(D_a0_logalpha, plates_alpha[:-2]), 0, 0
            )
        else:

            dQ_logHalpha = 0
            dQ_logpalpha = 0

        if terms:
            raise NotImplementedError()
            dQ_bound = {self.node_X: dQ_logpX + dQ_logHX}
            if self.update_alpha:
                dQ_bound.update({self.node_alpha: dQ_logpalpha + dQ_logHalpha})
        else:
            dQ_bound = 0 * dQ_logpX + dQ_logpX + dQ_logpalpha + dQ_logHX + dQ_logHalpha
        return (bound, dR_bound, dQ_bound)
Ejemplo n.º 15
0
    def setup(self, plate_axis=None):
        """
        This method should be called just before optimization.

        For efficiency, sum over axes that are not in mu, alpha nor rotation.

        If using Q, set rotate_plates to True.
        """

        # Store the original plate_axis parameter for later use in other methods
        self.plate_axis = plate_axis

        # Manipulate the plate_axis parameter to suit the needs of this method
        if plate_axis is not None:
            if not isinstance(plate_axis, int):
                raise ValueError("Plate axis must be integer")
            if plate_axis >= 0:
                plate_axis -= len(self.node_X.plates)
            if plate_axis < -len(self.node_X.plates) or plate_axis >= 0:
                raise ValueError("Axis out of bounds")
            plate_axis -= self.ndim - 1  # Why -1? Because one axis is preserved!

        # Get the mean parameter. It will not be rotated. This assumes that mu
        # and alpha are really independent.
        (alpha_mu, alpha_mu2, alpha, _) = self.node_parent.get_moments()
        (X, XX) = self.node_X.get_moments()

        #
        mu = alpha_mu / alpha
        mu2 = alpha_mu2 / alpha
        # For simplicity, force mu to have the same shape as X
        mu = mu * np.ones(self.node_X.dims[0])
        mu2 = mu2 * np.ones(self.node_X.dims[0])
        ## (mu, mumu) = gaussian.reshape_gaussian_array(self.node_mu.dims[0],
        ##                                              self.node_X.dims[0],
        ##                                              mu,
        ##                                              mumu)

        # Take diagonal of covariances to variances for axes that are not in R
        # (and move those axes to be the last)
        XX = covariance_to_variance(XX, ndim=self.ndim, covariance_axis=self.axis)
        ## mumu = covariance_to_variance(mumu,
        ##                               ndim=self.ndim,
        ##                               covariance_axis=self.axis)

        # Move axes of X and mu and compute their outer product
        X = misc.moveaxis(X, self.axis, -1)
        mu = misc.moveaxis(mu, self.axis, -1)
        mu2 = misc.moveaxis(mu2, self.axis, -1)
        Xmu = linalg.outer(X, mu, ndim=1)
        D = np.shape(X)[-1]

        # Move axes of alpha related variables
        def safe_move_axis(x):
            if np.ndim(x) >= -self.axis:
                return misc.moveaxis(x, self.axis, -1)
            else:
                return x[..., np.newaxis]

        if self.update_alpha:
            a = safe_move_axis(self.node_alpha.phi[1])
            a0 = safe_move_axis(self.node_alpha.parents[0].get_moments()[0])
            b0 = safe_move_axis(self.node_alpha.parents[1].get_moments()[0])
            plates_alpha = list(self.node_alpha.plates)
        else:
            alpha = safe_move_axis(self.node_parent.get_moments()[2])
            plates_alpha = list(self.node_parent.get_shape(2))

        # Move plates of alpha for R
        if len(plates_alpha) >= -self.axis:
            plate = plates_alpha.pop(self.axis)
            plates_alpha.append(plate)
        else:
            plates_alpha.append(1)

        plates_X = list(self.node_X.get_shape(0))
        plates_X.pop(self.axis)

        def sum_to_alpha(V, ndim=2):
            # TODO/FIXME: This could be improved so that it is not required to
            # explicitly repeat to alpha plates. Multiplying by ones was just a
            # simple bug fix.
            return sum_to_plates(
                V * np.ones(plates_alpha[:-1] + ndim * [1]), plates_alpha[:-1], ndim=ndim, plates_from=plates_X
            )

        if plate_axis is not None:
            # Move plate axis just before the rotated dimensions (which are
            # last)
            def safe_move_plate_axis(x, ndim):
                if np.ndim(x) - ndim >= -plate_axis:
                    return misc.moveaxis(x, plate_axis - ndim, -ndim - 1)
                else:
                    inds = (Ellipsis, None) + ndim * (slice(None),)
                    return x[inds]

            X = safe_move_plate_axis(X, 1)
            mu = safe_move_plate_axis(mu, 1)
            XX = safe_move_plate_axis(XX, 2)
            mu2 = safe_move_plate_axis(mu2, 1)
            if self.update_alpha:
                a = safe_move_plate_axis(a, 1)
                a0 = safe_move_plate_axis(a0, 1)
                b0 = safe_move_plate_axis(b0, 1)
            else:
                alpha = safe_move_plate_axis(alpha, 1)
            # Move plates of X and alpha
            plate = plates_X.pop(plate_axis)
            plates_X.append(plate)
            if len(plates_alpha) >= -plate_axis + 1:
                plate = plates_alpha.pop(plate_axis - 1)
            else:
                plate = 1
            plates_alpha = plates_alpha[:-1] + [plate] + plates_alpha[-1:]

            CovX = XX - linalg.outer(X, X)
            self.CovX = sum_to_plates(CovX, plates_alpha[:-2], ndim=3, plates_from=plates_X[:-1])
            # Broadcast mumu to ensure shape
            # mumu = np.ones(np.shape(XX)[-3:]) * mumu
            mu2 = mu2 * np.ones(np.shape(X)[-2:])
            self.mu2 = sum_to_alpha(mu2, ndim=1)

            if self.precompute:
                # Precompute some stuff for the gradient of plate rotation
                #
                # NOTE: These terms may require a lot of memory if alpha has the
                # same or almost the same plates as X.
                self.X_X = sum_to_plates(
                    X[..., :, :, None, None] * X[..., None, None, :, :],
                    plates_alpha[:-2],
                    ndim=4,
                    plates_from=plates_X[:-1],
                )
                self.X_mu = sum_to_plates(
                    X[..., :, :, None, None] * mu[..., None, None, :, :],
                    plates_alpha[:-2],
                    ndim=4,
                    plates_from=plates_X[:-1],
                )
            else:
                self.X = X
                self.mu = mu

        else:
            # Sum axes that are not in the plates of alpha
            self.XX = sum_to_alpha(XX)
            self.mu2 = sum_to_alpha(mu2, ndim=1)
            self.Xmu = sum_to_alpha(Xmu)

        if self.update_alpha:
            self.a = a
            self.a0 = a0
            self.b0 = b0
        else:
            self.alpha = alpha

        self.plates_X = plates_X
        self.plates_alpha = plates_alpha

        # Take only a subset of the matrix for rotation
        if self.subset is not None:
            if self.precompute:
                raise NotImplementedError("Precomputation not implemented when " "using a subset")
            # from X
            self.X = self.X[..., self.subset]
            self.mu2 = self.mu2[..., self.subset]
            if plate_axis is not None:
                # from CovX
                inds = []
                for i in range(np.ndim(self.CovX) - 2):
                    inds.append(range(np.shape(self.CovX)[i]))
                inds.append(self.subset)
                inds.append(self.subset)
                indices = np.ix_(*inds)
                self.CovX = self.CovX[indices]
                # from mu
                self.mu = self.mu[..., self.subset]
            else:
                # from XX
                inds = []
                for i in range(np.ndim(self.XX) - 2):
                    inds.append(range(np.shape(self.XX)[i]))
                inds.append(self.subset)
                inds.append(self.subset)
                indices = np.ix_(*inds)
                self.XX = self.XX[indices]
                # from Xmu
                self.Xmu = self.Xmu[..., self.subset]
            # from alpha
            if self.update_alpha:
                if np.shape(self.a)[-1] > 1:
                    self.a = self.a[..., self.subset]
                if np.shape(self.a0)[-1] > 1:
                    self.a0 = self.a0[..., self.subset]
                if np.shape(self.b0)[-1] > 1:
                    self.b0 = self.b0[..., self.subset]
            else:
                if np.shape(self.alpha)[-1] > 1:
                    self.alpha = self.alpha[..., self.subset]

            self.plates_alpha[-1] = min(self.plates_alpha[-1], len(self.subset))