def setup(self): """ This method should be called just before optimization. """ # Get moments of X (X, XnXn, XpXn) = self.X_node.get_moments() # TODO/FIXME: Sum to plates of A/CovA XpXp = XnXn[..., :-1, :, :] # # Expectations with respect to X # self.X0 = X[..., 0, :] self.X0X0 = XnXn[..., 0, :, :] # self.XnXn = np.sum(XnXn[...,1:,:,:], axis=-3) self.XnXn = sum_to_plates(XnXn[..., 1:, :, :], (), plates_from=self.X_node.plates + (self.N - 1,), ndim=2) # Get moments of the fixed parameter nodes mu = self.X_node.parents[0].get_moments()[0] self.Lambda = self.X_node.parents[1].get_moments()[0] self.Lambda_mu_X0 = linalg.outer(np.einsum("...ik,...k->...i", self.Lambda, mu), self.X0) self.Lambda_mu_X0 = sum_to_plates(self.Lambda_mu_X0, (), plates_from=self.X_node.plates, ndim=2) # # Prepare the rotation for A # (self.A_XpXn, self.A_XpXp_A, self.CovA_XpXp) = self._computations_for_A_and_X(XpXn, XpXp) self.A_rotator.setup(plate_axis=-1)
def test_rotate_plates(self): # Basic test for Gaussian vectors X = GaussianARD(np.random.randn(3,2), np.random.rand(3,2), shape=(2,), plates=(3,)) (u0, u1) = X.get_moments() Cov = u1 - linalg.outer(u0, u0, ndim=1) Q = np.random.randn(3,3) Qu0 = np.einsum('ik,kj->ij', Q, u0) QCov = np.einsum('k,kij->kij', np.sum(Q, axis=0)**2, Cov) Qu1 = QCov + linalg.outer(Qu0, Qu0, ndim=1) X.rotate_plates(Q, plate_axis=-1) (u0, u1) = X.get_moments() self.assertAllClose(u0, Qu0) self.assertAllClose(u1, Qu1) # Test full covariance, that is, with observations X = GaussianARD(np.random.randn(3,2), np.random.rand(3,2), shape=(2,), plates=(3,)) Y = Gaussian(X, [[2.0, 1.5], [1.5, 3.0]], plates=(3,)) Y.observe(np.random.randn(3,2)) X.update() (u0, u1) = X.get_moments() Cov = u1 - linalg.outer(u0, u0, ndim=1) Q = np.random.randn(3,3) Qu0 = np.einsum('ik,kj->ij', Q, u0) QCov = np.einsum('k,kij->kij', np.sum(Q, axis=0)**2, Cov) Qu1 = QCov + linalg.outer(Qu0, Qu0, ndim=1) X.rotate_plates(Q, plate_axis=-1) (u0, u1) = X.get_moments() self.assertAllClose(u0, Qu0) self.assertAllClose(u1, Qu1) pass
def test_initialization(self): """ Test initialization methods of GaussianARD """ X = GaussianARD(1, 2, shape=(2, ), plates=(3, )) # Prior initialization mu = 1 * np.ones((3, 2)) alpha = 2 * np.ones((3, 2)) X.initialize_from_prior() u = X._message_to_child() self.assertAllClose(u[0] * np.ones((3, 2)), mu) self.assertAllClose( u[1] * np.ones((3, 2, 2)), linalg.outer(mu, mu, ndim=1) + misc.diag(1 / alpha, ndim=1)) # Parameter initialization mu = np.random.randn(3, 2) alpha = np.random.rand(3, 2) X.initialize_from_parameters(mu, alpha) u = X._message_to_child() self.assertAllClose(u[0], mu) self.assertAllClose( u[1], linalg.outer(mu, mu, ndim=1) + misc.diag(1 / alpha, ndim=1)) # Value initialization x = np.random.randn(3, 2) X.initialize_from_value(x) u = X._message_to_child() self.assertAllClose(u[0], x) self.assertAllClose(u[1], linalg.outer(x, x, ndim=1)) # Random initialization X.initialize_from_random() pass
def test_initialization(self): """ Test initialization methods of GaussianARD """ X = GaussianARD(1, 2, shape=(2,), plates=(3,)) # Prior initialization mu = 1 * np.ones((3, 2)) alpha = 2 * np.ones((3, 2)) X.initialize_from_prior() u = X._message_to_child() self.assertAllClose(u[0]*np.ones((3,2)), mu) self.assertAllClose(u[1]*np.ones((3,2,2)), linalg.outer(mu, mu, ndim=1) + misc.diag(1/alpha, ndim=1)) # Parameter initialization mu = np.random.randn(3, 2) alpha = np.random.rand(3, 2) X.initialize_from_parameters(mu, alpha) u = X._message_to_child() self.assertAllClose(u[0], mu) self.assertAllClose(u[1], linalg.outer(mu, mu, ndim=1) + misc.diag(1/alpha, ndim=1)) # Value initialization x = np.random.randn(3, 2) X.initialize_from_value(x) u = X._message_to_child() self.assertAllClose(u[0], x) self.assertAllClose(u[1], linalg.outer(x, x, ndim=1)) # Random initialization X.initialize_from_random() pass
def _compute_moments(self, *u_parents): """ Compute the moments of the sum """ u0 = functools.reduce(np.add, (u_parent[0] for u_parent in u_parents)) u1 = functools.reduce(np.add, (u_parent[1] for u_parent in u_parents)) for i in range(self.N): for j in range(i + 1, self.N): xi_xj = linalg.outer(u_parents[i][0], u_parents[j][0], ndim=self.ndim) xj_xi = linalg.transpose(xi_xj, ndim=self.ndim) u1 = u1 + xi_xj + xj_xi return [u0, u1]
def _compute_moments(self, *u_parents): """ Compute the moments of the sum """ u0 = functools.reduce(np.add, (u_parent[0] for u_parent in u_parents)) u1 = functools.reduce(np.add, (u_parent[1] for u_parent in u_parents)) for i in range(self.N): for j in range(i+1, self.N): xi_xj = linalg.outer(u_parents[i][0], u_parents[j][0], ndim=self.ndim) xj_xi = linalg.transpose(xi_xj, ndim=self.ndim) u1 = u1 + xi_xj + xj_xi return [u0, u1]
def _compute_moments(self, *u_nodes): x = misc.concatenate(*[u[0] for u in u_nodes], axis=-1) xx = misc.block_diag(*[u[1] for u in u_nodes]) # Explicitly broadcast xx to plates of x x_plates = np.shape(x)[:-1] xx = np.ones(x_plates)[..., None, None] * xx # Compute the cross-covariance terms using the means of each variable # (because covariances are zero for factorized nodes in the VB # approximation) i_start = 0 for m in range(len(u_nodes)): i_end = i_start + np.shape(u_nodes[m][0])[-1] j_start = 0 for n in range(m): j_end = j_start + np.shape(u_nodes[n][0])[-1] xm_xn = linalg.outer(u_nodes[m][0], u_nodes[n][0], ndim=1) xx[..., i_start:i_end, j_start:j_end] = xm_xn xx[..., j_start:j_end, i_start:i_end] = misc.T(xm_xn) j_start = j_end i_start = i_end return [x, xx]
def _compute_moments(self, *u_nodes): x = misc.concatenate(*[u[0] for u in u_nodes], axis=-1) xx = misc.block_diag(*[u[1] for u in u_nodes]) # Explicitly broadcast xx to plates of x x_plates = np.shape(x)[:-1] xx = np.ones(x_plates)[...,None,None] * xx # Compute the cross-covariance terms using the means of each variable # (because covariances are zero for factorized nodes in the VB # approximation) i_start = 0 for m in range(len(u_nodes)): i_end = i_start + np.shape(u_nodes[m][0])[-1] j_start = 0 for n in range(m): j_end = j_start + np.shape(u_nodes[n][0])[-1] xm_xn = linalg.outer(u_nodes[m][0], u_nodes[n][0], ndim=1) xx[...,i_start:i_end,j_start:j_end] = xm_xn xx[...,j_start:j_end,i_start:i_end] = misc.T(xm_xn) j_start = j_end i_start = i_end return [x, xx]
def test_message_to_child(self): """ Test the message from SumMultiply to its children. """ def compare_moments(u0, u1, *args): Y = SumMultiply(*args) u_Y = Y.get_moments() self.assertAllClose(u_Y[0], u0) self.assertAllClose(u_Y[1], u1) # Test constant parent y = np.random.randn(2,3,4) compare_moments(y, linalg.outer(y, y, ndim=2), 'ij->ij', y) # Do nothing for 2-D array Y = GaussianARD(np.random.randn(5,2,3), np.random.rand(5,2,3), plates=(5,), shape=(2,3)) y = Y.get_moments() compare_moments(y[0], y[1], 'ij->ij', Y) compare_moments(y[0], y[1], Y, [0,1], [0,1]) # Sum over the rows of a matrix Y = GaussianARD(np.random.randn(5,2,3), np.random.rand(5,2,3), plates=(5,), shape=(2,3)) y = Y.get_moments() mu = np.einsum('...ij->...j', y[0]) cov = np.einsum('...ijkl->...jl', y[1]) compare_moments(mu, cov, 'ij->j', Y) compare_moments(mu, cov, Y, [0,1], [1]) # Inner product of three vectors X1 = GaussianARD(np.random.randn(2), np.random.rand(2), plates=(), shape=(2,)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(6,1,2), np.random.rand(6,1,2), plates=(6,1), shape=(2,)) x2 = X2.get_moments() X3 = GaussianARD(np.random.randn(7,6,5,2), np.random.rand(7,6,5,2), plates=(7,6,5), shape=(2,)) x3 = X3.get_moments() mu = np.einsum('...i,...i,...i->...', x1[0], x2[0], x3[0]) cov = np.einsum('...ij,...ij,...ij->...', x1[1], x2[1], x3[1]) compare_moments(mu, cov, 'i,i,i', X1, X2, X3) compare_moments(mu, cov, 'i,i,i->', X1, X2, X3) compare_moments(mu, cov, X1, [9], X2, [9], X3, [9]) compare_moments(mu, cov, X1, [9], X2, [9], X3, [9], []) # Outer product of two vectors X1 = GaussianARD(np.random.randn(2), np.random.rand(2), plates=(5,), shape=(2,)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(6,1,2), np.random.rand(6,1,2), plates=(6,1), shape=(2,)) x2 = X2.get_moments() mu = np.einsum('...i,...j->...ij', x1[0], x2[0]) cov = np.einsum('...ik,...jl->...ijkl', x1[1], x2[1]) compare_moments(mu, cov, 'i,j->ij', X1, X2) compare_moments(mu, cov, X1, [9], X2, [7], [9,7]) # Matrix product Y1 = GaussianARD(np.random.randn(3,2), np.random.rand(3,2), plates=(), shape=(3,2)) y1 = Y1.get_moments() Y2 = GaussianARD(np.random.randn(5,2,3), np.random.rand(5,2,3), plates=(5,), shape=(2,3)) y2 = Y2.get_moments() mu = np.einsum('...ik,...kj->...ij', y1[0], y2[0]) cov = np.einsum('...ikjl,...kmln->...imjn', y1[1], y2[1]) compare_moments(mu, cov, 'ik,kj->ij', Y1, Y2) compare_moments(mu, cov, Y1, ['i','k'], Y2, ['k','j'], ['i','j']) # Trace of a matrix product Y1 = GaussianARD(np.random.randn(3,2), np.random.rand(3,2), plates=(), shape=(3,2)) y1 = Y1.get_moments() Y2 = GaussianARD(np.random.randn(5,2,3), np.random.rand(5,2,3), plates=(5,), shape=(2,3)) y2 = Y2.get_moments() mu = np.einsum('...ij,...ji->...', y1[0], y2[0]) cov = np.einsum('...ikjl,...kilj->...', y1[1], y2[1]) compare_moments(mu, cov, 'ij,ji', Y1, Y2) compare_moments(mu, cov, 'ij,ji->', Y1, Y2) compare_moments(mu, cov, Y1, ['i','j'], Y2, ['j','i']) compare_moments(mu, cov, Y1, ['i','j'], Y2, ['j','i'], []) # Vector-matrix-vector product X1 = GaussianARD(np.random.randn(3), np.random.rand(3), plates=(), shape=(3,)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(6,1,2), np.random.rand(6,1,2), plates=(6,1), shape=(2,)) x2 = X2.get_moments() Y = GaussianARD(np.random.randn(3,2), np.random.rand(3,2), plates=(), shape=(3,2)) y = Y.get_moments() mu = np.einsum('...i,...ij,...j->...', x1[0], y[0], x2[0]) cov = np.einsum('...ia,...ijab,...jb->...', x1[1], y[1], x2[1]) compare_moments(mu, cov, 'i,ij,j', X1, Y, X2) compare_moments(mu, cov, X1, [1], Y, [1,2], X2, [2]) # Complex sum-product of 0-D, 1-D, 2-D and 3-D arrays V = GaussianARD(np.random.randn(7,6,5), np.random.rand(7,6,5), plates=(7,6,5), shape=()) v = V.get_moments() X = GaussianARD(np.random.randn(6,1,2), np.random.rand(6,1,2), plates=(6,1), shape=(2,)) x = X.get_moments() Y = GaussianARD(np.random.randn(3,4), np.random.rand(3,4), plates=(5,), shape=(3,4)) y = Y.get_moments() Z = GaussianARD(np.random.randn(4,2,3), np.random.rand(4,2,3), plates=(6,5), shape=(4,2,3)) z = Z.get_moments() mu = np.einsum('...,...i,...kj,...jik->...k', v[0], x[0], y[0], z[0]) cov = np.einsum('...,...ia,...kjcb,...jikbac->...kc', v[1], x[1], y[1], z[1]) compare_moments(mu, cov, ',i,kj,jik->k', V, X, Y, Z) compare_moments(mu, cov, V, [], X, ['i'], Y, ['k','j'], Z, ['j','i','k'], ['k']) # # Gaussian-gamma parents # # Outer product of vectors X1 = GaussianARD(np.random.randn(2), np.random.rand(2), shape=(2,)) x1 = X1.get_moments() X2 = GaussianGamma( np.random.randn(6,1,2), random.covariance(2), np.random.rand(6,1), np.random.rand(6,1), plates=(6,1) ) x2 = X2.get_moments() Y = SumMultiply('i,j->ij', X1, X2) u = Y._message_to_child() y = np.einsum('...i,...j->...ij', x1[0], x2[0]) yy = np.einsum('...ik,...jl->...ijkl', x1[1], x2[1]) self.assertAllClose(u[0], y) self.assertAllClose(u[1], yy) self.assertAllClose(u[2], x2[2]) self.assertAllClose(u[3], x2[3]) pass
def test_message_to_child(self): """ Test the message to child of GaussianGammaISO node. """ # Simple test mu = np.array([1, 2, 3]) Lambda = np.identity(3) a = 2 b = 10 X_alpha = GaussianGammaISO(mu, Lambda, a, b) u = X_alpha._message_to_child() self.assertEqual(len(u), 4) tau = np.array(a / b) self.assertAllClose(u[0], tau[..., None] * mu) self.assertAllClose( u[1], (linalg.inv(Lambda) + tau[..., None, None] * linalg.outer(mu, mu))) self.assertAllClose(u[2], tau) self.assertAllClose(u[3], -np.log(b) + special.psi(a)) # Test with unknown parents mu = Gaussian(np.arange(3), 10 * np.identity(3)) Lambda = Wishart(10, np.identity(3)) a = 2 b = Gamma(3, 15) X_alpha = GaussianGammaISO(mu, Lambda, a, b) u = X_alpha._message_to_child() (mu, mumu) = mu._message_to_child() Cov_mu = mumu - linalg.outer(mu, mu) (Lambda, _) = Lambda._message_to_child() (b, _) = b._message_to_child() (tau, logtau) = Gamma( a, b + 0.5 * np.sum(Lambda * Cov_mu))._message_to_child() self.assertAllClose(u[0], tau[..., None] * mu) self.assertAllClose( u[1], (linalg.inv(Lambda) + tau[..., None, None] * linalg.outer(mu, mu))) self.assertAllClose(u[2], tau) self.assertAllClose(u[3], logtau) # Test with plates mu = Gaussian(np.reshape(np.arange(3 * 4), (4, 3)), 10 * np.identity(3), plates=(4, )) Lambda = Wishart(10, np.identity(3)) a = 2 b = Gamma(3, 15) X_alpha = GaussianGammaISO(mu, Lambda, a, b, plates=(4, )) u = X_alpha._message_to_child() (mu, mumu) = mu._message_to_child() Cov_mu = mumu - linalg.outer(mu, mu) (Lambda, _) = Lambda._message_to_child() (b, _) = b._message_to_child() (tau, logtau) = Gamma( a, b + 0.5 * np.sum(Lambda * Cov_mu, axis=(-1, -2)))._message_to_child() self.assertAllClose(u[0] * np.ones((4, 1)), np.ones((4, 1)) * tau[..., None] * mu) self.assertAllClose( u[1] * np.ones((4, 1, 1)), np.ones((4, 1, 1)) * (linalg.inv(Lambda) + tau[..., None, None] * linalg.outer(mu, mu))) self.assertAllClose(u[2] * np.ones(4), np.ones(4) * tau) self.assertAllClose(u[3] * np.ones(4), np.ones(4) * logtau) pass
def test_message_to_child(self): """ Test the message to child of GaussianGamma node. """ # Simple test mu = np.array([1,2,3]) Lambda = np.identity(3) a = 2 b = 10 X_alpha = GaussianGamma(mu, Lambda, a, b) u = X_alpha._message_to_child() self.assertEqual(len(u), 4) tau = np.array(a/b) self.assertAllClose(u[0], tau[...,None] * mu) self.assertAllClose(u[1], (linalg.inv(Lambda) + tau[...,None,None] * linalg.outer(mu, mu))) self.assertAllClose(u[2], tau) self.assertAllClose(u[3], -np.log(b) + special.psi(a)) # Test with unknown parents mu = Gaussian(np.arange(3), 10*np.identity(3)) Lambda = Wishart(10, np.identity(3)) a = 2 b = Gamma(3, 15) X_alpha = GaussianGamma(mu, Lambda, a, b) u = X_alpha._message_to_child() (mu, mumu) = mu._message_to_child() Cov_mu = mumu - linalg.outer(mu, mu) (Lambda, _) = Lambda._message_to_child() (b, _) = b._message_to_child() (tau, logtau) = Gamma(a, b + 0.5*np.sum(Lambda*Cov_mu))._message_to_child() self.assertAllClose(u[0], tau[...,None] * mu) self.assertAllClose(u[1], (linalg.inv(Lambda) + tau[...,None,None] * linalg.outer(mu, mu))) self.assertAllClose(u[2], tau) self.assertAllClose(u[3], logtau) # Test with plates mu = Gaussian(np.reshape(np.arange(3*4), (4,3)), 10*np.identity(3), plates=(4,)) Lambda = Wishart(10, np.identity(3)) a = 2 b = Gamma(3, 15) X_alpha = GaussianGamma(mu, Lambda, a, b, plates=(4,)) u = X_alpha._message_to_child() (mu, mumu) = mu._message_to_child() Cov_mu = mumu - linalg.outer(mu, mu) (Lambda, _) = Lambda._message_to_child() (b, _) = b._message_to_child() (tau, logtau) = Gamma(a, b + 0.5*np.sum(Lambda*Cov_mu, axis=(-1,-2)))._message_to_child() self.assertAllClose(u[0] * np.ones((4,1)), np.ones((4,1)) * tau[...,None] * mu) self.assertAllClose(u[1] * np.ones((4,1,1)), np.ones((4,1,1)) * (linalg.inv(Lambda) + tau[...,None,None] * linalg.outer(mu, mu))) self.assertAllClose(u[2] * np.ones(4), np.ones(4) * tau) self.assertAllClose(u[3] * np.ones(4), np.ones(4) * logtau) pass
def covariance(self): (x, xx) = self.get()[:2] Cov = xx - linalg.outer(x, x, ndim=1) return Cov
def test_message_to_child(self): """ Test the message from SumMultiply to its children. """ def compare_moments(u0, u1, *args): Y = SumMultiply(*args) u_Y = Y.get_moments() self.assertAllClose(u_Y[0], u0) self.assertAllClose(u_Y[1], u1) # Test constant parent y = np.random.randn(2, 3, 4) compare_moments(y, linalg.outer(y, y, ndim=2), 'ij->ij', y) # Do nothing for 2-D array Y = GaussianARD(np.random.randn(5, 2, 3), np.random.rand(5, 2, 3), plates=(5, ), shape=(2, 3)) y = Y.get_moments() compare_moments(y[0], y[1], 'ij->ij', Y) compare_moments(y[0], y[1], Y, [0, 1], [0, 1]) # Sum over the rows of a matrix Y = GaussianARD(np.random.randn(5, 2, 3), np.random.rand(5, 2, 3), plates=(5, ), shape=(2, 3)) y = Y.get_moments() mu = np.einsum('...ij->...j', y[0]) cov = np.einsum('...ijkl->...jl', y[1]) compare_moments(mu, cov, 'ij->j', Y) compare_moments(mu, cov, Y, [0, 1], [1]) # Inner product of three vectors X1 = GaussianARD(np.random.randn(2), np.random.rand(2), plates=(), shape=(2, )) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(6, 1, 2), np.random.rand(6, 1, 2), plates=(6, 1), shape=(2, )) x2 = X2.get_moments() X3 = GaussianARD(np.random.randn(7, 6, 5, 2), np.random.rand(7, 6, 5, 2), plates=(7, 6, 5), shape=(2, )) x3 = X3.get_moments() mu = np.einsum('...i,...i,...i->...', x1[0], x2[0], x3[0]) cov = np.einsum('...ij,...ij,...ij->...', x1[1], x2[1], x3[1]) compare_moments(mu, cov, 'i,i,i', X1, X2, X3) compare_moments(mu, cov, 'i,i,i->', X1, X2, X3) compare_moments(mu, cov, X1, [9], X2, [9], X3, [9]) compare_moments(mu, cov, X1, [9], X2, [9], X3, [9], []) # Outer product of two vectors X1 = GaussianARD(np.random.randn(2), np.random.rand(2), plates=(5, ), shape=(2, )) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(6, 1, 2), np.random.rand(6, 1, 2), plates=(6, 1), shape=(2, )) x2 = X2.get_moments() mu = np.einsum('...i,...j->...ij', x1[0], x2[0]) cov = np.einsum('...ik,...jl->...ijkl', x1[1], x2[1]) compare_moments(mu, cov, 'i,j->ij', X1, X2) compare_moments(mu, cov, X1, [9], X2, [7], [9, 7]) # Matrix product Y1 = GaussianARD(np.random.randn(3, 2), np.random.rand(3, 2), plates=(), shape=(3, 2)) y1 = Y1.get_moments() Y2 = GaussianARD(np.random.randn(5, 2, 3), np.random.rand(5, 2, 3), plates=(5, ), shape=(2, 3)) y2 = Y2.get_moments() mu = np.einsum('...ik,...kj->...ij', y1[0], y2[0]) cov = np.einsum('...ikjl,...kmln->...imjn', y1[1], y2[1]) compare_moments(mu, cov, 'ik,kj->ij', Y1, Y2) compare_moments(mu, cov, Y1, ['i', 'k'], Y2, ['k', 'j'], ['i', 'j']) # Trace of a matrix product Y1 = GaussianARD(np.random.randn(3, 2), np.random.rand(3, 2), plates=(), shape=(3, 2)) y1 = Y1.get_moments() Y2 = GaussianARD(np.random.randn(5, 2, 3), np.random.rand(5, 2, 3), plates=(5, ), shape=(2, 3)) y2 = Y2.get_moments() mu = np.einsum('...ij,...ji->...', y1[0], y2[0]) cov = np.einsum('...ikjl,...kilj->...', y1[1], y2[1]) compare_moments(mu, cov, 'ij,ji', Y1, Y2) compare_moments(mu, cov, 'ij,ji->', Y1, Y2) compare_moments(mu, cov, Y1, ['i', 'j'], Y2, ['j', 'i']) compare_moments(mu, cov, Y1, ['i', 'j'], Y2, ['j', 'i'], []) # Vector-matrix-vector product X1 = GaussianARD(np.random.randn(3), np.random.rand(3), plates=(), shape=(3, )) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(6, 1, 2), np.random.rand(6, 1, 2), plates=(6, 1), shape=(2, )) x2 = X2.get_moments() Y = GaussianARD(np.random.randn(3, 2), np.random.rand(3, 2), plates=(), shape=(3, 2)) y = Y.get_moments() mu = np.einsum('...i,...ij,...j->...', x1[0], y[0], x2[0]) cov = np.einsum('...ia,...ijab,...jb->...', x1[1], y[1], x2[1]) compare_moments(mu, cov, 'i,ij,j', X1, Y, X2) compare_moments(mu, cov, X1, [1], Y, [1, 2], X2, [2]) # Complex sum-product of 0-D, 1-D, 2-D and 3-D arrays V = GaussianARD(np.random.randn(7, 6, 5), np.random.rand(7, 6, 5), plates=(7, 6, 5), shape=()) v = V.get_moments() X = GaussianARD(np.random.randn(6, 1, 2), np.random.rand(6, 1, 2), plates=(6, 1), shape=(2, )) x = X.get_moments() Y = GaussianARD(np.random.randn(3, 4), np.random.rand(3, 4), plates=(5, ), shape=(3, 4)) y = Y.get_moments() Z = GaussianARD(np.random.randn(4, 2, 3), np.random.rand(4, 2, 3), plates=(6, 5), shape=(4, 2, 3)) z = Z.get_moments() mu = np.einsum('...,...i,...kj,...jik->...k', v[0], x[0], y[0], z[0]) cov = np.einsum('...,...ia,...kjcb,...jikbac->...kc', v[1], x[1], y[1], z[1]) compare_moments(mu, cov, ',i,kj,jik->k', V, X, Y, Z) compare_moments(mu, cov, V, [], X, ['i'], Y, ['k', 'j'], Z, ['j', 'i', 'k'], ['k']) # Test with constant nodes N = 10 D = 5 a = np.random.randn(N, D) B = Gaussian( np.random.randn(D), random.covariance(D), ) X = SumMultiply('i,i->', B, a) np.testing.assert_allclose( X.get_moments()[0], np.einsum('ni,i->n', a, B.get_moments()[0]), ) np.testing.assert_allclose( X.get_moments()[1], np.einsum('ni,nj,ij->n', a, a, B.get_moments()[1]), ) # # Gaussian-gamma parents # # Outer product of vectors X1 = GaussianARD(np.random.randn(2), np.random.rand(2), shape=(2, )) x1 = X1.get_moments() X2 = GaussianGamma(np.random.randn(6, 1, 2), random.covariance(2), np.random.rand(6, 1), np.random.rand(6, 1), plates=(6, 1)) x2 = X2.get_moments() Y = SumMultiply('i,j->ij', X1, X2) u = Y._message_to_child() y = np.einsum('...i,...j->...ij', x1[0], x2[0]) yy = np.einsum('...ik,...jl->...ijkl', x1[1], x2[1]) self.assertAllClose(u[0], y) self.assertAllClose(u[1], yy) self.assertAllClose(u[2], x2[2]) self.assertAllClose(u[3], x2[3]) # Test with constant nodes N = 10 M = 8 D = 5 a = np.random.randn(N, 1, D) B = GaussianGamma( np.random.randn(M, D), random.covariance(D, size=(M, )), np.random.rand(M), np.random.rand(M), ndim=1, ) X = SumMultiply('i,i->', B, a) np.testing.assert_allclose( X.get_moments()[0], np.einsum('nmi,mi->nm', a, B.get_moments()[0]), ) np.testing.assert_allclose( X.get_moments()[1], np.einsum('nmi,nmj,mij->nm', a, a, B.get_moments()[1]), ) np.testing.assert_allclose( X.get_moments()[2], B.get_moments()[2], ) np.testing.assert_allclose( X.get_moments()[3], B.get_moments()[3], ) pass
def _compute_bound(self, R, logdet=None, inv=None, Q=None, gradient=False, terms=False): """ Rotate q(X) and q(alpha). Assume: p(X|alpha) = prod_m N(x_m|0,diag(alpha)) p(alpha) = prod_d G(a_d,b_d) """ ## R = self._full_rotation_matrix(R) ## if inv is not None: ## inv = self._full_rotation_matrix(inv) # # Transform the distributions and moments # plates_alpha = self.plates_alpha plates_X = self.plates_X # Compute rotated second moment if self.plate_axis is not None: # The plate axis has been moved to be the last plate axis if Q is None: raise ValueError("Plates should be rotated but no Q give") # Transform covariance sumQ = np.sum(Q, axis=0) QCovQ = sumQ[:, None, None] ** 2 * self.CovX # Rotate plates if self.precompute: QX_QX = np.einsum("...kalb,...ik,...il->...iab", self.X_X, Q, Q) XX = QX_QX + QCovQ XX = sum_to_plates(XX, plates_alpha[:-1], ndim=2) Xmu = np.einsum("...kaib,...ik->...iab", self.X_mu, Q) Xmu = sum_to_plates(Xmu, plates_alpha[:-1], ndim=2) else: X = self.X mu = self.mu QX = np.einsum("...ik,...kj->...ij", Q, X) XX = sum_to_plates(QCovQ, plates_alpha[:-1], ndim=2) + sum_to_plates( linalg.outer(QX, QX), plates_alpha[:-1], ndim=2, plates_from=plates_X ) Xmu = sum_to_plates(linalg.outer(QX, self.mu), plates_alpha[:-1], ndim=2, plates_from=plates_X) mu2 = self.mu2 D = np.shape(XX)[-1] logdet_Q = D * np.log(np.abs(sumQ)) else: XX = self.XX mu2 = self.mu2 Xmu = self.Xmu logdet_Q = 0 # Compute transformed moments # mu2 = np.einsum('...ii->...i', mu2) RXmu = np.einsum("...ik,...ki->...i", R, Xmu) RXX = np.einsum("...ik,...kj->...ij", R, XX) RXXR = np.einsum("...ik,...ik->...i", RXX, R) # <(X-mu) * (X-mu)'>_R XmuXmu = RXXR - 2 * RXmu + mu2 D = np.shape(R)[0] # Compute q(alpha) if self.update_alpha: # Parameters a0 = self.a0 b0 = self.b0 a = self.a b = b0 + 0.5 * sum_to_plates(XmuXmu, plates_alpha, plates_from=None, ndim=0) # Some expectations alpha = a / b logb = np.log(b) logalpha = -logb # + const b0_alpha = b0 * alpha a0_logalpha = a0 * logalpha else: alpha = self.alpha logalpha = 0 # # Compute the cost # def sum_plates(V, *plates): full_plates = misc.broadcasted_shape(*plates) r = self.node_X.broadcasting_multiplier(full_plates, np.shape(V)) return r * np.sum(V) XmuXmu_alpha = XmuXmu * alpha if logdet is None: logdet_R = np.linalg.slogdet(R)[1] inv_R = np.linalg.inv(R) else: logdet_R = logdet inv_R = inv # Compute entropy H(X) logH_X = random.gaussian_entropy(-2 * sum_plates(logdet_R + logdet_Q, plates_X), 0) # Compute <log p(X|alpha)> logp_X = random.gaussian_logpdf( sum_plates(XmuXmu_alpha, plates_alpha[:-1] + [D]), 0, 0, sum_plates(logalpha, plates_X + [D]), 0 ) if self.update_alpha: # Compute entropy H(alpha) # This cancels out with the log(alpha) term in log(p(alpha)) logH_alpha = 0 # Compute <log p(alpha)> logp_alpha = random.gamma_logpdf( sum_plates(b0_alpha, plates_alpha), 0, sum_plates(a0_logalpha, plates_alpha), 0, 0 ) else: logH_alpha = 0 logp_alpha = 0 # Compute the bound if terms: bound = {self.node_X: logp_X + logH_X} if self.update_alpha: bound.update({self.node_alpha: logp_alpha + logH_alpha}) else: bound = 0 + logp_X + logp_alpha + logH_X + logH_alpha if not gradient: return bound # # Compute the gradient with respect R # broadcasting_multiplier = self.node_X.broadcasting_multiplier def sum_plates(V, plates): ones = np.ones(np.shape(R)) r = broadcasting_multiplier(plates, np.shape(V)[:-2]) return r * misc.sum_multiply(V, ones, axis=(-1, -2), sumaxis=False, keepdims=False) D_XmuXmu = 2 * RXX - 2 * gaussian.transpose_covariance(Xmu) DXmuXmu_alpha = np.einsum("...i,...ij->...ij", alpha, D_XmuXmu) if self.update_alpha: D_b = 0.5 * D_XmuXmu XmuXmu_Dalpha = np.einsum( "...i,...i,...i,...ij->...ij", sum_to_plates(XmuXmu, plates_alpha, plates_from=None, ndim=0), alpha, -1 / b, D_b, ) D_b0_alpha = np.einsum("...i,...i,...i,...ij->...ij", b0, alpha, -1 / b, D_b) D_logb = np.einsum("...i,...ij->...ij", 1 / b, D_b) D_logalpha = -D_logb D_a0_logalpha = a0 * D_logalpha else: XmuXmu_Dalpha = 0 D_logalpha = 0 D_XmuXmu_alpha = DXmuXmu_alpha + XmuXmu_Dalpha D_logR = inv_R.T # Compute dH(X) dlogH_X = random.gaussian_entropy(-2 * sum_plates(D_logR, plates_X), 0) # Compute d<log p(X|alpha)> dlogp_X = random.gaussian_logpdf( sum_plates(D_XmuXmu_alpha, plates_alpha[:-1]), 0, 0, (sum_plates(D_logalpha, plates_X) * broadcasting_multiplier((D,), plates_alpha[-1:])), 0, ) if self.update_alpha: # Compute dH(alpha) # This cancels out with the log(alpha) term in log(p(alpha)) dlogH_alpha = 0 # Compute d<log p(alpha)> dlogp_alpha = random.gamma_logpdf( sum_plates(D_b0_alpha, plates_alpha[:-1]), 0, sum_plates(D_a0_logalpha, plates_alpha[:-1]), 0, 0 ) else: dlogH_alpha = 0 dlogp_alpha = 0 if terms: raise NotImplementedError() dR_bound = {self.node_X: dlogp_X + dlogH_X} if self.update_alpha: dR_bound.update({self.node_alpha: dlogp_alpha + dlogH_alpha}) else: dR_bound = 0 * dlogp_X + dlogp_X + dlogp_alpha + dlogH_X + dlogH_alpha if self.subset: indices = np.ix_(self.subset, self.subset) dR_bound = dR_bound[indices] if self.plate_axis is None: return (bound, dR_bound) # # Compute the gradient with respect to Q (if Q given) # # Some pre-computations Q_RCovR = np.einsum("...ik,...kl,...il,...->...i", R, self.CovX, R, sumQ) if self.precompute: Xr_rX = np.einsum("...abcd,...jb,...jd->...jac", self.X_X, R, R) QXr_rX = np.einsum("...akj,...ik->...aij", Xr_rX, Q) RX_mu = np.einsum("...jk,...akbj->...jab", R, self.X_mu) else: RX = np.einsum("...ik,...k->...i", R, X) QXR = np.einsum("...ik,...kj->...ij", Q, RX) QXr_rX = np.einsum("...ik,...jk->...kij", QXR, RX) RX_mu = np.einsum("...ik,...jk->...kij", RX, mu) QXr_rX = sum_to_plates(QXr_rX, plates_alpha[:-2], ndim=3, plates_from=plates_X[:-1]) RX_mu = sum_to_plates(RX_mu, plates_alpha[:-2], ndim=3, plates_from=plates_X[:-1]) def psi(v): """ Compute: d/dQ 1/2*trace(diag(v)*<(X-mu)*(X-mu)>) = Q*<X>'*R'*diag(v)*R*<X> + ones * Q diag( tr(R'*diag(v)*R*Cov) ) + mu*diag(v)*R*<X> """ # Precompute all terms to plates_alpha because v has shape # plates_alpha. # Gradient of 0.5*v*<x>*<x> v_QXrrX = np.einsum("...kij,...ik->...ij", QXr_rX, v) # Gradient of 0.5*v*Cov Q_tr_R_v_R_Cov = np.einsum("...k,...k->...", Q_RCovR, v)[..., None, :] # Gradient of mu*v*x mu_v_R_X = np.einsum("...ik,...kji->...ij", v, RX_mu) return v_QXrrX + Q_tr_R_v_R_Cov - mu_v_R_X def sum_plates(V, plates): ones = np.ones(np.shape(Q)) r = self.node_X.broadcasting_multiplier(plates, np.shape(V)[:-2]) return r * misc.sum_multiply(V, ones, axis=(-1, -2), sumaxis=False, keepdims=False) if self.update_alpha: D_logb = psi(1 / b) XX_Dalpha = -psi(alpha / b * sum_to_plates(XmuXmu, plates_alpha)) D_logalpha = -D_logb else: XX_Dalpha = 0 D_logalpha = 0 DXX_alpha = 2 * psi(alpha) D_XX_alpha = DXX_alpha + XX_Dalpha D_logdetQ = D / sumQ N = np.shape(Q)[-1] # Compute dH(X) dQ_logHX = random.gaussian_entropy(-2 * sum_plates(D_logdetQ, plates_X[:-1]), 0) # Compute d<log p(X|alpha)> dQ_logpX = random.gaussian_logpdf( sum_plates(D_XX_alpha, plates_alpha[:-2]), 0, 0, (sum_plates(D_logalpha, plates_X[:-1]) * broadcasting_multiplier((N, D), plates_alpha[-2:])), 0, ) if self.update_alpha: D_alpha = -psi(alpha / b) D_b0_alpha = b0 * D_alpha D_a0_logalpha = a0 * D_logalpha # Compute dH(alpha) # This cancels out with the log(alpha) term in log(p(alpha)) dQ_logHalpha = 0 # Compute d<log p(alpha)> dQ_logpalpha = random.gamma_logpdf( sum_plates(D_b0_alpha, plates_alpha[:-2]), 0, sum_plates(D_a0_logalpha, plates_alpha[:-2]), 0, 0 ) else: dQ_logHalpha = 0 dQ_logpalpha = 0 if terms: raise NotImplementedError() dQ_bound = {self.node_X: dQ_logpX + dQ_logHX} if self.update_alpha: dQ_bound.update({self.node_alpha: dQ_logpalpha + dQ_logHalpha}) else: dQ_bound = 0 * dQ_logpX + dQ_logpX + dQ_logpalpha + dQ_logHX + dQ_logHalpha return (bound, dR_bound, dQ_bound)
def setup(self, plate_axis=None): """ This method should be called just before optimization. For efficiency, sum over axes that are not in mu, alpha nor rotation. If using Q, set rotate_plates to True. """ # Store the original plate_axis parameter for later use in other methods self.plate_axis = plate_axis # Manipulate the plate_axis parameter to suit the needs of this method if plate_axis is not None: if not isinstance(plate_axis, int): raise ValueError("Plate axis must be integer") if plate_axis >= 0: plate_axis -= len(self.node_X.plates) if plate_axis < -len(self.node_X.plates) or plate_axis >= 0: raise ValueError("Axis out of bounds") plate_axis -= self.ndim - 1 # Why -1? Because one axis is preserved! # Get the mean parameter. It will not be rotated. This assumes that mu # and alpha are really independent. (alpha_mu, alpha_mu2, alpha, _) = self.node_parent.get_moments() (X, XX) = self.node_X.get_moments() # mu = alpha_mu / alpha mu2 = alpha_mu2 / alpha # For simplicity, force mu to have the same shape as X mu = mu * np.ones(self.node_X.dims[0]) mu2 = mu2 * np.ones(self.node_X.dims[0]) ## (mu, mumu) = gaussian.reshape_gaussian_array(self.node_mu.dims[0], ## self.node_X.dims[0], ## mu, ## mumu) # Take diagonal of covariances to variances for axes that are not in R # (and move those axes to be the last) XX = covariance_to_variance(XX, ndim=self.ndim, covariance_axis=self.axis) ## mumu = covariance_to_variance(mumu, ## ndim=self.ndim, ## covariance_axis=self.axis) # Move axes of X and mu and compute their outer product X = misc.moveaxis(X, self.axis, -1) mu = misc.moveaxis(mu, self.axis, -1) mu2 = misc.moveaxis(mu2, self.axis, -1) Xmu = linalg.outer(X, mu, ndim=1) D = np.shape(X)[-1] # Move axes of alpha related variables def safe_move_axis(x): if np.ndim(x) >= -self.axis: return misc.moveaxis(x, self.axis, -1) else: return x[..., np.newaxis] if self.update_alpha: a = safe_move_axis(self.node_alpha.phi[1]) a0 = safe_move_axis(self.node_alpha.parents[0].get_moments()[0]) b0 = safe_move_axis(self.node_alpha.parents[1].get_moments()[0]) plates_alpha = list(self.node_alpha.plates) else: alpha = safe_move_axis(self.node_parent.get_moments()[2]) plates_alpha = list(self.node_parent.get_shape(2)) # Move plates of alpha for R if len(plates_alpha) >= -self.axis: plate = plates_alpha.pop(self.axis) plates_alpha.append(plate) else: plates_alpha.append(1) plates_X = list(self.node_X.get_shape(0)) plates_X.pop(self.axis) def sum_to_alpha(V, ndim=2): # TODO/FIXME: This could be improved so that it is not required to # explicitly repeat to alpha plates. Multiplying by ones was just a # simple bug fix. return sum_to_plates( V * np.ones(plates_alpha[:-1] + ndim * [1]), plates_alpha[:-1], ndim=ndim, plates_from=plates_X ) if plate_axis is not None: # Move plate axis just before the rotated dimensions (which are # last) def safe_move_plate_axis(x, ndim): if np.ndim(x) - ndim >= -plate_axis: return misc.moveaxis(x, plate_axis - ndim, -ndim - 1) else: inds = (Ellipsis, None) + ndim * (slice(None),) return x[inds] X = safe_move_plate_axis(X, 1) mu = safe_move_plate_axis(mu, 1) XX = safe_move_plate_axis(XX, 2) mu2 = safe_move_plate_axis(mu2, 1) if self.update_alpha: a = safe_move_plate_axis(a, 1) a0 = safe_move_plate_axis(a0, 1) b0 = safe_move_plate_axis(b0, 1) else: alpha = safe_move_plate_axis(alpha, 1) # Move plates of X and alpha plate = plates_X.pop(plate_axis) plates_X.append(plate) if len(plates_alpha) >= -plate_axis + 1: plate = plates_alpha.pop(plate_axis - 1) else: plate = 1 plates_alpha = plates_alpha[:-1] + [plate] + plates_alpha[-1:] CovX = XX - linalg.outer(X, X) self.CovX = sum_to_plates(CovX, plates_alpha[:-2], ndim=3, plates_from=plates_X[:-1]) # Broadcast mumu to ensure shape # mumu = np.ones(np.shape(XX)[-3:]) * mumu mu2 = mu2 * np.ones(np.shape(X)[-2:]) self.mu2 = sum_to_alpha(mu2, ndim=1) if self.precompute: # Precompute some stuff for the gradient of plate rotation # # NOTE: These terms may require a lot of memory if alpha has the # same or almost the same plates as X. self.X_X = sum_to_plates( X[..., :, :, None, None] * X[..., None, None, :, :], plates_alpha[:-2], ndim=4, plates_from=plates_X[:-1], ) self.X_mu = sum_to_plates( X[..., :, :, None, None] * mu[..., None, None, :, :], plates_alpha[:-2], ndim=4, plates_from=plates_X[:-1], ) else: self.X = X self.mu = mu else: # Sum axes that are not in the plates of alpha self.XX = sum_to_alpha(XX) self.mu2 = sum_to_alpha(mu2, ndim=1) self.Xmu = sum_to_alpha(Xmu) if self.update_alpha: self.a = a self.a0 = a0 self.b0 = b0 else: self.alpha = alpha self.plates_X = plates_X self.plates_alpha = plates_alpha # Take only a subset of the matrix for rotation if self.subset is not None: if self.precompute: raise NotImplementedError("Precomputation not implemented when " "using a subset") # from X self.X = self.X[..., self.subset] self.mu2 = self.mu2[..., self.subset] if plate_axis is not None: # from CovX inds = [] for i in range(np.ndim(self.CovX) - 2): inds.append(range(np.shape(self.CovX)[i])) inds.append(self.subset) inds.append(self.subset) indices = np.ix_(*inds) self.CovX = self.CovX[indices] # from mu self.mu = self.mu[..., self.subset] else: # from XX inds = [] for i in range(np.ndim(self.XX) - 2): inds.append(range(np.shape(self.XX)[i])) inds.append(self.subset) inds.append(self.subset) indices = np.ix_(*inds) self.XX = self.XX[indices] # from Xmu self.Xmu = self.Xmu[..., self.subset] # from alpha if self.update_alpha: if np.shape(self.a)[-1] > 1: self.a = self.a[..., self.subset] if np.shape(self.a0)[-1] > 1: self.a0 = self.a0[..., self.subset] if np.shape(self.b0)[-1] > 1: self.b0 = self.b0[..., self.subset] else: if np.shape(self.alpha)[-1] > 1: self.alpha = self.alpha[..., self.subset] self.plates_alpha[-1] = min(self.plates_alpha[-1], len(self.subset))