예제 #1
0
    def _log_prob(self, given):
        mean, u_tril, v_tril = (self.path_param(self.mean),
                                self.path_param(self.u_tril),
                                self.path_param(self.v_tril))

        log_det_u = 2 * tf.reduce_sum(tf.log(tf.matrix_diag_part(u_tril)),
                                      axis=-1)
        log_det_v = 2 * tf.reduce_sum(tf.log(tf.matrix_diag_part(v_tril)),
                                      axis=-1)
        n_row = tf.cast(self._n_row, self.dtype)
        n_col = tf.cast(self._n_col, self.dtype)
        logZ = - (n_row * n_col) / 2. * \
            tf.log(2. * tf.constant(np.pi, dtype=self.dtype)) - \
            n_row / 2. * log_det_v - n_col / 2. * log_det_u
        # logZ.shape == batch_shape
        if self._check_numerics:
            logZ = tf.check_numerics(logZ, "log[det(Cov)]")

        y = given - mean
        y_with_last_dim_changed = tf.expand_dims(tf.ones(tf.shape(y)[:-1]), -1)
        Lu, _ = maybe_explicit_broadcast(u_tril, y_with_last_dim_changed,
                                         'MatrixVariateNormalCholesky.u_tril',
                                         'expand_dims(given, -1)')
        y_with_sec_last_dim_changed = tf.expand_dims(
            tf.ones(tf.concat(
                [tf.shape(y)[:-2], tf.shape(y)[-1:]], axis=0)), -1)
        Lv, _ = maybe_explicit_broadcast(v_tril, y_with_sec_last_dim_changed,
                                         'MatrixVariateNormalCholesky.v_tril',
                                         'expand_dims(given, -1)')
        x_Lb_inv_t = tf.matrix_triangular_solve(Lu, y, lower=True)
        x_t = tf.matrix_triangular_solve(Lv,
                                         tf.matrix_transpose(x_Lb_inv_t),
                                         lower=True)
        stoc_dist = -0.5 * tf.reduce_sum(tf.square(x_t), [-1, -2])
        return logZ + stoc_dist
예제 #2
0
 def _log_prob(self, given):
     given = tf.cast(given, self.param_dtype)
     given, logits = maybe_explicit_broadcast(
         given, self.logits, 'given', 'logits')
     if self.normalize_logits:
         logits = logits - tf.reduce_logsumexp(
             logits, axis=-1, keepdims=True)
     log_p = tf.reduce_sum(given * logits, -1)
     return log_p
예제 #3
0
 def _log_prob(self, given):
     given = tf.cast(given, self.param_dtype)
     given, logits = maybe_explicit_broadcast(given, self.logits, 'given',
                                              'logits')
     normalized_logits = logits - tf.reduce_logsumexp(
         logits, axis=-1, keep_dims=True)
     n = tf.cast(self.n_experiments, self.param_dtype)
     log_p = log_combination(n, given) + \
         tf.reduce_sum(given * normalized_logits, -1)
     return log_p
예제 #4
0
 def _log_prob(self, given):
     given, alpha = maybe_explicit_broadcast(given, self.alpha, 'given',
                                             'alpha')
     lbeta_alpha = tf.lbeta(alpha)
     # fix of no static shape inference for tf.lbeta
     if alpha.get_shape():
         lbeta_alpha.set_shape(alpha.get_shape()[:-1])
     log_given = tf.log(given)
     if self._check_numerics:
         lbeta_alpha = tf.check_numerics(lbeta_alpha, "lbeta(alpha)")
         log_given = tf.check_numerics(log_given, "log(given)")
     log_p = -lbeta_alpha + tf.reduce_sum((alpha - 1) * log_given, -1)
     return log_p
예제 #5
0
 def _log_prob(self, given):
     given = tf.cast(given, self.param_dtype)
     given, logits = maybe_explicit_broadcast(
         given, self.logits, 'given', 'logits')
     if (given.get_shape().ndims == 2) or (logits.get_shape().ndims == 2):
         given_flat = given
         logits_flat = logits
     else:
         given_flat = tf.reshape(given, [-1, self.n_categories])
         logits_flat = tf.reshape(logits, [-1, self.n_categories])
     log_p_flat = -tf.nn.softmax_cross_entropy_with_logits(
         labels=given_flat, logits=logits_flat)
     if (given.get_shape().ndims == 2) or (logits.get_shape().ndims == 2):
         log_p = log_p_flat
     else:
         log_p = tf.reshape(log_p_flat, tf.shape(logits)[:-1])
         if given.get_shape() and logits.get_shape():
             log_p.set_shape(tf.broadcast_static_shape(
                 given.get_shape(), logits.get_shape())[:-1])
     return log_p
예제 #6
0
 def _log_prob(self, given):
     mean, cov_tril = (self.path_param(self.mean),
                       self.path_param(self.cov_tril))
     log_det = 2 * tf.reduce_sum(
         tf.log(tf.matrix_diag_part(cov_tril)), axis=-1)
     n_dim = tf.cast(self._n_dim, self.dtype)
     log_z = - n_dim / 2 * tf.log(
         2 * tf.constant(np.pi, dtype=self.dtype)) - log_det / 2
     # log_z.shape == batch_shape
     if self._check_numerics:
         log_z = tf.check_numerics(log_z, "log[det(Cov)]")
     # (given-mean)' Sigma^{-1} (given-mean) =
     # (g-m)' L^{-T} L^{-1} (g-m) = |x|^2, where Lx = g-m =: y.
     y = tf.expand_dims(given - mean, -1)
     L, _ = maybe_explicit_broadcast(
         cov_tril, y, 'MultivariateNormalCholesky.cov_tril',
         'expand_dims(given, -1)')
     x = tf.matrix_triangular_solve(L, y, lower=True)
     x = tf.squeeze(x, -1)
     stoc_dist = -0.5 * tf.reduce_sum(tf.square(x), axis=-1)
     return log_z + stoc_dist
예제 #7
0
 def _sample(self, n_samples):
     alpha, beta = maybe_explicit_broadcast(self.alpha, self.beta, 'alpha',
                                            'beta')
     x = tf.random_gamma([n_samples], alpha, beta=1, dtype=self.dtype)
     y = tf.random_gamma([n_samples], beta, beta=1, dtype=self.dtype)
     return x / (x + y)
예제 #8
0
 def _log_prob(self, given):
     given = tf.cast(given, self.param_dtype)
     given, logits = maybe_explicit_broadcast(given, self.logits, 'given',
                                              'logits')
     return -tf.nn.sigmoid_cross_entropy_with_logits(labels=given,
                                                     logits=logits)