def _log_prob(self, given): mean, u_tril, v_tril = (self.path_param(self.mean), self.path_param(self.u_tril), self.path_param(self.v_tril)) log_det_u = 2 * tf.reduce_sum(tf.log(tf.matrix_diag_part(u_tril)), axis=-1) log_det_v = 2 * tf.reduce_sum(tf.log(tf.matrix_diag_part(v_tril)), axis=-1) n_row = tf.cast(self._n_row, self.dtype) n_col = tf.cast(self._n_col, self.dtype) logZ = - (n_row * n_col) / 2. * \ tf.log(2. * tf.constant(np.pi, dtype=self.dtype)) - \ n_row / 2. * log_det_v - n_col / 2. * log_det_u # logZ.shape == batch_shape if self._check_numerics: logZ = tf.check_numerics(logZ, "log[det(Cov)]") y = given - mean y_with_last_dim_changed = tf.expand_dims(tf.ones(tf.shape(y)[:-1]), -1) Lu, _ = maybe_explicit_broadcast(u_tril, y_with_last_dim_changed, 'MatrixVariateNormalCholesky.u_tril', 'expand_dims(given, -1)') y_with_sec_last_dim_changed = tf.expand_dims( tf.ones(tf.concat( [tf.shape(y)[:-2], tf.shape(y)[-1:]], axis=0)), -1) Lv, _ = maybe_explicit_broadcast(v_tril, y_with_sec_last_dim_changed, 'MatrixVariateNormalCholesky.v_tril', 'expand_dims(given, -1)') x_Lb_inv_t = tf.matrix_triangular_solve(Lu, y, lower=True) x_t = tf.matrix_triangular_solve(Lv, tf.matrix_transpose(x_Lb_inv_t), lower=True) stoc_dist = -0.5 * tf.reduce_sum(tf.square(x_t), [-1, -2]) return logZ + stoc_dist
def _log_prob(self, given): given = tf.cast(given, self.param_dtype) given, logits = maybe_explicit_broadcast( given, self.logits, 'given', 'logits') if self.normalize_logits: logits = logits - tf.reduce_logsumexp( logits, axis=-1, keepdims=True) log_p = tf.reduce_sum(given * logits, -1) return log_p
def _log_prob(self, given): given = tf.cast(given, self.param_dtype) given, logits = maybe_explicit_broadcast(given, self.logits, 'given', 'logits') normalized_logits = logits - tf.reduce_logsumexp( logits, axis=-1, keep_dims=True) n = tf.cast(self.n_experiments, self.param_dtype) log_p = log_combination(n, given) + \ tf.reduce_sum(given * normalized_logits, -1) return log_p
def _log_prob(self, given): given, alpha = maybe_explicit_broadcast(given, self.alpha, 'given', 'alpha') lbeta_alpha = tf.lbeta(alpha) # fix of no static shape inference for tf.lbeta if alpha.get_shape(): lbeta_alpha.set_shape(alpha.get_shape()[:-1]) log_given = tf.log(given) if self._check_numerics: lbeta_alpha = tf.check_numerics(lbeta_alpha, "lbeta(alpha)") log_given = tf.check_numerics(log_given, "log(given)") log_p = -lbeta_alpha + tf.reduce_sum((alpha - 1) * log_given, -1) return log_p
def _log_prob(self, given): given = tf.cast(given, self.param_dtype) given, logits = maybe_explicit_broadcast( given, self.logits, 'given', 'logits') if (given.get_shape().ndims == 2) or (logits.get_shape().ndims == 2): given_flat = given logits_flat = logits else: given_flat = tf.reshape(given, [-1, self.n_categories]) logits_flat = tf.reshape(logits, [-1, self.n_categories]) log_p_flat = -tf.nn.softmax_cross_entropy_with_logits( labels=given_flat, logits=logits_flat) if (given.get_shape().ndims == 2) or (logits.get_shape().ndims == 2): log_p = log_p_flat else: log_p = tf.reshape(log_p_flat, tf.shape(logits)[:-1]) if given.get_shape() and logits.get_shape(): log_p.set_shape(tf.broadcast_static_shape( given.get_shape(), logits.get_shape())[:-1]) return log_p
def _log_prob(self, given): mean, cov_tril = (self.path_param(self.mean), self.path_param(self.cov_tril)) log_det = 2 * tf.reduce_sum( tf.log(tf.matrix_diag_part(cov_tril)), axis=-1) n_dim = tf.cast(self._n_dim, self.dtype) log_z = - n_dim / 2 * tf.log( 2 * tf.constant(np.pi, dtype=self.dtype)) - log_det / 2 # log_z.shape == batch_shape if self._check_numerics: log_z = tf.check_numerics(log_z, "log[det(Cov)]") # (given-mean)' Sigma^{-1} (given-mean) = # (g-m)' L^{-T} L^{-1} (g-m) = |x|^2, where Lx = g-m =: y. y = tf.expand_dims(given - mean, -1) L, _ = maybe_explicit_broadcast( cov_tril, y, 'MultivariateNormalCholesky.cov_tril', 'expand_dims(given, -1)') x = tf.matrix_triangular_solve(L, y, lower=True) x = tf.squeeze(x, -1) stoc_dist = -0.5 * tf.reduce_sum(tf.square(x), axis=-1) return log_z + stoc_dist
def _sample(self, n_samples): alpha, beta = maybe_explicit_broadcast(self.alpha, self.beta, 'alpha', 'beta') x = tf.random_gamma([n_samples], alpha, beta=1, dtype=self.dtype) y = tf.random_gamma([n_samples], beta, beta=1, dtype=self.dtype) return x / (x + y)
def _log_prob(self, given): given = tf.cast(given, self.param_dtype) given, logits = maybe_explicit_broadcast(given, self.logits, 'given', 'logits') return -tf.nn.sigmoid_cross_entropy_with_logits(labels=given, logits=logits)