Ejemplo n.º 1
0
  def _inverse_log_det_jacobian(self, y, use_saved_statistics=False):
    if not self.batchnorm.built:
      # Create variables.
      self.batchnorm.build(y.shape)

    event_dims = self.batchnorm.axis
    reduction_axes = [i for i in range(len(y.shape)) if i not in event_dims]

    # At training-time, ildj is computed from the mean and log-variance across
    # the current minibatch.
    # We use multiplication instead of tf.where() to get easier broadcasting.
    log_variance = tf.math.log(
        tf.where(
            tf.logical_or(use_saved_statistics, tf.logical_not(self._training)),
            self.batchnorm.moving_variance,
            tf.nn.moments(x=y, axes=reduction_axes, keepdims=True)[1]) +
        self.batchnorm.epsilon)

    # TODO(b/137216713): determine whether it's unsafe for the reduce_sums below
    # to happen across all axes.
    # `gamma` and `log Var(y)` reductions over event_dims.
    # Log(total change in area from gamma term).
    log_total_gamma = tf.reduce_sum(tf.math.log(self.batchnorm.gamma))

    # Log(total change in area from log-variance term).
    log_total_variance = tf.reduce_sum(log_variance)
    # The ildj is scalar, as it does not depend on the values of x and are
    # constant across minibatch elements.
    return log_total_gamma - 0.5 * log_total_variance
    def _entropy(self):
        if self._logits is None:
            # If we only have probs, there's not much we can do to ensure numerical
            # precision.
            probs = tf.convert_to_tensor(self._probs)
            return -tf.reduce_sum(
                tf.math.multiply_no_nan(tf.math.log(probs), probs), axis=-1)

        # The following result can be derived as follows. Write log(p[i]) as:
        # s[i]-m-lse(s[i]-m) where m=max(s), then you have:
        #   sum_i exp(s[i]-m-lse(s-m)) (s[i] - m - lse(s-m))
        #   = -m - lse(s-m) + sum_i s[i] exp(s[i]-m-lse(s-m))
        #   = -m - lse(s-m) + (1/exp(lse(s-m))) sum_i s[i] exp(s[i]-m)
        #   = -m - lse(s-m) + (1/sumexp(s-m)) sum_i s[i] exp(s[i]-m)
        # Write x[i]=s[i]-m then you have:
        #   = -m - lse(x) + (1/sum_exp(x)) sum_i s[i] exp(x[i])
        # Negating all of this result is the Shanon (discrete) entropy.
        logits = tf.convert_to_tensor(self._logits)
        m = tf.reduce_max(logits, axis=-1, keepdims=True)
        x = logits - m
        lse_logits = m[..., 0] + tf.reduce_logsumexp(x, axis=-1)
        sum_exp_x = tf.reduce_sum(tf.math.exp(x), axis=-1)
        return lse_logits - tf.reduce_sum(tf.math.multiply_no_nan(
            logits, tf.math.exp(x)),
                                          axis=-1) / sum_exp_x
Ejemplo n.º 3
0
 def _entropy(self):
   concentration = tf.convert_to_tensor(self.concentration)
   k = tf.cast(tf.shape(concentration)[-1], self.dtype)
   total_concentration = tf.reduce_sum(concentration, axis=-1)
   return (tf.math.lbeta(concentration) +
           ((total_concentration - k) * tf.math.digamma(total_concentration)) -
           tf.reduce_sum((concentration - 1.) * tf.math.digamma(concentration),
                         axis=-1))
 def _variance(self):
     with tf.control_dependencies(self._runtime_assertions):
         # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
         probs = distribution_utils.pad_mixture_dimensions(
             self.mixture_distribution.probs_parameter(), self,
             self.mixture_distribution, self._event_ndims)  # [B, k, [1]*e]
         mean_cond_var = tf.reduce_sum(
             probs * self.components_distribution.variance(),
             axis=-1 - self._event_ndims)  # [B, E]
         var_cond_mean = tf.reduce_sum(probs * tf.math.squared_difference(
             self.components_distribution.mean(),
             self._pad_sample_dims(self._mean())),
                                       axis=-1 -
                                       self._event_ndims)  # [B, E]
         return mean_cond_var + var_cond_mean  # [B, E]
 def _mean(self):
     with tf.control_dependencies(self._runtime_assertions):
         probs = distribution_utils.pad_mixture_dimensions(
             self.mixture_distribution.probs_parameter(), self,
             self.mixture_distribution, self._event_ndims)  # [B, k, [1]*e]
         return tf.reduce_sum(probs * self.components_distribution.mean(),
                              axis=-1 - self._event_ndims)  # [B, E]
Ejemplo n.º 6
0
 def _log_prob(self, x, **kwargs):
     batch_ndims = prefer_static.rank_from_shape(
         self.distribution.batch_shape_tensor,
         self.distribution.batch_shape)
     extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape)
     event_ndims = prefer_static.rank_from_shape(
         self.distribution.event_shape_tensor,
         self.distribution.event_shape)
     ndims = prefer_static.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=tf.pad(
                        tf.shape(x),
                        paddings=[[prefer_static.maximum(0, -d), 0]],
                        constant_values=1))
     sample_ndims = prefer_static.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = prefer_static.range(0, sample_ndims)
     batch_dims = prefer_static.range(sample_ndims,
                                      sample_ndims + batch_ndims)
     extra_sample_dims = prefer_static.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = prefer_static.range(
         sample_ndims + batch_ndims + extra_sample_ndims, ndims)
     perm = prefer_static.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(a=x, perm=perm)
     # (3) Compute x's log_prob.
     lp = self.distribution.log_prob(x, **kwargs)
     # (4) Make the final reduction in x.
     axis = prefer_static.range(sample_ndims,
                                sample_ndims + extra_sample_ndims)
     return tf.reduce_sum(lp, axis=axis)
Ejemplo n.º 7
0
def log_combinations(n, counts, name='log_combinations'):
  """Multinomial coefficient.

  Given `n` and `counts`, where `counts` has last dimension `k`, we compute
  the multinomial coefficient as:

  ```n! / sum_i n_i!```

  where `i` runs over all `k` classes.

  Args:
    n: Floating-point `Tensor` broadcastable with `counts`. This represents `n`
      outcomes.
    counts: Floating-point `Tensor` broadcastable with `n`. This represents
      counts in `k` classes, where `k` is the last dimension of the tensor.
    name: A name for this operation (optional).

  Returns:
    log_combinations: `Tensor` representing the multinomial coefficient between
      `n` and `counts`.
  """
  # First a bit about the number of ways counts could have come in:
  # E.g. if counts = [1, 2], then this is 3 choose 2.
  # In general, this is (sum counts)! / sum(counts!)
  # The sum should be along the last dimension of counts. This is the
  # 'distribution' dimension. Here n a priori represents the sum of counts.
  with tf.name_scope(name):
    n = tf.convert_to_tensor(n, name='n')
    counts = tf.convert_to_tensor(counts, name='counts')
    total_permutations = tf.math.lgamma(n + 1)
    counts_factorial = tf.math.lgamma(counts + 1)
    redundant_permutations = tf.reduce_sum(counts_factorial, axis=-1)
    return total_permutations - redundant_permutations
Ejemplo n.º 8
0
 def _variance(self):
   concentration = tf.convert_to_tensor(self.concentration)
   total_concentration = tf.reduce_sum(concentration, axis=-1, keepdims=True)
   mean = concentration / total_concentration
   scale = tf.math.rsqrt(1. + total_concentration)
   x = scale * mean
   return x * (scale - x)
 def squared_frobenius_norm(x):
     """Helper to make KL calculation slightly more readable."""
     # http://mathworld.wolfram.com/FrobeniusNorm.html
     # The gradient of KL[p,q] is not defined when p==q. The culprit is
     # tf.norm, i.e., we cannot use the commented out code.
     # return tf.square(tf.norm(x, ord="fro", axis=[-2, -1]))
     return tf.reduce_sum(tf.square(x), axis=[-2, -1])
Ejemplo n.º 10
0
 def _log_prob(self, counts):
     with tf.control_dependencies(self._maybe_assert_valid_sample(counts)):
         log_p = (tf.math.log(self._probs) if self._logits is None else
                  tf.math.log_softmax(self._logits))
         k = tf.convert_to_tensor(self.total_count)
         return (tf.reduce_sum(counts * log_p, axis=-1) +  # log_unnorm_prob
                 tfp_math.log_combinations(k, counts))  # -log_normalization
Ejemplo n.º 11
0
def matrix_rank(a, tol=None, validate_args=False, name=None):
    """Compute the matrix rank; the number of non-zero SVD singular values.

  Arguments:
    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
      pseudo-inverted.
    tol: Threshold below which the singular value is counted as 'zero'.
      Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`).
    validate_args: When `True`, additional assertions might be embedded in the
      graph.
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: 'matrix_rank'.

  Returns:
    matrix_rank: (Batch of) `int32` scalars representing the number of non-zero
      singular values.
  """
    with tf.name_scope(name or 'matrix_rank'):
        a = tf.convert_to_tensor(a, dtype_hint=tf.float32, name='a')
        assertions = _maybe_validate_matrix(a, validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                a = tf.identity(a)
        s = tf.linalg.svd(a, compute_uv=False)
        if tol is None:
            if tensorshape_util.is_fully_defined(a.shape[-2:]):
                m = np.max(a.shape[-2:].as_list())
            else:
                m = tf.reduce_max(tf.shape(a)[-2:])
            eps = np.finfo(dtype_util.as_numpy_dtype(a.dtype)).eps
            tol = (eps * tf.cast(m, a.dtype) *
                   tf.reduce_max(s, axis=-1, keepdims=True))
        return tf.reduce_sum(tf.cast(s > tol, tf.int32), axis=-1)
    def _log_prob(self, x):
        logits = self._logits_parameter_no_checks()
        event_size = self._event_size(logits)

        x = tf.cast(x, logits.dtype)
        x = self._maybe_assert_valid_sample(x, dtype=logits.dtype)

        # broadcast logits or x if need be.
        if (not tensorshape_util.is_fully_defined(x.shape)
                or not tensorshape_util.is_fully_defined(logits.shape)
                or x.shape != logits.shape):
            broadcast_shape = tf.broadcast_dynamic_shape(
                tf.shape(logits), tf.shape(x))
            logits = tf.broadcast_to(logits, broadcast_shape)
            x = tf.broadcast_to(x, broadcast_shape)

        logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1))
        logits_2d = tf.reshape(logits, [-1, event_size])
        x_2d = tf.reshape(x, [-1, event_size])
        ret = -tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.stop_gradient(x_2d), logits=logits_2d)

        # Reshape back to user-supplied batch and sample dims prior to 2D reshape.
        ret = tf.reshape(ret, logits_shape)
        return ret
 def backward_step(most_likely_successor,
                   most_likely_given_successor):
     return tf.reduce_sum(
         input_tensor=(most_likely_given_successor *
                       tf.one_hot(most_likely_successor,
                                  self._num_states,
                                  dtype=tf.int64)),
         axis=-1)
 def _maybe_assert_valid_sample(self, x, dtype):
     if not self.validate_args:
         return x
     one = tf.ones([], dtype=dtype)
     return distribution_util.with_dependencies([
         assert_util.assert_non_negative(x),
         assert_util.assert_less_equal(x, one),
         assert_util.assert_near(one, tf.reduce_sum(x, axis=[-1])),
     ], x)
Ejemplo n.º 15
0
        def body(m, pchol, perm, matrix_diag):
            """Body of a single `tf.while_loop` iteration."""
            # Here is roughly a numpy, non-batched version of what's going to happen.
            # (See also Algorithm 1 of Harbrecht et al.)
            # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
            # 2: maxval = matrix_diag[perm][maxi]
            # 3: perm[m], perm[maxi] = perm[maxi], perm[m]
            # 4: row = matrix[perm[m]][perm[m + 1:]]
            # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
            # 6: pivot = np.sqrt(maxval); row /= pivot
            # 7: row = np.concatenate([[[pivot]], row], -1)
            # 8: matrix_diag[perm[m:]] -= row**2
            # 9: pchol[m, perm[m:]] = row

            # Find the maximal position of the (remaining) permuted diagonal.
            # Steps 1, 2 above.
            permuted_diag = batch_gather(matrix_diag, perm[..., m:])
            maxi = tf.argmax(permuted_diag, axis=-1,
                             output_type=tf.int64)[..., tf.newaxis]
            maxval = batch_gather(permuted_diag, maxi)
            maxi = maxi + m
            maxval = maxval[..., 0]
            # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
            perm = _swap_m_with_i(perm, m, maxi)
            # Step 4.
            row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
            row = batch_gather(row, perm[..., m + 1:])
            # Step 5.
            prev_rows = pchol[..., :m, :]
            prev_rows_perm_m_onward = batch_gather(prev_rows, perm[...,
                                                                   m + 1:])
            prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
            row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col,
                                 axis=-2)[..., tf.newaxis, :]
            # Step 6.
            pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
            # Step 7.
            row = tf.concat([pivot, row / pivot], axis=-1)
            # TODO(b/130899118): Pad grad fails with int64 paddings.
            # Step 8.
            paddings = tf.concat([
                tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32),
                [[tf.cast(m, tf.int32), 0]]
            ],
                                 axis=0)
            diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :]
            reverse_perm = _invert_permutation(perm)
            matrix_diag -= batch_gather(diag_update, reverse_perm)
            # Step 9.
            row = tf.pad(row, paddings=paddings)
            # TODO(bjp): Defer the reverse permutation all-at-once at the end?
            row = batch_gather(row, reverse_perm)
            pchol_shape = pchol.shape
            pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
                              axis=-2)
            tensorshape_util.set_shape(pchol, pchol_shape)
            return m + 1, pchol, perm, matrix_diag
Ejemplo n.º 16
0
 def _rotate(self, samples):
     """Applies a Householder rotation to `samples`."""
     event_dim = (tf.compat.dimension_value(self.event_shape[0])
                  or self._event_shape_tensor()[0])
     basis = tf.concat(
         [[1.], tf.zeros([event_dim - 1], dtype=self.dtype)], axis=0),
     u = tf.math.l2_normalize(basis - self.mean_direction, axis=-1)
     return samples - 2 * tf.reduce_sum(samples * u, axis=-1,
                                        keepdims=True) * u
Ejemplo n.º 17
0
 def _forward_log_det_jacobian(self, x):
   # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the
   # sections leading up to it, for context) in
   # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf
   with tf.control_dependencies(self._assertions(x)):
     matrix_dim = tf.cast(tf.shape(x)[-1],
                          dtype_util.base_dtype(x.dtype))
     return -(matrix_dim + 1) * tf.reduce_sum(
         tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
Ejemplo n.º 18
0
 def _inverse_log_det_jacobian(self, y):
     # The Jacobian of the inverse mapping is lower
     # triangular, with the diagonal elements being:
     # J[i,i] = 1 if i=1, and
     #          exp(y_i) if 1<i<=K
     # which gives the absolute Jacobian determinant:
     # |det(Jac)| = prod_{i=1}^{K} exp(y[i]).
     # (1) - Stan Modeling Language User's Guide and Reference Manual
     #       Version 2.17.0 session 35.2
     return tf.reduce_sum(y[..., 1:], axis=-1)
Ejemplo n.º 19
0
 def _covariance(self):
   concentration = tf.convert_to_tensor(self.concentration)
   total_concentration = tf.reduce_sum(concentration, axis=-1, keepdims=True)
   mean = concentration / total_concentration
   scale = tf.math.rsqrt(1. + total_concentration)
   x = scale * mean
   variance = x * (scale - x)
   return tf.linalg.set_diag(
       tf.matmul(-x[..., tf.newaxis], x[..., tf.newaxis, :]),
       variance)
Ejemplo n.º 20
0
            def grad(dy):
                """Computes a derivative for the min and max parameters.

        This function implements the derivative wrt the truncation bounds, which
        get blocked by the sampler. We use a custom expression for numerical
        stability instead of automatic differentiation on CDF for implicit
        gradients.

        Args:
          dy: output gradients

        Returns:
           The standard normal samples and the gradients wrt the upper
           bound and lower bound.
        """
                # std_samples has an extra dimension (the sample dimension), expand
                # lower and upper so they broadcast along this dimension.
                # See note above regarding parameterized_truncated_normal, the sample
                # dimension is the final dimension.
                lower_broadcast = lower[..., tf.newaxis]
                upper_broadcast = upper[..., tf.newaxis]

                cdf_samples = ((special_math.ndtr(std_samples) -
                                special_math.ndtr(lower_broadcast)) /
                               (special_math.ndtr(upper_broadcast) -
                                special_math.ndtr(lower_broadcast)))

                # tiny, eps are tolerance parameters to ensure we stay away from giving
                # a zero arg to the log CDF expression.

                tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny
                eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps
                cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps)

                du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) +
                            tf.math.log(cdf_samples))
                dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) +
                            tf.math.log1p(-cdf_samples))

                # Reduce the gradient across the samples
                grad_u = tf.reduce_sum(dy * du, axis=-1)
                grad_l = tf.reduce_sum(dy * dl, axis=-1)
                return [grad_l, grad_u]
Ejemplo n.º 21
0
 def _maybe_assert_valid_sample(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return []
     assertions = distribution_util.assert_nonnegative_integer_form(counts)
     assertions.append(
         assert_util.assert_equal(
             self.total_count,
             tf.reduce_sum(counts, axis=-1),
             message='counts must sum to `self.total_count`'))
     return assertions
 def _maybe_assert_valid_sample(self, counts):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args:
     return counts
   counts = distribution_util.embed_check_nonnegative_integer_form(counts)
   return distribution_util.with_dependencies([
       assert_util.assert_equal(
           self.total_count,
           tf.reduce_sum(counts, axis=-1),
           message='counts last-dimension must sum to `self.total_count`'),
   ], counts)
Ejemplo n.º 23
0
 def _maybe_assert_valid_sample(self, x):
   """Checks the validity of a sample."""
   if not self.validate_args:
     return []
   return [
       assert_util.assert_positive(x, message='samples must be positive'),
       assert_util.assert_near(
           tf.ones([], dtype=self.dtype),
           tf.reduce_sum(x, axis=-1),
           message='sample last-dimension must sum to `1`'),
   ]
Ejemplo n.º 24
0
 def _sample_one_batch_member(args):
     logits, num_cat_samples = args[0], args[1]  # [K], []
     # x has shape [1, num_cat_samples = num_samples * num_trials]
     x = tf.random.categorical(logits[tf.newaxis, ...],
                               num_cat_samples,
                               seed=seed)
     x = tf.reshape(x, shape=[num_samples,
                              -1])  # [num_samples, num_trials]
     x = tf.one_hot(
         x, depth=num_classes)  # [num_samples, num_trials, num_classes]
     x = tf.reduce_sum(x, axis=-2)  # [num_samples, num_classes]
     return tf.cast(x, dtype=dtype)
    def _covariance(self):
        static_event_ndims = tensorshape_util.rank(self.event_shape)
        if static_event_ndims is not None and static_event_ndims != 1:
            # Covariance is defined only for vector distributions.
            raise NotImplementedError("covariance is not implemented")

        with tf.control_dependencies(self._runtime_assertions):
            # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
            probs = distribution_utils.pad_mixture_dimensions(
                distribution_utils.pad_mixture_dimensions(
                    self.mixture_distribution.probs_parameter(), self,
                    self.mixture_distribution, self._event_ndims), self,
                self.mixture_distribution, self._event_ndims)  # [B, k, 1, 1]
            mean_cond_var = tf.reduce_sum(
                probs * self.components_distribution.covariance(),
                axis=-3)  # [B, e, e]
            var_cond_mean = tf.reduce_sum(
                probs *
                _outer_squared_difference(self.components_distribution.mean(),
                                          self._pad_sample_dims(self._mean())),
                axis=-3)  # [B, e, e]
            return mean_cond_var + var_cond_mean  # [B, e, e]
Ejemplo n.º 26
0
    def _event_shape_tensor(self):
        event_sizes = tf.nest.map_structure(tensorshape_util.num_elements,
                                            self._distribution.event_shape)

        if any(s is None for s in tf.nest.flatten(event_sizes)):
            event_sizes = tf.nest.map_structure(
                lambda static_size, shape_tensor:  # pylint: disable=g-long-lambda
                (tf.reduce_prod(shape_tensor)
                 if static_size is None else static_size),
                event_sizes,
                self._distribution.event_shape_tensor())

        return tf.reduce_sum(tf.nest.flatten(event_sizes))[tf.newaxis]
Ejemplo n.º 27
0
 def _variance(self):
     probs = self._categorical.probs_parameter()
     outcomes = tf.broadcast_to(self.outcomes,
                                shape=dist_util.prefer_static_shape(probs))
     if dtype_util.is_integer(outcomes.dtype):
         if self._validate_args:
             outcomes = dist_util.embed_check_integer_casting_closed(
                 outcomes, target_dtype=probs.dtype)
         outcomes = tf.cast(outcomes, dtype=probs.dtype)
     square_d = tf.math.squared_difference(
         outcomes,
         self._mean(probs)[..., tf.newaxis])
     return tf.reduce_sum(probs * square_d, axis=-1)
Ejemplo n.º 28
0
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs,
                                               logits):
    """Return assertions for `Categorical`-type distributions."""
    assertions = []

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
        x, name = (probs, 'probs') if logits is None else (logits, 'logits')

        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

        msg = 'Argument `{}` must have rank at least 1.'.format(name)
        ndims = tensorshape_util.rank(x.shape)
        if ndims is not None:
            if ndims < 1:
                raise ValueError(msg)
        elif validate_args:
            x = tf.convert_to_tensor(x)
            probs = x if logits is None else None  # Retain tensor conversion.
            logits = x if probs is None else None
            assertions.append(
                assert_util.assert_rank_at_least(x, 1, message=msg))

    if not validate_args:
        assert not assertions  # Should never happen.
        return []

    if logits is not None:
        if is_init != tensor_util.is_ref(logits):
            logits = tf.convert_to_tensor(logits)
            assertions.extend(
                distribution_util.assert_categorical_event_shape(logits))

    if probs is not None:
        if is_init != tensor_util.is_ref(probs):
            probs = tf.convert_to_tensor(probs)
            assertions.extend([
                assert_util.assert_non_negative(probs),
                assert_util.assert_near(
                    tf.reduce_sum(probs, axis=-1),
                    np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)),
                    message='Argument `probs` must sum to 1.')
            ])
            assertions.extend(
                distribution_util.assert_categorical_event_shape(probs))

    return assertions
Ejemplo n.º 29
0
 def _forward_log_det_jacobian(self, x):
     # CholeskyToInvCholesky.forward(X) is equivalent to
     # 1) M = CholeskyOuterProduct.forward(X)
     # 2) N = invert(M)
     # 3) Y = CholeskyOuterProduct.inverse(N)
     #
     # For step 1,
     #   |Jac(outerprod(X))| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}.
     # For step 2,
     #   |Jac(inverse(M))| = |M|^{-(p+1)} (because M is symmetric)
     #                     = |X|^{-2(p+1)} = (prod_{j=0}^{p-1} X[j,j])^{-2(p+1)}
     #   (see http://web.mit.edu/18.325/www/handouts/handout2.pdf sect 3.0.2)
     # For step 3,
     #   |Jac(Cholesky(N))| = -|Jac(outerprod(Y)|
     #                      = 2^p prod_{j=0}^{p-1} Y[j,j]^{p-j}
     n = tf.cast(tf.shape(x)[-1], x.dtype)
     y = self._forward(x)
     return ((self._cholesky.forward_log_det_jacobian(x, event_ndims=2) -
              (n + 1.) *
              tf.reduce_sum(tf.math.log(tf.linalg.diag_part(x)), axis=-1)) -
             (self._cholesky.forward_log_det_jacobian(y, event_ndims=2) -
              (n + 1.) *
              tf.reduce_sum(tf.math.log(tf.linalg.diag_part(y)), axis=-1)))
Ejemplo n.º 30
0
 def _forward_log_det_jacobian(self, x):
     # This code is similar to tf.math.log_softmax but different because we have
     # an implicit zero column to handle. I.e., instead of:
     #   reduce_sum(logits - reduce_sum(exp(logits), dim))
     # we must do:
     #   log_normalization = 1 + reduce_sum(exp(logits))
     #   -log_normalization + reduce_sum(logits - log_normalization)
     n = prefer_static.shape(x)[-1]
     log_normalization = tf.math.softplus(
         tf.reduce_logsumexp(x, axis=-1, keepdims=True))
     return tf.squeeze(
         (-log_normalization +
          tf.reduce_sum(x - log_normalization, axis=-1, keepdims=True)),
         axis=-1) + 0.5 * tf.math.log(tf.cast(n + 1, dtype=x.dtype))