def _unary_complexity_penalty(self): """Computes the complexity penalty for unary potentials. This function computes KL-divergence between prior and variational distribution over the values of GPs at inducing inputs. Returns: A scalar `tf.Tensor` containing the complexity penalty for GPs determining unary potentials. """ # TODO: test this mus = self.mus sigma_ls = _kron_tril(self.sigma_ls) sigmas = ops.tt_tt_matmul(sigma_ls, ops.transpose(sigma_ls)) sigmas_logdet = _kron_logdet(sigma_ls) K_mms = self._K_mms() K_mms_inv = kron.inv(K_mms) K_mms_logdet = kron.slog_determinant(K_mms)[1] penalty = 0 penalty += -K_mms_logdet penalty += sigmas_logdet penalty += -ops.tt_tt_flat_inner(sigmas, K_mms_inv) penalty += -ops.tt_tt_flat_inner(mus, ops.tt_tt_matmul(K_mms_inv, mus)) return tf.reduce_sum(penalty) / 2
def _predict_process_values(self, x, with_variance=False, test=False): w = self.inputs.interpolate_on_batch(self.cov.project(x, test=test)) mean = batch_ops.pairwise_flat_inner(w, self.mus) if not with_variance: return mean K_mms = self._K_mms() sigma_ls = _kron_tril(self.sigma_ls) variances = [] sigmas = ops.tt_tt_matmul(sigma_ls, ops.transpose(sigma_ls)) variances = pairwise_quadratic_form(sigmas, w, w) variances -= pairwise_quadratic_form(K_mms, w, w) variances += self.cov.cov_0()[None, :] return mean, variances
def complexity_penalty(self): """Returns the complexity penalty term for ELBO. """ mus = self.mus sigma_ls = _kron_tril(self.sigma_ls) sigmas = ops.tt_tt_matmul(sigma_ls, ops.transpose(sigma_ls)) sigmas_logdet = _kron_logdet(sigma_ls) K_mms = self._K_mms() K_mms_inv = kron.inv(K_mms) K_mms_logdet = kron.slog_determinant(K_mms)[1] penalty = 0 penalty += - K_mms_logdet penalty += sigmas_logdet penalty += - ops.tt_tt_flat_inner(sigmas, K_mms_inv) penalty += - ops.tt_tt_flat_inner(mus, ops.tt_tt_matmul(K_mms_inv, mus)) return penalty / 2
def complexity_penalty(self): """Returns the complexity penalty term for ELBO of different GP models. """ mu = self.mu sigma_l = _kron_tril(self.sigma_l) sigma = ops.tt_tt_matmul(sigma_l, ops.transpose(sigma_l)) sigma_logdet = _kron_logdet(sigma_l) K_mm = self.K_mm() K_mm_inv = kron.inv(K_mm) K_mm_logdet = kron.slog_determinant(K_mm)[1] elbo = 0 elbo += - K_mm_logdet elbo += sigma_logdet elbo += - ops.tt_tt_flat_inner(sigma, K_mm_inv) elbo += - ops.tt_tt_flat_inner(mu, ops.tt_tt_matmul(K_mm_inv, mu)) return elbo / 2
def elbo(self, w, y): '''Evidence lower bound. Args: w: interpolation vector for the current batch. y: target values for the current batch. ''' l = tf.cast(tf.shape(y)[0], tf.float64) # batch size N = tf.cast(self.N, dtype=tf.float64) y = tf.reshape(y, [-1]) mu = self.gp.mu sigma_l = _kron_tril(self.gp.sigma_l) sigma = ops.tt_tt_matmul(sigma_l, ops.transpose(sigma_l)) sigma_n = self.gp.cov.noise_variance() K_mm = self.gp.K_mm() tilde_K_ii = l * self.gp.cov.cov_0() tilde_K_ii -= tf.reduce_sum(ops.tt_tt_flat_inner(w, ops.tt_tt_matmul(K_mm, w))) elbo = 0 elbo -= tf.reduce_sum(tf.square(y - ops.tt_tt_flat_inner(w, mu))) elbo -= tilde_K_ii # TODO: wtf? # elbo -= ops.tt_tt_flat_inner(w, ops.tt_tt_matmul(sigma, w)) elbo -= tf.reduce_sum(ops.tt_tt_flat_inner(w, ops.tt_tt_matmul(sigma, w))) elbo /= 2 * sigma_n**2 * l elbo += self.gp.complexity_penalty() / N # TODO: wtf? # elbo -= tf.log(tf.abs(sigma_n)) return -elbo[0]