def _inverse(self, y):
     map_values = tf.convert_to_tensor(self.map_values)
     flat_y = tf.reshape(y, shape=[-1])
     # Search for the indices of map_values that are closest to flat_y.
     # Since map_values is strictly increasing, the closest is either the
     # first one that is strictly greater than flat_y, or the one before it.
     upper_candidates = tf.minimum(
         tf.size(map_values) - 1,
         tf.searchsorted(map_values, values=flat_y, side='right'))
     lower_candidates = tf.maximum(0, upper_candidates - 1)
     candidates = tf.stack([lower_candidates, upper_candidates], axis=-1)
     lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates))
     upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates))
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_near(tf.minimum(lower_cand_diff,
                                                    upper_cand_diff),
                                         0,
                                         message='inverse value not found')
         ]):
             candidates = tf.identity(candidates)
     candidate_selector = tf.stack([
         tf.range(tf.size(flat_y), dtype=tf.int32),
         tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32)
     ],
                                   axis=-1)
     return tf.reshape(tf.gather_nd(candidates, candidate_selector),
                       shape=y.shape)
示例#2
0
def log1psquare(x, name=None):
    """Numerically stable calculation of `log(1 + x**2)` for small or large `|x|`.

  For sufficiently large `x` we use the following observation:

  ```none
  log(1 + x**2) =   2 log(|x|) + log(1 + 1 / x**2)
                --> 2 log(|x|)  as x --> inf
  ```

  Numerically, `log(1 + 1 / x**2)` is `0` when `1 / x**2` is small relative to
  machine epsilon.

  Args:
    x: Float `Tensor` input.
    name: Python string indicating the name of the TensorFlow operation.
      Default value: `'log1psquare'`.

  Returns:
    log1psq: Float `Tensor` representing `log(1. + x**2.)`.
  """
    with tf.name_scope(name or 'log1psquare'):
        x = tf.convert_to_tensor(x, dtype_hint=tf.float32, name='x')
        dtype = dtype_util.as_numpy_dtype(x.dtype)

        eps = np.finfo(dtype).eps.astype(np.float64)
        is_large = tf.abs(x) > (eps**-0.5).astype(dtype)

        # Mask out small x's so the gradient correctly propagates.
        abs_large_x = tf.where(is_large, tf.abs(x), tf.ones([], x.dtype))
        return tf.where(is_large, 2. * tf.math.log(abs_large_x),
                        tf.math.log1p(tf.square(x)))
 def _log_variance(self):
     # Following calculation is based on law of total variance:
     #
     # Var[Z] = E[Var[Z | V]] + Var[E[Z | V]]
     #
     # where,
     #
     # Z|v ~ interpolate_affine[v](dist)
     # V ~ mixture_dist
     #
     # thus,
     #
     # E[Var[Z | V]] = sum{ prob[d] Var[d] : d=0, ..., deg-1 }
     # Var[E[Z | V]] = sum{ prob[d] (Mean[d] - Mean)**2 : d=0, ..., deg-1 }
     distributions = self.poisson_and_mixture_distributions()
     dist, mixture_dist = distributions
     v = tf.stack(
         [
             # log(dist.variance()) = log(Var[d]) = log(rate[d])
             dist.log_rate,
             # log((Mean[d] - Mean)**2)
             2. * tf.math.log(
                 tf.abs(dist.mean() - self._mean(
                     distributions=distributions)[..., tf.newaxis])),
         ],
         axis=-1)
     return tf.reduce_logsumexp(mixture_dist.logits[..., tf.newaxis] + v,
                                axis=[-2, -1])
示例#4
0
 def _cdf(self, x):
     df = tf.convert_to_tensor(self.df)
     # Take Abs(scale) to make subsequent where work correctly.
     y = (x - self.loc) / tf.abs(self.scale)
     x_t = df / (y**2. + df)
     neg_cdf = 0.5 * tf.math.betainc(
         0.5 * tf.broadcast_to(df, prefer_static.shape(x_t)), 0.5, x_t)
     return tf.where(y < 0., neg_cdf, 1. - neg_cdf)
示例#5
0
 def _sample_n(self, n, seed=None):
     scale = tf.convert_to_tensor(self.scale)
     shape = tf.concat([[n], tf.shape(scale)], 0)
     sampled = tf.random.normal(shape=shape,
                                mean=0.,
                                stddev=1.,
                                dtype=self.dtype,
                                seed=seed)
     return tf.abs(sampled * scale)
 def _forward_log_det_jacobian(self, x):
     # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the
     # sections leading up to it, for context) in
     # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf
     with tf.control_dependencies(self._assertions(x)):
         matrix_dim = tf.cast(
             tf.shape(x)[-1], dtype_util.base_dtype(x.dtype))
         return -(matrix_dim + 1) * tf.reduce_sum(
             tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
 def _forward_log_det_jacobian(self, x):
     if self.log_scale is not None:
         return self.log_scale
     elif self.scale is not None:
         return tf.math.log(tf.abs(self.scale))
     else:
         # is_constant_jacobian = True for this bijector, hence the
         # `log_det_jacobian` need only be specified for a single input, as this
         # will be tiled to match `event_ndims`.
         return tf.zeros([], dtype=x.dtype)
示例#8
0
def _ndtr(x):
    """Implements ndtr core logic."""
    half_sqrt_2 = tf.constant(0.5 * np.sqrt(2.),
                              dtype=x.dtype,
                              name="half_sqrt_2")
    w = x * half_sqrt_2
    z = tf.abs(w)
    y = tf.where(z < half_sqrt_2, 1. + tf.math.erf(w),
                 tf.where(w > 0., 2. - tf.math.erfc(z), tf.math.erfc(z)))
    return 0.5 * y
示例#9
0
 def _entropy(self):
     df = tf.convert_to_tensor(self.df)
     scale = tf.convert_to_tensor(self.scale)
     v = tf.ones(self._batch_shape_tensor(df=df, scale=scale),
                 dtype=self.dtype)[..., tf.newaxis]
     u = v * df[..., tf.newaxis]
     beta_arg = tf.concat([u, v], -1) / 2.
     return (tf.math.log(tf.abs(scale)) + 0.5 * tf.math.log(df) +
             tf.math.lbeta(beta_arg) + 0.5 * (df + 1.) *
             (tf.math.digamma(0.5 * (df + 1.)) - tf.math.digamma(0.5 * df)))
示例#10
0
 def _stddev(self):
   if distribution_util.is_diagonal_scale(self.scale):
     return tf.abs(self.scale.diag_part())
   elif (isinstance(self.scale, tf.linalg.LinearOperatorLowRankUpdate) and
         self.scale.is_self_adjoint):
     return tf.sqrt(
         tf.linalg.diag_part(self.scale.matmul(self.scale.to_dense())))
   else:
     return tf.sqrt(
         tf.linalg.diag_part(
             self.scale.matmul(self.scale.to_dense(), adjoint_arg=True)))
示例#11
0
 def _log_prob(self, x):
     df = tf.convert_to_tensor(self.df)
     scale = tf.convert_to_tensor(self.scale)
     loc = tf.convert_to_tensor(self.loc)
     y = (x - loc) / scale  # Abs(scale) superfluous.
     log_unnormalized_prob = -0.5 * (df + 1.) * tf.math.log1p(y**2. / df)
     log_normalization = (tf.math.log(tf.abs(scale)) +
                          0.5 * tf.math.log(df) + 0.5 * np.log(np.pi) +
                          tf.math.lgamma(0.5 * df) -
                          tf.math.lgamma(0.5 * (df + 1.)))
     return log_unnormalized_prob - log_normalization
示例#12
0
 def _forward_log_det_jacobian(self, x):
   # is_constant_jacobian = True for this bijector, hence the
   # `log_det_jacobian` need only be specified for a single input, as this will
   # be tiled to match `event_ndims`.
   if self._is_only_identity_multiplier:
     # We don't pad in this case and instead let the fldj be applied
     # via broadcast.
     log_abs_diag = tf.math.log(tf.abs(self._scale))
     event_size = tf.shape(x)[-1]
     event_size = tf.cast(event_size, dtype=log_abs_diag.dtype)
     return log_abs_diag * event_size
   return self.scale.log_abs_determinant()
示例#13
0
def _sqrtx2p1(x):
    """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`."""
    sqrt_eps = np.sqrt(np.finfo(dtype_util.as_numpy_dtype(x.dtype)).eps)
    return tf.where(
        tf.abs(x) * sqrt_eps <= 1.,
        tf.sqrt(x**2. + 1.),
        # For large x, calculating x**2 can overflow. This can be alleviated by
        # considering:
        # sqrt(1 + x**2)
        # = exp(0.5 log(1 + x**2))
        # = exp(0.5 log(x**2 * (1 + x**-2)))
        # = exp(log(x) + 0.5 * log(1 + x**-2))
        # = |x| * exp(0.5 log(1 + x**-2))
        # = |x| * sqrt(1 + x**-2)
        # We omit the last term in this approximation.
        # When |x| > 1 / sqrt(machineepsilon), the second term will be 1,
        # due to sqrt(1 + x**-2) = 1. This is also true with the gradient term,
        # and higher order gradients, since the first order derivative of
        # sqrt(1 + x**-2) is -2 * x**-3 / (1 + x**-2) = -2 / (x**3 + x),
        # and all nth-order derivatives will be O(x**-(n + 2)). This makes any
        # gradient terms that contain any derivatives of sqrt(1 + x**-2) vanish.
        tf.abs(x))
def _bessel_ive(v, z, cache=None):
  """Computes I_v(z)*exp(-abs(z)) using a recurrence relation, where z > 0."""
  # TODO(b/67497980): Switch to a more numerically faithful implementation.
  z = tf.convert_to_tensor(z)

  wrap = lambda result: tf.debugging.check_numerics(result, 'besseli{}'.format(v
                                                                              ))

  if float(v) >= 2:
    raise ValueError(
        'Evaluating bessel_i by recurrence becomes imprecise for large v')

  cache = cache or {}
  safe_z = tf.where(z > 0, z, tf.ones_like(z))
  if v in cache:
    return wrap(cache[v])
  if v == 0:
    cache[v] = tf.math.bessel_i0e(z)
  elif v == 1:
    cache[v] = tf.math.bessel_i1e(z)
  elif v == 0.5:
    # sinh(x)*exp(-abs(x)), sinh(x) = (e^x - e^{-x}) / 2
    sinhe = lambda x: (tf.exp(x - tf.abs(x)) - tf.exp(-x - tf.abs(x))) / 2
    cache[v] = (
        np.sqrt(2 / np.pi) * sinhe(z) *
        tf.where(z > 0, tf.math.rsqrt(safe_z), tf.ones_like(safe_z)))
  elif v == -0.5:
    # cosh(x)*exp(-abs(x)), cosh(x) = (e^x + e^{-x}) / 2
    coshe = lambda x: (tf.exp(x - tf.abs(x)) + tf.exp(-x - tf.abs(x))) / 2
    cache[v] = (
        np.sqrt(2 / np.pi) * coshe(z) *
        tf.where(z > 0, tf.math.rsqrt(safe_z), tf.ones_like(safe_z)))
  if v <= 1:
    return wrap(cache[v])
  # Recurrence relation:
  cache[v] = (_bessel_ive(v - 2, z, cache) -
              (2 * (v - 1)) * _bessel_ive(v - 1, z, cache) / z)
  return wrap(cache[v])
    def _stddev(self):
        if distribution_util.is_diagonal_scale(self.scale):
            mvn_std = tf.abs(self.scale.diag_part())
        elif (isinstance(self.scale, tf.linalg.LinearOperatorLowRankUpdate)
              and self.scale.is_self_adjoint):
            mvn_std = tf.sqrt(
                tf.linalg.diag_part(self.scale.matmul(self.scale.to_dense())))
        else:
            mvn_std = tf.sqrt(
                tf.linalg.diag_part(
                    self.scale.matmul(self.scale.to_dense(),
                                      adjoint_arg=True)))

        mvn_std = tf.broadcast_to(mvn_std, self._sample_shape())
        return self._std_var_helper(mvn_std, 'standard deviation', 1, tf.sqrt)
示例#16
0
def log_cdf_laplace(x, name="log_cdf_laplace"):
    """Log Laplace distribution function.

  This function calculates `Log[L(x)]`, where `L(x)` is the cumulative
  distribution function of the Laplace distribution, i.e.

  ```L(x) := 0.5 * int_{-infty}^x e^{-|t|} dt```

  For numerical accuracy, `L(x)` is computed in different ways depending on `x`,

  ```
  x <= 0:
    Log[L(x)] = Log[0.5] + x, which is exact

  0 < x:
    Log[L(x)] = Log[1 - 0.5 * e^{-x}], which is exact
  ```

  Args:
    x: `Tensor` of type `float32`, `float64`.
    name: Python string. A name for the operation (default="log_ndtr").

  Returns:
    `Tensor` with `dtype=x.dtype`.

  Raises:
    TypeError: if `x.dtype` is not handled.
  """

    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, name="x")

        # For x < 0, L(x) = 0.5 * exp{x} exactly, so Log[L(x)] = log(0.5) + x.
        lower_solution = -np.log(2.) + x

        # safe_exp_neg_x = exp{-x} for x > 0, but is
        # bounded above by 1, which avoids
        #   log[1 - 1] = -inf for x = log(1/2), AND
        #   exp{-x} --> inf, for x << -1
        safe_exp_neg_x = tf.exp(-tf.abs(x))

        # log1p(z) = log(1 + z) approx z for |z| << 1. This approxmation is used
        # internally by log1p, rather than being done explicitly here.
        upper_solution = tf.math.log1p(-0.5 * safe_exp_neg_x)

        return tf.where(x < 0., lower_solution, upper_solution)
 def _log_normalization(self):
   """Computes the log-normalizer of the distribution."""
   event_dim = tf.compat.dimension_value(self.event_shape[0])
   if event_dim is None:
     raise ValueError('vMF _log_normalizer currently only supports '
                      'statically known event shape')
   safe_conc = tf.where(self.concentration > 0, self.concentration,
                        tf.ones_like(self.concentration))
   safe_lognorm = ((event_dim / 2 - 1) * tf.math.log(safe_conc) -
                   (event_dim / 2) * np.log(2 * np.pi) -
                   tf.math.log(_bessel_ive(event_dim / 2 - 1, safe_conc)) -
                   tf.abs(safe_conc))
   log_nsphere_surface_area = (
       np.log(2.) + (event_dim / 2) * np.log(np.pi) -
       tf.math.lgamma(tf.cast(event_dim / 2, self.dtype)))
   return tf.where(self.concentration > 0, -safe_lognorm,
                   log_nsphere_surface_area)
示例#18
0
 def _prob(self, x):
     if self.validate_args:
         is_vector_check = assert_util.assert_rank_at_least(x, 1)
         right_vec_space_check = assert_util.assert_equal(
             self.event_shape_tensor(),
             tf.gather(tf.shape(x),
                       tf.rank(x) - 1),
             message=
             "Argument 'x' not defined in the same space R^k as this distribution"
         )
         with tf.control_dependencies([is_vector_check]):
             with tf.control_dependencies([right_vec_space_check]):
                 x = tf.identity(x)
     loc = tf.convert_to_tensor(self.loc)
     return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc),
                                  axis=-1),
                    dtype=self.dtype)
示例#19
0
 def _sample_n(self, n, seed=None):
   loc = tf.convert_to_tensor(self.loc)
   scale = tf.convert_to_tensor(self.scale)
   shape = tf.concat([[n], self._batch_shape_tensor(loc=loc, scale=scale)], 0)
   # Uniform variates must be sampled from the open-interval `(-1, 1)` rather
   # than `[-1, 1)`. In the case of `(0, 1)` we'd use
   # `np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny` because it is the
   # smallest, positive, 'normal' number. However, the concept of subnormality
   # exists only at zero; here we need the smallest usable number larger than
   # -1, i.e., `-1 + eps/2`.
   dt = dtype_util.as_numpy_dtype(self.dtype)
   uniform_samples = tf.random.uniform(
       shape=shape,
       minval=np.nextafter(dt(-1.), dt(1.)),
       maxval=1.,
       dtype=self.dtype,
       seed=seed)
   return (loc - scale * tf.sign(uniform_samples) *
           tf.math.log1p(-tf.abs(uniform_samples)))
示例#20
0
def _kl_laplace_laplace(a, b, name=None):
  """Calculate the batched KL divergence KL(a || b) with a and b Laplace.

  Args:
    a: instance of a Laplace distribution object.
    b: instance of a Laplace distribution object.
    name: Python `str` name to use for created operations.
      Default value: `None` (i.e., `'kl_laplace_laplace'`).

  Returns:
    kl_div: Batchwise KL(a || b)
  """
  with tf.name_scope(name or 'kl_laplace_laplace'):
    # Consistent with
    # http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf, page 38
    distance = tf.abs(a.loc - b.loc)
    a_scale = tf.convert_to_tensor(a.scale)
    b_scale = tf.convert_to_tensor(b.scale)
    delta_log_scale = tf.math.log(a_scale) - tf.math.log(b_scale)
    return (-delta_log_scale +
            distance / b_scale - 1. +
            tf.exp(-distance / a_scale + delta_log_scale))
 def _is_equal_or_close(self, a, b):
   if dtype_util.is_integer(self.outcomes.dtype):
     return tf.equal(a, b)
   return tf.abs(a - b) < self._atol + self._rtol * tf.abs(b)
示例#22
0
 def _log_prob(self, x):
   loc = tf.convert_to_tensor(self.loc)
   scale = tf.convert_to_tensor(self.scale)
   z = (x - loc) / scale
   return -tf.abs(z) - np.log(2.) - tf.math.log(scale)
示例#23
0
 def _prob(self, x):
     loc = tf.convert_to_tensor(self.loc)
     # Enforces dtype of probability to be float, when self.dtype is not.
     prob_dtype = self.dtype if self.dtype.is_floating else tf.float32
     return tf.cast(tf.abs(x - loc) <= self._slack(loc), dtype=prob_dtype)
示例#24
0
 def _slack(self, loc):
     # Avoid using the large broadcast with self.loc if possible.
     if self.parameters["rtol"] is None:
         return self.atol
     else:
         return self.atol + self.rtol * tf.abs(loc)
示例#25
0
 def _cdf(self, x):
   z = self._z(x)
   return 0.5 - 0.5 * tf.sign(z) * tf.math.expm1(-tf.abs(z))
示例#26
0
def reduce_weighted_logsumexp(logx,
                              w=None,
                              axis=None,
                              keep_dims=False,
                              return_sign=False,
                              name=None):
    """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`.

  If all weights `w` are known to be positive, it is more efficient to directly
  use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.log(w))` is more
  efficient than `du.reduce_weighted_logsumexp(logx, w)`.

  Reduces `input_tensor` along the dimensions given in `axis`.
  Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
  entry in `axis`. If `keep_dims` is true, the reduced dimensions
  are retained with length 1.

  If `axis` has no entries, all dimensions are reduced, and a
  tensor with a single element is returned.

  This function is more numerically stable than log(sum(w * exp(input))). It
  avoids overflows caused by taking the exp of large inputs and underflows
  caused by taking the log of small inputs.

  For example:

  ```python
  x = tf.constant([[0., 0, 0],
                   [0, 0, 0]])

  w = tf.constant([[-1., 1, 1],
                   [1, 1, 1]])

  du.reduce_weighted_logsumexp(x, w)
  # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)

  du.reduce_weighted_logsumexp(x, w, axis=0)
  # ==> [log(-1+1), log(1+1), log(1+1)]

  du.reduce_weighted_logsumexp(x, w, axis=1)
  # ==> [log(-1+1+1), log(1+1+1)]

  du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
  # ==> [[log(-1+1+1)], [log(1+1+1)]]

  du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
  # ==> log(-1+5)
  ```

  Args:
    logx: The tensor to reduce. Should have numeric type.
    w: The weight tensor. Should have numeric type identical to `logx`.
    axis: The dimensions to reduce. If `None` (the default), reduces all
      dimensions. Must be in the range `[-rank(input_tensor),
      rank(input_tensor))`.
    keep_dims: If true, retains reduced dimensions with length 1.
    return_sign: If `True`, returns the sign of the result.
    name: A name for the operation (optional).

  Returns:
    lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor.
    sign: (Optional) The sign of `sum(weight * exp(x))`.
  """
    with tf.name_scope(name or 'reduce_weighted_logsumexp'):
        logx = tf.convert_to_tensor(logx, name='logx')
        if w is None:
            lswe = tf.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims)
            if return_sign:
                sgn = tf.ones_like(lswe)
                return lswe, sgn
            return lswe
        w = tf.convert_to_tensor(w, dtype=logx.dtype, name='w')
        log_absw_x = logx + tf.math.log(tf.abs(w))
        max_log_absw_x = tf.reduce_max(log_absw_x, axis=axis, keepdims=True)
        # If the largest element is `-inf` or `inf` then we don't bother subtracting
        # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
        # this is ok follows from the fact that we're actually free to subtract any
        # value we like, so long as we add it back after taking the `log(sum(...))`.
        max_log_absw_x = tf.where(tf.math.is_inf(max_log_absw_x),
                                  tf.zeros([], max_log_absw_x.dtype),
                                  max_log_absw_x)
        wx_over_max_absw_x = (tf.sign(w) * tf.exp(log_absw_x - max_log_absw_x))
        sum_wx_over_max_absw_x = tf.reduce_sum(wx_over_max_absw_x,
                                               axis=axis,
                                               keepdims=keep_dims)
        if not keep_dims:
            max_log_absw_x = tf.squeeze(max_log_absw_x, axis)
        sgn = tf.sign(sum_wx_over_max_absw_x)
        lswe = max_log_absw_x + tf.math.log(sgn * sum_wx_over_max_absw_x)
        if return_sign:
            return lswe, sgn
        return lswe
示例#27
0
def soft_threshold(x, threshold, name=None):
    """Soft Thresholding operator.

  This operator is defined by the equations

  ```none
                                { x[i] - gamma,  x[i] >   gamma
  SoftThreshold(x, gamma)[i] =  { 0,             x[i] ==  gamma
                                { x[i] + gamma,  x[i] <  -gamma
  ```

  In the context of proximal gradient methods, we have

  ```none
  SoftThreshold(x, gamma) = prox_{gamma L1}(x)
  ```

  where `prox` is the proximity operator.  Thus the soft thresholding operator
  is used in proximal gradient descent for optimizing a smooth function with
  (non-smooth) L1 regularization, as outlined below.

  The proximity operator is defined as:

  ```none
  prox_r(x) = argmin{ r(z) + 0.5 ||x - z||_2**2 : z },
  ```

  where `r` is a (weakly) convex function, not necessarily differentiable.
  Because the L2 norm is strictly convex, the above argmin is unique.

  One important application of the proximity operator is as follows.  Let `L` be
  a convex and differentiable function with Lipschitz-continuous gradient.  Let
  `R` be a convex lower semicontinuous function which is possibly
  nondifferentiable.  Let `gamma` be an arbitrary positive real.  Then

  ```none
  x_star = argmin{ L(x) + R(x) : x }
  ```

  if and only if the fixed-point equation is satisfied:

  ```none
  x_star = prox_{gamma R}(x_star - gamma grad L(x_star))
  ```

  Proximal gradient descent thus typically consists of choosing an initial value
  `x^{(0)}` and repeatedly applying the update

  ```none
  x^{(k+1)} = prox_{gamma^{(k)} R}(x^{(k)} - gamma^{(k)} grad L(x^{(k)}))
  ```

  where `gamma` is allowed to vary from iteration to iteration.  Specializing to
  the case where `R(x) = ||x||_1`, we minimize `L(x) + ||x||_1` by repeatedly
  applying the update

  ```
  x^{(k+1)} = SoftThreshold(x - gamma grad L(x^{(k)}), gamma)
  ```

  (This idea can also be extended to second-order approximations, although the
  multivariate case does not have a known closed form like above.)

  Args:
    x: `float` `Tensor` representing the input to the SoftThreshold function.
    threshold: nonnegative scalar, `float` `Tensor` representing the radius of
      the interval on which each coordinate of SoftThreshold takes the value
      zero.  Denoted `gamma` above.
    name: Python string indicating the name of the TensorFlow operation.
      Default value: `'soft_threshold'`.

  Returns:
    softthreshold: `float` `Tensor` with the same shape and dtype as `x`,
      representing the value of the SoftThreshold function.

  #### References

  [1]: Yu, Yao-Liang. The Proximity Operator.
       https://www.cs.cmu.edu/~suvrit/teach/yaoliang_proximity.pdf

  [2]: Wikipedia Contributors. Proximal gradient methods for learning.
       _Wikipedia, The Free Encyclopedia_, 2018.
       https://en.wikipedia.org/wiki/Proximal_gradient_methods_for_learning

  """
    # https://math.stackexchange.com/questions/471339/derivation-of-soft-thresholding-operator
    with tf.name_scope(name or 'soft_threshold'):
        x = tf.convert_to_tensor(x, name='x')
        threshold = tf.convert_to_tensor(threshold,
                                         dtype=x.dtype,
                                         name='threshold')
        return tf.sign(x) * tf.maximum(tf.abs(x) - threshold, 0.)
  def _sample_n(self, n, seed=None):
    seed = SeedStream(seed, salt='vom_mises_fisher')
    # The sampling strategy relies on the fact that vMF variates are symmetric
    # about the mean direction. Accordingly, if we have a sampling strategy for
    # the away-from-mean angle, then we can uniformly sample the remaining
    # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
    # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
    #
    # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
    # von-Mises distributed `x` value in [-1, 1], then uniformly select what
    # amounts to a "up" or "down" additional degree of freedom after unit
    # normalizing, followed by a final rotation to the desired mean direction
    # from a basis of (1, 0).
    #
    # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
    # unit sphere over which the distribution is uniform, in particular the
    # circle where x = \hat{x} intersects the unit sphere. We pick a point on
    # that circle, then rotate to the desired mean direction from a basis of
    # (1, 0, 0).
    event_dim = (
        tf.compat.dimension_value(self.event_shape[0]) or
        self._event_shape_tensor()[0])

    sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
    dim = tf.cast(event_dim - 1, self.dtype)
    if event_dim == 3:
      samples_dim0 = self._sample_3d(n, seed=seed)
    else:
      # Wood'94 provides a rejection algorithm to sample the x coordinate.
      # Wood'94 definition of b:
      # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
      # https://stats.stackexchange.com/questions/156729 suggests:
      b = dim / (2 * self.concentration +
                 tf.sqrt(4 * self.concentration**2 + dim**2))
      # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
      #     https://github.com/nicola-decao/s-vae-tf/
      x = (1 - b) / (1 + b)
      c = self.concentration * x + dim * tf.math.log1p(-x**2)
      beta = beta_lib.Beta(dim / 2, dim / 2)

      def cond_fn(w, should_continue):
        del w
        return tf.reduce_any(should_continue)

      def body_fn(w, should_continue):
        z = beta.sample(sample_shape=sample_batch_shape, seed=seed())
        # set_shape needed here because of b/139013403
        z.set_shape(w.shape)
        w = tf.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w)
        w = tf.debugging.check_numerics(w, 'w')
        unif = tf.random.uniform(
            sample_batch_shape, seed=seed(), dtype=self.dtype)
        # set_shape needed here because of b/139013403
        unif.set_shape(w.shape)
        should_continue = tf.logical_and(
            should_continue,
            self.concentration * w + dim * tf.math.log1p(-x * w) - c <
            tf.math.log(unif))
        return w, should_continue

      w = tf.zeros(sample_batch_shape, dtype=self.dtype)
      should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
      samples_dim0 = tf.while_loop(
          cond=cond_fn, body=body_fn, loop_vars=(w, should_continue))[0]
      samples_dim0 = samples_dim0[..., tf.newaxis]
    if not self._allow_nan_stats:
      # Verify samples are w/in -1, 1, with useful error output tensors (top
      # value rather than all values).
      with tf.control_dependencies([
          assert_util.assert_less_equal(
              samples_dim0,
              dtype_util.as_numpy_dtype(self.dtype)(1.01),
              data=[tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0]]),
          assert_util.assert_greater_equal(
              samples_dim0,
              dtype_util.as_numpy_dtype(self.dtype)(-1.01),
              data=[-tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0]])
      ]):
        samples_dim0 = tf.identity(samples_dim0)
    samples_otherdims_shape = tf.concat([sample_batch_shape, [event_dim - 1]],
                                        axis=0)
    unit_otherdims = tf.math.l2_normalize(
        tf.random.normal(
            samples_otherdims_shape, seed=seed(), dtype=self.dtype),
        axis=-1)
    samples = tf.concat([
        samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
        tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
    ], axis=-1)
    samples = tf.math.l2_normalize(samples, axis=-1)
    if not self._allow_nan_stats:
      samples = tf.debugging.check_numerics(samples, 'samples')

    # Runtime assert that samples are unit length.
    if not self._allow_nan_stats:
      worst, idx = tf.math.top_k(
          tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
      with tf.control_dependencies([
          assert_util.assert_near(
              dtype_util.as_numpy_dtype(self.dtype)(0),
              worst,
              data=[
                  worst, idx,
                  tf.gather(tf.reshape(samples, [-1, event_dim]), idx)
              ],
              atol=1e-4,
              summarize=100)
      ]):
        samples = tf.identity(samples)
    # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
    # Now, we move the mode to `self.mean_direction` using a rotation matrix.
    if not self._allow_nan_stats:
      # Assert that the basis vector rotates to the mean direction, as expected.
      basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                      self.dtype)
      with tf.control_dependencies([
          assert_util.assert_less(
              tf.linalg.norm(
                  self._rotate(basis) - self.mean_direction, axis=-1),
              dtype_util.as_numpy_dtype(self.dtype)(1e-5))
      ]):
        return self._rotate(samples)
    return self._rotate(samples)