def _unary_complexity_penalty(self): """Computes the complexity penalty for unary potentials. This function computes KL-divergence between prior and variational distribution over the values of GPs at inducing inputs. Returns: A scalar `tf.Tensor` containing the complexity penalty for GPs determining unary potentials. """ # TODO: test this mus = self.mus sigma_ls = _kron_tril(self.sigma_ls) sigmas = ops.tt_tt_matmul(sigma_ls, ops.transpose(sigma_ls)) sigmas_logdet = _kron_logdet(sigma_ls) K_mms = self._K_mms() K_mms_inv = kron.inv(K_mms) K_mms_logdet = kron.slog_determinant(K_mms)[1] penalty = 0 penalty += -K_mms_logdet penalty += sigmas_logdet penalty += -ops.tt_tt_flat_inner(sigmas, K_mms_inv) penalty += -ops.tt_tt_flat_inner(mus, ops.tt_tt_matmul(K_mms_inv, mus)) return tf.reduce_sum(penalty) / 2
def testInv(self): # Tests the inv function initializer = initializers.random_matrix(((2, 3, 2), (2, 3, 2)), tt_rank=1) kron_mat = variables.get_variable('kron_mat', initializer=initializer) init_op = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init_op) desired = np.linalg.inv(ops.full(kron_mat).eval()) actual = ops.full(kr.inv(kron_mat)).eval() self.assertAllClose(desired, actual)
def testInv(self): # Tests the inv function initializer = initializers.random_matrix(((2, 3, 2), (2, 3, 2)), tt_rank=1, dtype=self.dtype) kron_mat = variables.get_variable('kron_mat', initializer=initializer) init_op = tf.compat.v1.global_variables_initializer() self.evaluate(init_op) desired = np.linalg.inv(self.evaluate(ops.full(kron_mat))) actual = self.evaluate(ops.full(kr.inv(kron_mat))) self.assertAllClose(desired, actual)
def testInv(self): # Tests the inv function initializer = initializers.random_matrix_batch(((2, 3, 2), (2, 3, 2)), tt_rank=1, batch_size=3, dtype=self.dtype) kron_mat_batch = variables.get_variable('kron_mat_batch', initializer=initializer) init_op = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init_op) desired = np.linalg.inv(ops.full(kron_mat_batch).eval()) actual = ops.full(kr.inv(kron_mat_batch)).eval() self.assertAllClose(desired, actual, atol=1e-4)
def complexity_penalty(self): """Returns the complexity penalty term for ELBO. """ mus = self.mus sigma_ls = _kron_tril(self.sigma_ls) sigmas = ops.tt_tt_matmul(sigma_ls, ops.transpose(sigma_ls)) sigmas_logdet = _kron_logdet(sigma_ls) K_mms = self._K_mms() K_mms_inv = kron.inv(K_mms) K_mms_logdet = kron.slog_determinant(K_mms)[1] penalty = 0 penalty += - K_mms_logdet penalty += sigmas_logdet penalty += - ops.tt_tt_flat_inner(sigmas, K_mms_inv) penalty += - ops.tt_tt_flat_inner(mus, ops.tt_tt_matmul(K_mms_inv, mus)) return penalty / 2
def complexity_penalty(self): """Returns the complexity penalty term for ELBO of different GP models. """ mu = self.mu sigma_l = _kron_tril(self.sigma_l) sigma = ops.tt_tt_matmul(sigma_l, ops.transpose(sigma_l)) sigma_logdet = _kron_logdet(sigma_l) K_mm = self.K_mm() K_mm_inv = kron.inv(K_mm) K_mm_logdet = kron.slog_determinant(K_mm)[1] elbo = 0 elbo += - K_mm_logdet elbo += sigma_logdet elbo += - ops.tt_tt_flat_inner(sigma, K_mm_inv) elbo += - ops.tt_tt_flat_inner(mu, ops.tt_tt_matmul(K_mm_inv, mu)) return elbo / 2