Пример #1
0
def log1psquare(x, name=None):
    """Numerically stable calculation of `log(1 + x**2)` for small or large `|x|`.

  For sufficiently large `x` we use the following observation:

  ```none
  log(1 + x**2) =   2 log(|x|) + log(1 + 1 / x**2)
                --> 2 log(|x|)  as x --> inf
  ```

  Numerically, `log(1 + 1 / x**2)` is `0` when `1 / x**2` is small relative to
  machine epsilon.

  Args:
    x: Float `Tensor` input.
    name: Python string indicating the name of the TensorFlow operation.
      Default value: `'log1psquare'`.

  Returns:
    log1psq: Float `Tensor` representing `log(1. + x**2.)`.
  """
    with tf.name_scope(name or 'log1psquare'):
        x = tf.convert_to_tensor(x, dtype_hint=tf.float32, name='x')
        dtype = dtype_util.as_numpy_dtype(x.dtype)

        eps = np.finfo(dtype).eps.astype(np.float64)
        is_large = tf.abs(x) > (eps**-0.5).astype(dtype)

        # Mask out small x's so the gradient correctly propagates.
        abs_large_x = tf.where(is_large, tf.abs(x), tf.ones([], x.dtype))
        return tf.where(is_large, 2. * tf.math.log(abs_large_x),
                        tf.math.log1p(tf.square(x)))
    def _survival_function(self, y):
        low = self._low
        high = self._high

        # Recall the promise:
        # survival_function(y) := P[Y > y]
        #                       = 0, if y >= high,
        #                       = 1, if y < low,
        #                       = P[X > y], otherwise.

        # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in
        # between.
        j = tf.math.ceil(y)

        # P[X > j], used when low < X < high.
        result_so_far = self.distribution.survival_function(j)

        # Re-define values at the cutoffs.
        if low is not None:
            result_so_far = tf.where(j < low, tf.ones_like(result_so_far),
                                     result_so_far)
        if high is not None:
            result_so_far = tf.where(j >= high, tf.zeros_like(result_so_far),
                                     result_so_far)

        return result_so_far
Пример #3
0
  def _extend_support(self, x, scale, f, alt):
    """Returns `f(x)` if x is in the support, and `alt` otherwise.

    Given `f` which is defined on the support of this distribution
    (e.g. x > scale), extend the function definition to the real line
    by defining `f(x) = alt` for `x < scale`.

    Args:
      x: Floating-point Tensor to evaluate `f` at.
      scale: Floating-point Tensor by which to verify `x` validity.
      f: Lambda that takes in a tensor and returns a tensor. This represents the
        function who we want to extend the domain of definition.
      alt: Python or numpy literal representing the value to use for extending
        the domain.

    Returns:
      Tensor representing an extension of `f(x)`.
    """
    if self.validate_args:
      return f(x)
    scale = tf.convert_to_tensor(self.scale) if scale is None else scale
    is_invalid = x < scale
    # We need to do this to ensure gradients are sound.
    y = f(tf.where(is_invalid, scale, x))
    if alt == 0.:
      alt = tf.zeros([], dtype=y.dtype)
    elif alt == 1.:
      alt = tf.ones([], dtype=y.dtype)
    else:
      alt = dtype_util.as_numpy_dtype(self.dtype)(alt)
    return tf.where(is_invalid, alt, y)
    def _std_var_helper(self, statistic, statistic_name, statistic_ndims,
                        df_factor_fn):
        """Helper to compute stddev, covariance and variance."""
        df = tf.reshape(
            self.df,
            tf.concat([
                tf.shape(self.df),
                tf.ones([statistic_ndims], dtype=tf.int32)
            ], -1))
        # We need to put the tf.where inside the outer tf1.where to ensure we never
        # hit a NaN in the gradient.
        denom = tf.where(df > 2., df - 2., tf.ones_like(df))
        statistic = statistic * df_factor_fn(df / denom)
        # When 1 < df <= 2, stddev/variance are infinite.
        result_where_defined = tf.where(
            df > 2., statistic,
            dtype_util.as_numpy_dtype(self.dtype)(np.inf))

        if self.allow_nan_stats:
            return tf.where(df > 1., result_where_defined,
                            dtype_util.as_numpy_dtype(self.dtype)(np.nan))
        else:
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.cast(1., self.dtype),
                        df,
                        message='{} not defined for components of df <= 1.'.
                        format(statistic_name.capitalize())),
            ]):
                return tf.identity(result_where_defined)
    def _cdf(self, y):
        low = self._low
        high = self._high

        # Recall the promise:
        # cdf(y) := P[Y <= y]
        #         = 1, if y >= high,
        #         = 0, if y < low,
        #         = P[X <= y], otherwise.

        # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in
        # between.
        j = tf.floor(y)

        # P[X <= j], used when low < X < high.
        result_so_far = self.distribution.cdf(j)

        # Re-define values at the cutoffs.
        if low is not None:
            result_so_far = tf.where(j < low, tf.zeros_like(result_so_far),
                                     result_so_far)
        if high is not None:
            result_so_far = tf.where(j >= high, tf.ones_like(result_so_far),
                                     result_so_far)

        return result_so_far
    def _log_cdf(self, y):
        low = self._low
        high = self._high

        # Recall the promise:
        # cdf(y) := P[Y <= y]
        #         = 1, if y >= high,
        #         = 0, if y < low,
        #         = P[X <= y], otherwise.

        # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in
        # between.
        j = tf.floor(y)

        result_so_far = self.distribution.log_cdf(j)

        # Re-define values at the cutoffs.
        if low is not None:
            result_so_far = tf.where(
                j < low,
                dtype_util.as_numpy_dtype(self.dtype)(-np.inf), result_so_far)
        if high is not None:
            result_so_far = tf.where(j >= high, tf.zeros_like(result_so_far),
                                     result_so_far)

        return result_so_far
Пример #7
0
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape):
  """Slices a single parameter of a distribution.

  Args:
    param: A `Tensor`, the original parameter to slice.
    param_event_ndims: `int` event parameterization rank for this parameter.
    slices: A `tuple` of normalized slices.
    dist_batch_shape: The distribution's batch shape `Tensor`.

  Returns:
    new_param: A `Tensor`, batch-sliced according to slices.
  """
  # Extend param shape with ones on the left to match dist_batch_shape.
  param_shape = tf.shape(input=param)
  insert_ones = tf.ones(
      [tf.size(input=dist_batch_shape) + param_event_ndims - tf.rank(param)],
      dtype=param_shape.dtype)
  new_param_shape = tf.concat([insert_ones, param_shape], axis=0)
  full_batch_param = tf.reshape(param, new_param_shape)
  param_slices = []
  # We separately track the batch axis from the parameter axis because we want
  # them to align for positive indexing, and be offset by param_event_ndims for
  # negative indexing.
  param_dim_idx = 0
  batch_dim_idx = 0
  for slc in slices:
    if slc is tf.newaxis:
      param_slices.append(slc)
      continue
    if slc is Ellipsis:
      if batch_dim_idx < 0:
        raise ValueError('Found multiple `...` in slices {}'.format(slices))
      param_slices.append(slc)
      # Switch over to negative indexing for the broadcast check.
      num_remaining_non_newaxis_slices = sum(
          [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]])
      batch_dim_idx = -num_remaining_non_newaxis_slices
      param_dim_idx = batch_dim_idx - param_event_ndims
      continue
    # Find the batch dimension sizes for both parameter and distribution.
    param_dim_size = new_param_shape[param_dim_idx]
    batch_dim_size = dist_batch_shape[batch_dim_idx]
    is_broadcast = batch_dim_size > param_dim_size
    # Slices are denoted by start:stop:step.
    if isinstance(slc, slice):
      start, stop, step = slc.start, slc.stop, slc.step
      if start is not None:
        start = tf.where(is_broadcast, 0, start)
      if stop is not None:
        stop = tf.where(is_broadcast, 1, stop)
      if step is not None:
        step = tf.where(is_broadcast, 1, step)
      param_slices.append(slice(start, stop, step))
    else:  # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
      param_slices.append(tf.where(is_broadcast, 0, slc))
    param_dim_idx += 1
    batch_dim_idx += 1
  param_slices.extend([ALL_SLICE] * param_event_ndims)
  return full_batch_param.__getitem__(param_slices)
Пример #8
0
def _ndtr(x):
    """Implements ndtr core logic."""
    half_sqrt_2 = tf.constant(0.5 * np.sqrt(2.),
                              dtype=x.dtype,
                              name="half_sqrt_2")
    w = x * half_sqrt_2
    z = tf.abs(w)
    y = tf.where(z < half_sqrt_2, 1. + tf.math.erf(w),
                 tf.where(w > 0., 2. - tf.math.erfc(z), tf.math.erfc(z)))
    return 0.5 * y
Пример #9
0
 def _cdf(self, x):
     low = tf.convert_to_tensor(self.low)
     high = tf.convert_to_tensor(self.high)
     broadcast_shape = tf.broadcast_dynamic_shape(
         tf.shape(x), self._batch_shape_tensor(low=low, high=high))
     zeros = tf.zeros(broadcast_shape, dtype=self.dtype)
     ones = tf.ones(broadcast_shape, dtype=self.dtype)
     result_if_not_big = tf.where(x < low, zeros, (x - low) /
                                  self._range(low=low, high=high))
     return tf.where(x >= high, ones, result_if_not_big)
Пример #10
0
 def _prob(self, x):
     low = tf.convert_to_tensor(self.low)
     high = tf.convert_to_tensor(self.high)
     return tf.where(
         tf.math.is_nan(x),
         x,
         tf.where(
             # This > is only sound for continuous uniform
             (x < low) | (x > high),
             tf.zeros_like(x),
             tf.ones_like(x) / self._range(low=low, high=high)))
Пример #11
0
 def _log_prob(self, x):
     scale = tf.convert_to_tensor(self.scale)
     concentration = tf.convert_to_tensor(self.concentration)
     z = self._z(x, scale, concentration)
     eq_zero = tf.equal(concentration,
                        0)  # Concentration = 0 ==> Exponential.
     nonzero_conc = tf.where(eq_zero, tf.constant(1, self.dtype),
                             concentration)
     where_nonzero = (1 / nonzero_conc + 1) * tf.math.log1p(
         nonzero_conc * z)
     return -tf.math.log(scale) - tf.where(eq_zero, z, where_nonzero)
Пример #12
0
 def _log_cdf(self, x):
     scale = tf.convert_to_tensor(self.scale)
     concentration = tf.convert_to_tensor(self.concentration)
     z = self._z(x, scale, concentration)
     eq_zero = tf.equal(concentration,
                        0)  # Concentration = 0 ==> Exponential.
     nonzero_conc = tf.where(eq_zero, tf.constant(1, self.dtype),
                             concentration)
     where_nonzero = tf.math.log1p(-(1 + nonzero_conc * z)**(-1 /
                                                             nonzero_conc))
     where_zero = tf.math.log1p(-tf.exp(-z))
     return tf.where(eq_zero, where_zero, where_nonzero)
Пример #13
0
def softplus_inverse(x, name=None):
    """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).

  Mathematically this op is equivalent to:

  ```none
  softplus_inverse = log(exp(x) - 1.)
  ```

  Args:
    x: `Tensor`. Non-negative (not enforced), floating-point.
    name: A name for the operation (optional).

  Returns:
    `Tensor`. Has the same type/shape as input `x`.
  """
    with tf.name_scope(name or 'softplus_inverse'):
        x = tf.convert_to_tensor(x, name='x')
        # We begin by deriving a more numerically stable softplus_inverse:
        # x = softplus(y) = Log[1 + exp{y}], (which means x > 0).
        # ==> exp{x} = 1 + exp{y}                                (1)
        # ==> y = Log[exp{x} - 1]                                (2)
        #       = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}]
        #       = Log[(1 - exp{-x}) / 1] + Log[exp{x}]
        #       = Log[1 - exp{-x}] + x                           (3)
        # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x.
        # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will
        # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0.
        #
        # In addition to the numerically stable derivation above, we clamp
        # small/large values to be congruent with the logic in:
        # tensorflow/core/kernels/softplus_op.h
        #
        # Finally, we set the input to one whenever the input is too large or too
        # small. This ensures that no unchosen codepath is +/- inf. This is
        # necessary to ensure the gradient doesn't get NaNs. Recall that the
        # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false`
        # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful
        # to overwrite `x` with ones only when we will never actually use this
        # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`.
        threshold = np.log(np.finfo(dtype_util.as_numpy_dtype(
            x.dtype)).eps) + 2.
        is_too_small = x < np.exp(threshold)
        is_too_large = x > -threshold
        too_small_value = tf.math.log(x)
        too_large_value = x
        # This `where` will ultimately be a NOP because we won't select this
        # codepath whenever we used the surrogate `ones_like`.
        x = tf.where(is_too_small | is_too_large, tf.ones([], x.dtype), x)
        y = x + tf.math.log(-tf.math.expm1(-x))  # == log(expm1(x))
        return tf.where(is_too_small, too_small_value,
                        tf.where(is_too_large, too_large_value, y))
Пример #14
0
 def _mean(self):
   # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/
   event_dim = tf.compat.dimension_value(self.event_shape[0])
   if event_dim is None:
     raise ValueError('event shape must be statically known for _bessel_ive')
   safe_conc = tf.where(self.concentration > 0, self.concentration,
                        tf.ones_like(self.concentration))
   safe_mean = self.mean_direction * (
       _bessel_ive(event_dim / 2, safe_conc) /
       _bessel_ive(event_dim / 2 - 1, safe_conc))[..., tf.newaxis]
   return tf.where(
       self.concentration[..., tf.newaxis] > tf.zeros_like(safe_mean),
       safe_mean, tf.zeros_like(safe_mean))
Пример #15
0
  def _log_prob(self, x, power=None):
    # The log probability at positive integer points x is log(x^(-power) / Z)
    # where Z is the normalization constant. For x < 1 and non-integer points,
    # the log-probability is -inf.
    #
    # However, if interpolate_nondiscrete is True, we return the natural
    # continuous relaxation for x >= 1 which agrees with the log probability at
    # positive integer points.
    #
    # If interpolate_nondiscrete is False and validate_args is True, we check
    # that the sample point x is in the support. That is, x is equivalent to a
    # positive integer.
    power = power if power is not None else tf.convert_to_tensor(self.power)
    x = tf.cast(x, power.dtype)
    if self.validate_args and not self.interpolate_nondiscrete:
      x = distribution_util.embed_check_integer_casting_closed(
          x, target_dtype=self.dtype, assert_positive=True)
    log_normalization = tf.math.log(tf.math.zeta(power, 1.))

    safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 1.)
    y = -power * tf.math.log(safe_x)
    log_unnormalized_prob = tf.where(
        tf.equal(x, safe_x), y, dtype_util.as_numpy_dtype(y.dtype)(-np.inf))

    return log_unnormalized_prob - log_normalization
Пример #16
0
 def _mean(self):
     concentration = tf.convert_to_tensor(self.concentration)
     lim = tf.ones([], dtype=self.dtype)
     valid = concentration < lim
     safe_conc = tf.where(valid, concentration, tf.constant(.5, self.dtype))
     result = lambda: self.loc + self.scale / (1 - safe_conc)
     if self.allow_nan_stats:
         return tf.where(valid, result(),
                         tf.constant(float('nan'), self.dtype))
     with tf.control_dependencies([
             assert_util.assert_less(
                 concentration,
                 lim,
                 message='`mean` is undefined when `concentration >= 1`')
     ]):
         return result()
Пример #17
0
    def loop_body(should_continue, k):
      """Resample the non-accepted points."""
      # The range of U is chosen so that the resulting sample K lies in
      # [0, tf.int64.max). The final sample, if accepted, is K + 1.
      u = tf.random.uniform(
          shape,
          minval=minval_u,
          maxval=maxval_u,
          dtype=power.dtype,
          seed=seed())

      # Sample the point X from the continuous density h(x) \propto x^(-power).
      x = self._hat_integral_inverse(u, power=power)

      # Rejection-inversion requires a `hat` function, h(x) such that
      # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
      # support. A natural hat function for us is h(x) = x^(-power).
      #
      # After sampling X from h(x), suppose it lies in the interval
      # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
      # if lies to the left of x_K, where x_K is defined by:
      #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
      # where H(x) = \int_x^inf h(x) dx.

      # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
      # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
      # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

      # Update the non-accepted points.
      # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
      k = tf.where(should_continue, tf.floor(x + 0.5), k)
      accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp(
          self._log_prob(k + 1, power=power)))

      return [should_continue & (~accept), k]
Пример #18
0
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
    """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
    batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape)
    if tensorshape_util.is_fully_defined(batch_shape_static):
        return np.int32(batch_shape_static), batch_shape_static, []
    with tf.name_scope(name or 'calculate_reshape'):
        original_size = tf.reduce_prod(original_shape)
        implicit_dim = tf.equal(new_shape, -1)
        size_implicit_dim = (original_size //
                             tf.maximum(1, -tf.reduce_prod(new_shape)))
        expanded_new_shape = tf.where(  # Assumes exactly one `-1`.
            implicit_dim, size_implicit_dim, new_shape)
        validations = [] if not validate else [  # pylint: disable=g-long-ternary
            assert_util.assert_rank(
                original_shape, 1, message='Original shape must be a vector.'),
            assert_util.assert_rank(
                new_shape, 1, message='New shape must be a vector.'),
            assert_util.assert_less_equal(
                tf.math.count_nonzero(implicit_dim, dtype=tf.int32),
                1,
                message='At most one dimension can be unknown.'),
            assert_util.assert_positive(
                expanded_new_shape, message='Shape elements must be >=-1.'),
            assert_util.assert_equal(tf.reduce_prod(expanded_new_shape),
                                     original_size,
                                     message='Shape sizes do not match.'),
        ]
        return expanded_new_shape, batch_shape_static, validations
Пример #19
0
    def _cdf(self, k):
        # TODO(b/135263541): Improve numerical precision of categorical.cdf.
        probs = self.probs_parameter()
        num_categories = self._num_categories(probs)

        k, probs = _broadcast_cat_event_and_params(
            k, probs, base_dtype=dtype_util.base_dtype(self.dtype))

        # Since the lowest number in the support is 0, any k < 0 should be zero in
        # the output.
        should_be_zero = k < 0

        # Will use k as an index in the gather below, so clip it to {0,...,K-1}.
        k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1)

        batch_shape = tf.shape(k)

        # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so
        # to handle the case where the batch shape is dynamic, flatten the batch
        # dims (so we know batch_dims=1).
        k_flat_batch = tf.reshape(k, [-1])
        probs_flat_batch = tf.reshape(
            probs, tf.concat(([-1], [num_categories]), axis=0))

        cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1),
                             k_flat_batch[..., tf.newaxis],
                             batch_dims=1)

        cdf = tf.reshape(cdf_flat, shape=batch_shape)

        zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype))
        return tf.where(should_be_zero, zero, cdf)
Пример #20
0
def _swap_m_with_i(vecs, m, i):
    """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.)

  Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped
  per-vector indices `i`, this function swaps elements `m` and `i` in each
  vector. For the use-case below, these are permutation vectors.

  Args:
    vecs: Vectors on which we perform the swap, int64 `Tensor`.
    m: Scalar int64 `Tensor`, the index into which the `i`th element is going.
    i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into
      which the `m`th element is going.

  Returns:
    vecs: The updated vectors.
  """
    vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs')
    m = tf.convert_to_tensor(m, dtype=tf.int64, name='m')
    i = tf.convert_to_tensor(i, dtype=tf.int64, name='i')
    trailing_elts = tf.broadcast_to(
        tf.range(m + 1,
                 prefer_static.shape(vecs, out_type=tf.int64)[-1]),
        prefer_static.shape(vecs[..., m + 1:]))
    trailing_elts = tf.where(tf.equal(trailing_elts, i),
                             tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:])
    # TODO(bjp): Could we use tensor_scatter_nd_update?
    vecs_shape = vecs.shape
    vecs = tf.concat([
        vecs[..., :m],
        tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1),
        trailing_elts
    ],
                     axis=-1)
    tensorshape_util.set_shape(vecs, vecs_shape)
    return vecs
Пример #21
0
 def _variance(self):
   concentration = tf.convert_to_tensor(self.concentration)
   valid_variance = (self.scale**2 * concentration /
                     ((concentration - 1.)**2 * (concentration - 2.)))
   return tf.where(concentration > 2.,
                   valid_variance,
                   dtype_util.as_numpy_dtype(self.dtype)(np.inf))
Пример #22
0
    def _sample_n(self, n, seed=None):
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)

        stream = SeedStream(seed, salt='triangular')
        shape = tf.concat(
            [[n], self._batch_shape_tensor(low=low, high=high, peak=peak)],
            axis=0)
        samples = tf.random.uniform(shape=shape,
                                    dtype=self.dtype,
                                    seed=stream())
        # We use Inverse CDF sampling here. Because the CDF is a quadratic function,
        # we must use sqrts here.
        interval_length = high - low
        return tf.where(
            # Note the CDF on the left side of the peak is
            # (x - low) ** 2 / ((high - low) * (peak - low)).
            # If we plug in peak for x, we get that the CDF at the peak
            # is (peak - low) / (high - low). Because of this we decide
            # which part of the piecewise CDF we should use based on the cdf samples
            # we drew.
            samples < (peak - low) / interval_length,
            # Inverse of (x - low) ** 2 / ((high - low) * (peak - low)).
            low + tf.sqrt(samples * interval_length * (peak - low)),
            # Inverse of 1 - (high - x) ** 2 / ((high - low) * (high - peak))
            high - tf.sqrt((1. - samples) * interval_length * (high - peak)))
Пример #23
0
def _kl_pareto_pareto(a, b, name=None):
  """Calculate the batched KL divergence KL(a || b) with a and b Pareto.

  Args:
    a: instance of a Pareto distribution object.
    b: instance of a Pareto distribution object.
    name: (optional) Name to use for created operations.
      default is 'kl_pareto_pareto'.

  Returns:
    Batchwise KL(a || b)
  """
  with tf.name_scope(name or 'kl_pareto_pareto'):
    # Consistent with
    # http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf, page 55
    # Terminology is different from source to source for Pareto distributions.
    # The 'concentration' parameter corresponds to 'a' in that source, and the
    # 'scale' parameter corresponds to 'm'.
    a_scale = tf.convert_to_tensor(a.scale)
    b_scale = tf.convert_to_tensor(b.scale)
    a_concentration = tf.convert_to_tensor(a.concentration)
    b_concentration = tf.convert_to_tensor(b.concentration)
    return tf.where(
        a_scale >= b_scale,
        (b_concentration * (tf.math.log(a_scale) - tf.math.log(b_scale)) +
         tf.math.log(a_concentration) - tf.math.log(b_concentration) +
         b_concentration / a_concentration - 1.),
        dtype_util.as_numpy_dtype(a.dtype)(np.inf))
Пример #24
0
  def _inverse_log_det_jacobian(self, y, use_saved_statistics=False):
    if not self.batchnorm.built:
      # Create variables.
      self.batchnorm.build(y.shape)

    event_dims = self.batchnorm.axis
    reduction_axes = [i for i in range(len(y.shape)) if i not in event_dims]

    # At training-time, ildj is computed from the mean and log-variance across
    # the current minibatch.
    # We use multiplication instead of tf.where() to get easier broadcasting.
    log_variance = tf.math.log(
        tf.where(
            tf.logical_or(use_saved_statistics, tf.logical_not(self._training)),
            self.batchnorm.moving_variance,
            tf.nn.moments(x=y, axes=reduction_axes, keepdims=True)[1]) +
        self.batchnorm.epsilon)

    # TODO(b/137216713): determine whether it's unsafe for the reduce_sums below
    # to happen across all axes.
    # `gamma` and `log Var(y)` reductions over event_dims.
    # Log(total change in area from gamma term).
    log_total_gamma = tf.reduce_sum(tf.math.log(self.batchnorm.gamma))

    # Log(total change in area from log-variance term).
    log_total_variance = tf.reduce_sum(log_variance)
    # The ildj is scalar, as it does not depend on the values of x and are
    # constant across minibatch elements.
    return log_total_gamma - 0.5 * log_total_variance
Пример #25
0
def _kl_uniform_uniform(a, b, name=None):
    """Calculate the batched KL divergence KL(a || b) with a and b Uniform.

  Note that the KL divergence is infinite if the support of `a` is not a subset
  of the support of `b`.

  Args:
    a: instance of a Uniform distribution object.
    b: instance of a Uniform distribution object.
    name: (optional) Name to use for created operations.
      default is "kl_uniform_uniform".

  Returns:
    Batchwise KL(a || b)
  """
    with tf.name_scope(name or 'kl_uniform_uniform'):
        # Consistent with
        # http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf, page 60
        # Watch out for the change in conventions--they use 'a' and 'b' to refer to
        # lower and upper bounds respectively there.
        dtype = dtype_util.common_dtype([a.low, a.high, b.low, b.high],
                                        tf.float32)
        a_low = tf.convert_to_tensor(a.low)
        b_low = tf.convert_to_tensor(b.low)
        a_high = tf.convert_to_tensor(a.high)
        b_high = tf.convert_to_tensor(b.high)
        return tf.where(
            (b_low <= a_low) & (a_high <= b_high),
            tf.math.log(b_high - b_low) - tf.math.log(a_high - a_low),
            dtype_util.as_numpy_dtype(dtype)(np.inf))
Пример #26
0
 def _log_normalization(self):
   """Computes the log-normalizer of the distribution."""
   event_dim = tf.compat.dimension_value(self.event_shape[0])
   if event_dim is None:
     raise ValueError('vMF _log_normalizer currently only supports '
                      'statically known event shape')
   safe_conc = tf.where(self.concentration > 0, self.concentration,
                        tf.ones_like(self.concentration))
   safe_lognorm = ((event_dim / 2 - 1) * tf.math.log(safe_conc) -
                   (event_dim / 2) * np.log(2 * np.pi) -
                   tf.math.log(_bessel_ive(event_dim / 2 - 1, safe_conc)) -
                   tf.abs(safe_conc))
   log_nsphere_surface_area = (
       np.log(2.) + (event_dim / 2) * np.log(np.pi) -
       tf.math.lgamma(tf.cast(event_dim / 2, self.dtype)))
   return tf.where(self.concentration > 0, -safe_lognorm,
                   log_nsphere_surface_area)
Пример #27
0
 def _log_unnormalized_prob(self, x, log_rate):
     # The log-probability at negative points is always -inf.
     # Catch such x's and set the output value accordingly.
     safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x),
                         0.)
     y = safe_x * log_rate - tf.math.lgamma(1. + safe_x)
     return tf.where(tf.equal(x, safe_x), y,
                     dtype_util.as_numpy_dtype(y.dtype)(-np.inf))
Пример #28
0
 def _cdf(self, x):
     df = tf.convert_to_tensor(self.df)
     # Take Abs(scale) to make subsequent where work correctly.
     y = (x - self.loc) / tf.abs(self.scale)
     x_t = df / (y**2. + df)
     neg_cdf = 0.5 * tf.math.betainc(
         0.5 * tf.broadcast_to(df, prefer_static.shape(x_t)), 0.5, x_t)
     return tf.where(y < 0., neg_cdf, 1. - neg_cdf)
def _pick_scalar_condition(pred, cond_true, cond_false):
    """Convenience function which chooses the condition based on the predicate."""
    # Note: This function is only valid if all of pred, cond_true, and cond_false
    # are scalars. This means its semantics are arguably more like tf.cond than
    # tf.where even though we use tf.where to implement it.
    pred_ = tf.get_static_value(tf.convert_to_tensor(pred))
    if pred_ is None:
        return tf.where(pred, cond_true, cond_false)
    return cond_true if pred_ else cond_false
Пример #30
0
 def _log_prob(self, x):
     with tf.control_dependencies(self._maybe_assert_valid_sample(x)):
         probs = self._probs_parameter_no_checks()
         if not self.validate_args:
             # For consistency with cdf, we take the floor.
             x = tf.floor(x)
         safe_domain = tf.where(tf.equal(x, 0.), tf.zeros_like(probs),
                                probs)
         return x * tf.math.log1p(-safe_domain) + tf.math.log(probs)