Ejemplo n.º 1
0
def matrix_rank(a, tol=None, validate_args=False, name=None):
    """Compute the matrix rank; the number of non-zero SVD singular values.

  Arguments:
    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
      pseudo-inverted.
    tol: Threshold below which the singular value is counted as 'zero'.
      Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`).
    validate_args: When `True`, additional assertions might be embedded in the
      graph.
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: 'matrix_rank'.

  Returns:
    matrix_rank: (Batch of) `int32` scalars representing the number of non-zero
      singular values.
  """
    with tf.name_scope(name or 'matrix_rank'):
        a = tf.convert_to_tensor(a, dtype_hint=tf.float32, name='a')
        assertions = _maybe_validate_matrix(a, validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                a = tf.identity(a)
        s = tf.linalg.svd(a, compute_uv=False)
        if tol is None:
            if tensorshape_util.is_fully_defined(a.shape[-2:]):
                m = np.max(a.shape[-2:].as_list())
            else:
                m = tf.reduce_max(tf.shape(a)[-2:])
            eps = np.finfo(dtype_util.as_numpy_dtype(a.dtype)).eps
            tol = (eps * tf.cast(m, a.dtype) *
                   tf.reduce_max(s, axis=-1, keepdims=True))
        return tf.reduce_sum(tf.cast(s > tol, tf.int32), axis=-1)
Ejemplo n.º 2
0
 def _make_perm(self, x_rank, perm):
     sample_batch_ndims = (distribution_util.prefer_static_value(x_rank) -
                           distribution_util.prefer_static_value(
                               self.rightmost_transposed_ndims))
     dtype = perm.dtype
     perm = tf.concat([
         tf.range(tf.cast(sample_batch_ndims, dtype)),
         tf.cast(
             sample_batch_ndims +
             distribution_util.prefer_static_value(perm), dtype),
     ],
                      axis=0)
     return perm
Ejemplo n.º 3
0
 def _prob(self, event):
   samples = tf.convert_to_tensor(self._samples)
   num_samples = self._compute_num_samples(samples)
   event = tf.convert_to_tensor(event, name='event', dtype=self.dtype)
   event, samples = _broadcast_event_and_samples(event, samples,
                                                 event_ndims=self._event_ndims)
   prob = tf.reduce_sum(
       tf.cast(
           tf.reduce_all(
               tf.equal(samples, event), axis=tf.range(-self._event_ndims, 0)),
           dtype=tf.int32),
       axis=-1) / num_samples
   if dtype_util.is_floating(self.dtype):
     prob = tf.cast(prob, self.dtype)
   return prob
Ejemplo n.º 4
0
    def _cdf(self, k):
        # TODO(b/135263541): Improve numerical precision of categorical.cdf.
        probs = self.probs_parameter()
        num_categories = self._num_categories(probs)

        k, probs = _broadcast_cat_event_and_params(
            k, probs, base_dtype=dtype_util.base_dtype(self.dtype))

        # Since the lowest number in the support is 0, any k < 0 should be zero in
        # the output.
        should_be_zero = k < 0

        # Will use k as an index in the gather below, so clip it to {0,...,K-1}.
        k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1)

        batch_shape = tf.shape(k)

        # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so
        # to handle the case where the batch shape is dynamic, flatten the batch
        # dims (so we know batch_dims=1).
        k_flat_batch = tf.reshape(k, [-1])
        probs_flat_batch = tf.reshape(
            probs, tf.concat(([-1], [num_categories]), axis=0))

        cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1),
                             k_flat_batch[..., tf.newaxis],
                             batch_dims=1)

        cdf = tf.reshape(cdf_flat, shape=batch_shape)

        zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype))
        return tf.where(should_be_zero, zero, cdf)
Ejemplo n.º 5
0
 def _multi_gamma_sequence(self, a, p, name="multi_gamma_sequence"):
     """Creates sequence used in multivariate (di)gamma; shape = shape(a)+[p]."""
     with self._name_and_control_scope(name):
         # Linspace only takes scalars, so we'll add in the offset afterwards.
         seq = tf.linspace(tf.constant(0., dtype=self.dtype), 0.5 - 0.5 * p,
                           tf.cast(p, tf.int32))
         return seq + tf.expand_dims(a, [-1])
Ejemplo n.º 6
0
def reduce_logmeanexp(input_tensor, axis=None, keepdims=False, name=None):
  """Computes `log(mean(exp(input_tensor)))`.

  Reduces `input_tensor` along the dimensions given in `axis`.  Unless
  `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in
  `axis`. If `keepdims` is true, the reduced dimensions are retained with length
  1.

  If `axis` has no entries, all dimensions are reduced, and a tensor with a
  single element is returned.

  This function is more numerically stable than `log(reduce_mean(exp(input)))`.
  It avoids overflows caused by taking the exp of large inputs and underflows
  caused by taking the log of small inputs.

  Args:
    input_tensor: The tensor to reduce. Should have numeric type.
    axis: The dimensions to reduce. If `None` (the default), reduces all
      dimensions. Must be in the range `[-rank(input_tensor),
      rank(input_tensor))`.
    keepdims:  Boolean.  Whether to keep the axis as singleton dimensions.
      Default value: `False` (i.e., squeeze the reduced dimensions).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., `'reduce_logmeanexp'`).

  Returns:
    log_mean_exp: The reduced tensor.
  """
  with tf.name_scope(name or 'reduce_logmeanexp'):
    lse = tf.reduce_logsumexp(input_tensor, axis=axis, keepdims=keepdims)
    n = prefer_static.size(input_tensor) // prefer_static.size(lse)
    log_n = tf.math.log(tf.cast(n, lse.dtype))
    return lse - log_n
Ejemplo n.º 7
0
  def _mode(self, samples=None):
    # Samples count can vary by batch member. Use map_fn to compute mode for
    # each batch separately.
    def _get_mode(samples):
      # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed
      count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count
      return tf.argmax(count)

    if samples is None:
      samples = tf.convert_to_tensor(self._samples)
    num_samples = self._compute_num_samples(samples)

    # Flatten samples for each batch.
    if self._event_ndims == 0:
      flattened_samples = tf.reshape(samples, [-1, num_samples])
      mode_shape = self._batch_shape_tensor(samples)
    else:
      event_size = tf.reduce_prod(self._event_shape_tensor(samples))
      mode_shape = tf.concat(
          [self._batch_shape_tensor(samples),
           self._event_shape_tensor(samples)],
          axis=0)
      flattened_samples = tf.reshape(samples, [-1, num_samples, event_size])

    indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64)
    full_indices = tf.stack(
        [tf.range(tf.shape(indices)[0]),
         tf.cast(indices, tf.int32)], axis=1)

    mode = tf.gather_nd(flattened_samples, full_indices)
    return tf.reshape(mode, mode_shape)
    def _log_prob(self, x):
        logits = self._logits_parameter_no_checks()
        event_size = self._event_size(logits)

        x = tf.cast(x, logits.dtype)
        x = self._maybe_assert_valid_sample(x, dtype=logits.dtype)

        # broadcast logits or x if need be.
        if (not tensorshape_util.is_fully_defined(x.shape)
                or not tensorshape_util.is_fully_defined(logits.shape)
                or x.shape != logits.shape):
            broadcast_shape = tf.broadcast_dynamic_shape(
                tf.shape(logits), tf.shape(x))
            logits = tf.broadcast_to(logits, broadcast_shape)
            x = tf.broadcast_to(x, broadcast_shape)

        logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1))
        logits_2d = tf.reshape(logits, [-1, event_size])
        x_2d = tf.reshape(x, [-1, event_size])
        ret = -tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.stop_gradient(x_2d), logits=logits_2d)

        # Reshape back to user-supplied batch and sample dims prior to 2D reshape.
        ret = tf.reshape(ret, logits_shape)
        return ret
Ejemplo n.º 9
0
    def _log_prob(self, x, power=None):
        # The log probability at positive integer points x is log(x^(-power) / Z)
        # where Z is the normalization constant. For x < 1 and non-integer points,
        # the log-probability is -inf.
        #
        # However, if interpolate_nondiscrete is True, we return the natural
        # continuous relaxation for x >= 1 which agrees with the log probability at
        # positive integer points.
        #
        # If interpolate_nondiscrete is False and validate_args is True, we check
        # that the sample point x is in the support. That is, x is equivalent to a
        # positive integer.
        power = power if power is not None else tf.convert_to_tensor(
            self.power)
        x = tf.cast(x, power.dtype)
        if self.validate_args and not self.interpolate_nondiscrete:
            x = distribution_util.embed_check_integer_casting_closed(
                x, target_dtype=self.dtype, assert_positive=True)
        log_normalization = tf.math.log(tf.math.zeta(power, 1.))

        safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x),
                            1.)
        y = -power * tf.math.log(safe_x)
        log_unnormalized_prob = tf.where(
            tf.equal(x, safe_x), y,
            dtype_util.as_numpy_dtype(y.dtype)(-np.inf))

        return log_unnormalized_prob - log_normalization
Ejemplo n.º 10
0
 def _sample_n(self, n, seed=None):
     n_draws = tf.cast(self.total_count, dtype=tf.int32)
     logits = self._logits_parameter_no_checks()
     k = tf.compat.dimension_value(logits.shape[-1])
     if k is None:
         k = tf.shape(logits)[-1]
     return draw_sample(n, k, logits, n_draws, self.dtype, seed)
Ejemplo n.º 11
0
 def _entropy(self):
   concentration = tf.convert_to_tensor(self.concentration)
   k = tf.cast(tf.shape(concentration)[-1], self.dtype)
   total_concentration = tf.reduce_sum(concentration, axis=-1)
   return (tf.math.lbeta(concentration) +
           ((total_concentration - k) * tf.math.digamma(total_concentration)) -
           tf.reduce_sum((concentration - 1.) * tf.math.digamma(concentration),
                         axis=-1))
Ejemplo n.º 12
0
  def _log_prob(self, event):
    if self.validate_args:
      event = distribution_util.embed_check_integer_casting_closed(
          event, target_dtype=tf.bool)

    log_probs0, log_probs1 = self._outcome_log_probs()
    event = tf.cast(event, log_probs0.dtype)
    return event * (log_probs1 - log_probs0) + log_probs0
Ejemplo n.º 13
0
    def _entropy(self, **kwargs):
        if not self.bijector.is_constant_jacobian:
            raise NotImplementedError("entropy is not implemented")
        if not self.bijector._is_injective:  # pylint: disable=protected-access
            raise NotImplementedError("entropy is not implemented when "
                                      "bijector is not injective.")
        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
        # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It
        # can be shown that:
        #   H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)].
        # If is_constant_jacobian then:
        #   E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
        # where c can by anything.
        entropy = self.distribution.entropy(**distribution_kwargs)
        if self._is_maybe_event_override:
            # H[X] = sum_i H[X_i] if X_i are mutually independent.
            # This means that a reduce_sum is a simple rescaling.
            entropy = entropy * tf.cast(
                tf.reduce_prod(self._override_event_shape),
                dtype=dtype_util.base_dtype(entropy.dtype))
        if self._is_maybe_batch_override:
            new_shape = tf.concat([
                prefer_static.ones_like(self._override_batch_shape),
                self.distribution.batch_shape_tensor()
            ], 0)
            entropy = tf.reshape(entropy, new_shape)
            multiples = tf.concat([
                self._override_batch_shape,
                prefer_static.ones_like(self.distribution.batch_shape_tensor())
            ], 0)
            entropy = tf.tile(entropy, multiples)
        dummy = prefer_static.zeros(shape=tf.concat(
            [self.batch_shape_tensor(),
             self.event_shape_tensor()], 0),
                                    dtype=self.dtype)
        event_ndims = (tensorshape_util.rank(self.event_shape)
                       if tensorshape_util.rank(self.event_shape) is not None
                       else tf.size(self.event_shape_tensor()))
        ildj = self.bijector.inverse_log_det_jacobian(dummy,
                                                      event_ndims=event_ndims,
                                                      **bijector_kwargs)

        entropy = entropy - tf.cast(ildj, entropy.dtype)
        tensorshape_util.set_shape(entropy, self.batch_shape)
        return entropy
Ejemplo n.º 14
0
        def body(m, pchol, perm, matrix_diag):
            """Body of a single `tf.while_loop` iteration."""
            # Here is roughly a numpy, non-batched version of what's going to happen.
            # (See also Algorithm 1 of Harbrecht et al.)
            # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
            # 2: maxval = matrix_diag[perm][maxi]
            # 3: perm[m], perm[maxi] = perm[maxi], perm[m]
            # 4: row = matrix[perm[m]][perm[m + 1:]]
            # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
            # 6: pivot = np.sqrt(maxval); row /= pivot
            # 7: row = np.concatenate([[[pivot]], row], -1)
            # 8: matrix_diag[perm[m:]] -= row**2
            # 9: pchol[m, perm[m:]] = row

            # Find the maximal position of the (remaining) permuted diagonal.
            # Steps 1, 2 above.
            permuted_diag = batch_gather(matrix_diag, perm[..., m:])
            maxi = tf.argmax(permuted_diag, axis=-1,
                             output_type=tf.int64)[..., tf.newaxis]
            maxval = batch_gather(permuted_diag, maxi)
            maxi = maxi + m
            maxval = maxval[..., 0]
            # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
            perm = _swap_m_with_i(perm, m, maxi)
            # Step 4.
            row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
            row = batch_gather(row, perm[..., m + 1:])
            # Step 5.
            prev_rows = pchol[..., :m, :]
            prev_rows_perm_m_onward = batch_gather(prev_rows, perm[...,
                                                                   m + 1:])
            prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
            row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col,
                                 axis=-2)[..., tf.newaxis, :]
            # Step 6.
            pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
            # Step 7.
            row = tf.concat([pivot, row / pivot], axis=-1)
            # TODO(b/130899118): Pad grad fails with int64 paddings.
            # Step 8.
            paddings = tf.concat([
                tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32),
                [[tf.cast(m, tf.int32), 0]]
            ],
                                 axis=0)
            diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :]
            reverse_perm = _invert_permutation(perm)
            matrix_diag -= batch_gather(diag_update, reverse_perm)
            # Step 9.
            row = tf.pad(row, paddings=paddings)
            # TODO(bjp): Defer the reverse permutation all-at-once at the end?
            row = batch_gather(row, reverse_perm)
            pchol_shape = pchol.shape
            pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
                              axis=-2)
            tensorshape_util.set_shape(pchol, pchol_shape)
            return m + 1, pchol, perm, matrix_diag
Ejemplo n.º 15
0
 def _forward_log_det_jacobian(self, x):
   # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the
   # sections leading up to it, for context) in
   # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf
   with tf.control_dependencies(self._assertions(x)):
     matrix_dim = tf.cast(tf.shape(x)[-1],
                          dtype_util.base_dtype(x.dtype))
     return -(matrix_dim + 1) * tf.reduce_sum(
         tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
Ejemplo n.º 16
0
 def _reshape_part(part, dtype, event_shape):
     part = tf.cast(part, dtype)
     static_rank = tf.get_static_value(
         ps.rank_from_shape(event_shape))
     if static_rank == 1:
         return part
     new_shape = ps.concat([ps.shape(part)[:-1], event_shape],
                           axis=-1)
     return tf.reshape(part, ps.cast(new_shape, tf.int32))
Ejemplo n.º 17
0
def normal_conjugates_known_scale_posterior(prior, scale, s, n):
    """Posterior Normal distribution with conjugate prior on the mean.

  This model assumes that `n` observations (with sum `s`) come from a
  Normal with unknown mean `loc` (described by the Normal `prior`)
  and known variance `scale**2`. The "known scale posterior" is
  the distribution of the unknown `loc`.

  Accepts a prior Normal distribution object, having parameters
  `loc0` and `scale0`, as well as known `scale` values of the predictive
  distribution(s) (also assumed Normal),
  and statistical estimates `s` (the sum(s) of the observations) and
  `n` (the number(s) of observations).

  Returns a posterior (also Normal) distribution object, with parameters
  `(loc', scale'**2)`, where:

  ```
  mu ~ N(mu', sigma'**2)
  sigma'**2 = 1/(1/sigma0**2 + n/sigma**2),
  mu' = (mu0/sigma0**2 + s/sigma**2) * sigma'**2.
  ```

  Distribution parameters from `prior`, as well as `scale`, `s`, and `n`.
  will broadcast in the case of multidimensional sets of parameters.

  Args:
    prior: `Normal` object of type `dtype`:
      the prior distribution having parameters `(loc0, scale0)`.
    scale: tensor of type `dtype`, taking values `scale > 0`.
      The known stddev parameter(s).
    s: Tensor of type `dtype`. The sum(s) of observations.
    n: Tensor of type `int`. The number(s) of observations.

  Returns:
    A new Normal posterior distribution object for the unknown observation
    mean `loc`.

  Raises:
    TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
      Normal object.
  """
    if not isinstance(prior, normal.Normal):
        raise TypeError("Expected prior to be an instance of type Normal")

    if s.dtype != prior.dtype:
        raise TypeError(
            "Observation sum s.dtype does not match prior dtype: %s vs. %s" %
            (s.dtype, prior.dtype))

    n = tf.cast(n, prior.dtype)
    scale0_2 = tf.square(prior.scale)
    scale_2 = tf.square(scale)
    scalep_2 = 1.0 / (1 / scale0_2 + n / scale_2)
    return normal.Normal(loc=(prior.loc / scale0_2 + s / scale_2) * scalep_2,
                         scale=tf.sqrt(scalep_2))
Ejemplo n.º 18
0
 def _mean(self, probs=None):
     if probs is None:
         probs = self._categorical.probs_parameter()
     outcomes = self.outcomes
     if dtype_util.is_integer(outcomes.dtype):
         if self._validate_args:
             outcomes = dist_util.embed_check_integer_casting_closed(
                 outcomes, target_dtype=probs.dtype)
         outcomes = tf.cast(outcomes, dtype=probs.dtype)
     return tf.tensordot(outcomes, probs, axes=[[0], [-1]])
Ejemplo n.º 19
0
 def _inverse(self, y):
   # As specified in the Stan reference manual, the procedure is as follows:
   # N = y.shape[-1]
   # z_k = y_k / (1 - sum_{i=1 to k-1} y_i)
   # x_k = logit(z_k) - log(1 / (N - k))
   offset = tf.math.log(
       tf.cast(
           tf.range(tf.shape(y)[-1] - 1, 0, delta=-1),
           dtype=dtype_util.base_dtype(y.dtype)))
   z = y / (1. - tf.math.cumsum(y, axis=-1, exclusive=True))
   return tf.math.log(z[..., :-1]) - tf.math.log1p(-z[..., :-1]) + offset
Ejemplo n.º 20
0
 def _sample_one_batch_member(args):
     logits, num_cat_samples = args[0], args[1]  # [K], []
     # x has shape [1, num_cat_samples = num_samples * num_trials]
     x = tf.random.categorical(logits[tf.newaxis, ...],
                               num_cat_samples,
                               seed=seed)
     x = tf.reshape(x, shape=[num_samples,
                              -1])  # [num_samples, num_trials]
     x = tf.one_hot(
         x, depth=num_classes)  # [num_samples, num_trials, num_classes]
     x = tf.reduce_sum(x, axis=-2)  # [num_samples, num_classes]
     return tf.cast(x, dtype=dtype)
Ejemplo n.º 21
0
 def _forward_log_det_jacobian(self, x):
     # is_constant_jacobian = True for this bijector, hence the
     # `log_det_jacobian` need only be specified for a single input, as this will
     # be tiled to match `event_ndims`.
     if self._is_only_identity_multiplier:
         # We don't pad in this case and instead let the fldj be applied
         # via broadcast.
         log_abs_diag = tf.math.log(tf.abs(self._scale))
         event_size = tf.shape(x)[-1]
         event_size = tf.cast(event_size, dtype=log_abs_diag.dtype)
         return log_abs_diag * event_size
     return self.scale.log_abs_determinant()
Ejemplo n.º 22
0
 def _inverse_log_det_jacobian(self, y):
     # Could also do:
     #   ildj = tf.reduce_sum(y - tfp.math.softplus_inverse(y),
     #                              axis=event_dims)
     # but the following is more numerically stable. Ie,
     # Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1]
     # ==> dX/dY = exp{Y} / (exp{Y} - 1)
     #           = 1 / (1 - exp{-Y}),
     # which is the most stable for large Y > 0. For small Y, we use
     # 1 - exp{-Y} approx Y.
     if self.hinge_softness is not None:
         y = y / tf.cast(self.hinge_softness, y.dtype)
     return -tf.math.log(-tf.math.expm1(-y))
Ejemplo n.º 23
0
 def _sample_n(self, n, seed=None):
     logits = self._logits_parameter_no_checks()
     logits_2d = tf.reshape(logits, [-1, self._num_categories(logits)])
     sample_dtype = tf.int64 if dtype_util.size(
         self.dtype) > 4 else tf.int32
     draws = tf.random.categorical(logits_2d,
                                   n,
                                   dtype=sample_dtype,
                                   seed=seed)
     draws = tf.cast(draws, self.dtype)
     return tf.reshape(tf.transpose(draws),
                       shape=tf.concat(
                           [[n], self._batch_shape_tensor(logits)], axis=0))
Ejemplo n.º 24
0
 def _variance(self):
     probs = self._categorical.probs_parameter()
     outcomes = tf.broadcast_to(self.outcomes,
                                shape=dist_util.prefer_static_shape(probs))
     if dtype_util.is_integer(outcomes.dtype):
         if self._validate_args:
             outcomes = dist_util.embed_check_integer_casting_closed(
                 outcomes, target_dtype=probs.dtype)
         outcomes = tf.cast(outcomes, dtype=probs.dtype)
     square_d = tf.math.squared_difference(
         outcomes,
         self._mean(probs)[..., tf.newaxis])
     return tf.reduce_sum(probs * square_d, axis=-1)
Ejemplo n.º 25
0
 def _forward_log_det_jacobian(self, x):
     # This code is similar to tf.math.log_softmax but different because we have
     # an implicit zero column to handle. I.e., instead of:
     #   reduce_sum(logits - reduce_sum(exp(logits), dim))
     # we must do:
     #   log_normalization = 1 + reduce_sum(exp(logits))
     #   -log_normalization + reduce_sum(logits - log_normalization)
     n = prefer_static.shape(x)[-1]
     log_normalization = tf.math.softplus(
         tf.reduce_logsumexp(x, axis=-1, keepdims=True))
     return tf.squeeze(
         (-log_normalization +
          tf.reduce_sum(x - log_normalization, axis=-1, keepdims=True)),
         axis=-1) + 0.5 * tf.math.log(tf.cast(n + 1, dtype=x.dtype))
Ejemplo n.º 26
0
    def _cdf(self, x):
        # CDF(x) at positive integer x is the probability that the Zipf variable is
        # less than or equal to x; given by the formula:
        #     CDF(x) = 1 - (zeta(power, x + 1) / Z)
        # For fractional x, the CDF is equal to the CDF at n = floor(x).
        # For x < 1, the CDF is zero.

        # If interpolate_nondiscrete is True, we return a continuous relaxation
        # which agrees with the CDF at integer points.
        power = tf.convert_to_tensor(self.power)
        x = tf.cast(x, power.dtype)
        safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x),
                            0.)

        cdf = 1. - (tf.math.zeta(power, safe_x + 1.) / tf.math.zeta(power, 1.))
        return tf.where(x < 1., tf.zeros_like(cdf), cdf)
Ejemplo n.º 27
0
 def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims,
                                **distribution_kwargs):
     """Finish computation of prob on one element of the inverse image."""
     x = self._maybe_rotate_dims(x, rotate_right=True)
     prob = self.distribution.prob(x, **distribution_kwargs)
     if self._is_maybe_event_override:
         prob = tf.reduce_prod(prob, axis=self._reduce_event_indices)
     prob = prob * tf.exp(tf.cast(ildj, prob.dtype))
     if self._is_maybe_event_override and isinstance(event_ndims, int):
         tensorshape_util.set_shape(
             prob,
             tf.broadcast_static_shape(
                 tensorshape_util.with_rank_at_least(y.shape,
                                                     1)[:-event_ndims],
                 self.batch_shape))
     return prob
Ejemplo n.º 28
0
 def _forward(self, x):
   # As specified in the Stan reference manual, the procedure is as follows:
   # N = x.shape[-1] + 1
   # z_k = sigmoid(x + log(1 / (N - k)))
   # y_1 = z_1
   # y_k = (1 - sum_{i=1 to k-1} y_i) * z_k
   # y_N = 1 - sum_{i=1 to N-1} y_i
   # TODO(b/128857065): The numerics can possibly be improved here with a
   # log-space computation.
   offset = -tf.math.log(
       tf.cast(
           tf.range(tf.shape(x)[-1], 0, delta=-1),
           dtype=dtype_util.base_dtype(x.dtype)))
   z = tf.math.sigmoid(x + offset)
   y = z * tf.math.cumprod(1 - z, axis=-1, exclusive=True)
   return tf.concat([y, 1. - tf.reduce_sum(y, axis=-1, keepdims=True)],
                    axis=-1)
Ejemplo n.º 29
0
 def _log_normalization(self):
     """Computes the log-normalizer of the distribution."""
     event_dim = tf.compat.dimension_value(self.event_shape[0])
     if event_dim is None:
         raise ValueError('vMF _log_normalizer currently only supports '
                          'statically known event shape')
     safe_conc = tf.where(self.concentration > 0, self.concentration,
                          tf.ones_like(self.concentration))
     safe_lognorm = ((event_dim / 2 - 1) * tf.math.log(safe_conc) -
                     (event_dim / 2) * np.log(2 * np.pi) - tf.math.log(
                         _bessel_ive(event_dim / 2 - 1, safe_conc)) -
                     tf.abs(safe_conc))
     log_nsphere_surface_area = (
         np.log(2.) + (event_dim / 2) * np.log(np.pi) -
         tf.math.lgamma(tf.cast(event_dim / 2, self.dtype)))
     return tf.where(self.concentration > 0, -safe_lognorm,
                     log_nsphere_surface_area)
Ejemplo n.º 30
0
 def _prob(self, x):
     if self.validate_args:
         is_vector_check = assert_util.assert_rank_at_least(x, 1)
         right_vec_space_check = assert_util.assert_equal(
             self.event_shape_tensor(),
             tf.gather(tf.shape(x),
                       tf.rank(x) - 1),
             message=
             "Argument 'x' not defined in the same space R^k as this distribution"
         )
         with tf.control_dependencies([is_vector_check]):
             with tf.control_dependencies([right_vec_space_check]):
                 x = tf.identity(x)
     loc = tf.convert_to_tensor(self.loc)
     return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc),
                                  axis=-1),
                    dtype=self.dtype)