Example #1
0
    def _unary_complexity_penalty(self):
        """Computes the complexity penalty for unary potentials.

    This function computes KL-divergence between prior and variational 
    distribution over the values of GPs at inducing inputs.

    Returns:
      A scalar `tf.Tensor` containing the complexity penalty for GPs 
      determining unary potentials.
    """
        # TODO: test this
        mus = self.mus
        sigma_ls = _kron_tril(self.sigma_ls)
        sigmas = ops.tt_tt_matmul(sigma_ls, ops.transpose(sigma_ls))
        sigmas_logdet = _kron_logdet(sigma_ls)

        K_mms = self._K_mms()
        K_mms_inv = kron.inv(K_mms)
        K_mms_logdet = kron.slog_determinant(K_mms)[1]

        penalty = 0
        penalty += -K_mms_logdet
        penalty += sigmas_logdet
        penalty += -ops.tt_tt_flat_inner(sigmas, K_mms_inv)
        penalty += -ops.tt_tt_flat_inner(mus, ops.tt_tt_matmul(K_mms_inv, mus))
        return tf.reduce_sum(penalty) / 2
Example #2
0
    def complexity_penalty(self):
        """Returns the complexity penalty term for ELBO. 
        """
        mus = self.mus
        sigma_ls = _kron_tril(self.sigma_ls)
        sigmas = ops.tt_tt_matmul(sigma_ls, ops.transpose(sigma_ls))
        sigmas_logdet = _kron_logdet(sigma_ls)

        K_mms = self._K_mms()
        K_mms_inv = kron.inv(K_mms)
        K_mms_logdet = kron.slog_determinant(K_mms)[1]

        penalty = 0
        penalty += - K_mms_logdet
        penalty += sigmas_logdet
        penalty += - ops.tt_tt_flat_inner(sigmas, K_mms_inv)
        penalty += - ops.tt_tt_flat_inner(mus, 
                               ops.tt_tt_matmul(K_mms_inv, mus))
        return penalty / 2
Example #3
0
  def complexity_penalty(self):
    """Returns the complexity penalty term for ELBO of different GP models. 
    """
    mu = self.mu
    sigma_l = _kron_tril(self.sigma_l)
    sigma = ops.tt_tt_matmul(sigma_l, ops.transpose(sigma_l))
    sigma_logdet = _kron_logdet(sigma_l)

    K_mm = self.K_mm()
    K_mm_inv = kron.inv(K_mm)
    K_mm_logdet = kron.slog_determinant(K_mm)[1]

    elbo = 0
    elbo += - K_mm_logdet
    elbo += sigma_logdet
    elbo += - ops.tt_tt_flat_inner(sigma, K_mm_inv)
    elbo += - ops.tt_tt_flat_inner(mu, 
                           ops.tt_tt_matmul(K_mm_inv, mu))
    return elbo / 2
Example #4
0
  def predict_process_value(self, x, with_variance=False):
    """Predicts the value of the process at point x.

    Args:
      x: data features
      with_variance: if True, returns process variance at x
    """
    mu = self.mu
    w = self.inputs.interpolate_on_batch(self.cov.project(x))

    mean = ops.tt_tt_flat_inner(w, mu)
    if not with_variance:
      return mean
    K_mm = self.K_mm()
    variance = self.cov.cov_0() 
    sigma_l_w = ops.tt_tt_matmul(ops.transpose(self.sigma_l), w)
    variance += ops.tt_tt_flat_inner(sigma_l_w, sigma_l_w)
    variance -= ops.tt_tt_flat_inner(w, ops.tt_tt_matmul(K_mm, w))
    return mean, variance
Example #5
0
  def elbo(self, w, y):
    '''Evidence lower bound.
    
    Args:
      w: interpolation vector for the current batch.
      y: target values for the current batch.
    '''
      
    l = tf.cast(tf.shape(y)[0], tf.float64) # batch size
    N = tf.cast(self.N, dtype=tf.float64) 

    y = tf.reshape(y, [-1])
    
    mu = self.gp.mu
    sigma_l = _kron_tril(self.gp.sigma_l)
    sigma = ops.tt_tt_matmul(sigma_l, ops.transpose(sigma_l))
    
    sigma_n = self.gp.cov.noise_variance()
    
    K_mm = self.gp.K_mm()

    tilde_K_ii = l * self.gp.cov.cov_0()
    tilde_K_ii -= tf.reduce_sum(ops.tt_tt_flat_inner(w, 
                                         ops.tt_tt_matmul(K_mm, w)))

    elbo = 0
    elbo -= tf.reduce_sum(tf.square(y - ops.tt_tt_flat_inner(w, mu)))
    elbo -= tilde_K_ii 
    # TODO: wtf?
#    elbo -= ops.tt_tt_flat_inner(w, ops.tt_tt_matmul(sigma, w))
    elbo -= tf.reduce_sum(ops.tt_tt_flat_inner(w, ops.tt_tt_matmul(sigma, w)))
    elbo /= 2 * sigma_n**2 * l
    elbo += self.gp.complexity_penalty() / N
    # TODO: wtf?
#    elbo -=  tf.log(tf.abs(sigma_n))  
    return -elbo[0]