Exemple #1
0
def sparse_or_dense_matmul(sparse_or_dense_a,
                           dense_b,
                           validate_args=False,
                           name=None,
                           **kwargs):
    """Returns (batched) matmul of a SparseTensor (or Tensor) with a Tensor.

  Args:
    sparse_or_dense_a: `SparseTensor` or `Tensor` representing a (batch of)
      matrices.
    dense_b: `Tensor` representing a (batch of) matrices, with the same batch
      shape as `sparse_or_dense_a`. The shape must be compatible with the shape
      of `sparse_or_dense_a` and kwargs.
    validate_args: When `True`, additional assertions might be embedded in the
      graph.
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: 'sparse_or_dense_matmul'.
    **kwargs: Keyword arguments to `tf.sparse_tensor_dense_matmul` or
      `tf.matmul`.

  Returns:
    product: A dense (batch of) matrix-shaped Tensor of the same batch shape and
    dtype as `sparse_or_dense_a` and `dense_b`. If `sparse_or_dense_a` or
    `dense_b` is adjointed through `kwargs` then the shape is adjusted
    accordingly.
  """
    with tf.name_scope(name or 'sparse_or_dense_matmul'):
        dense_b = tf.convert_to_tensor(dense_b,
                                       dtype_hint=tf.float32,
                                       name='dense_b')

        if validate_args:
            assert_a_rank_at_least_2 = assert_util.assert_rank_at_least(
                sparse_or_dense_a,
                rank=2,
                message=
                'Input `sparse_or_dense_a` must have at least 2 dimensions.')
            assert_b_rank_at_least_2 = assert_util.assert_rank_at_least(
                dense_b,
                rank=2,
                message='Input `dense_b` must have at least 2 dimensions.')
            with tf.control_dependencies(
                [assert_a_rank_at_least_2, assert_b_rank_at_least_2]):
                sparse_or_dense_a = tf.identity(sparse_or_dense_a)
                dense_b = tf.identity(dense_b)

        if isinstance(sparse_or_dense_a,
                      (tf.SparseTensor, tf1.SparseTensorValue)):
            return _sparse_tensor_dense_matmul(sparse_or_dense_a, dense_b,
                                               **kwargs)
        else:
            return tf.matmul(sparse_or_dense_a, dense_b, **kwargs)
Exemple #2
0
 def _inverse(self, y):
     map_values = tf.convert_to_tensor(self.map_values)
     flat_y = tf.reshape(y, shape=[-1])
     # Search for the indices of map_values that are closest to flat_y.
     # Since map_values is strictly increasing, the closest is either the
     # first one that is strictly greater than flat_y, or the one before it.
     upper_candidates = tf.minimum(
         tf.size(map_values) - 1,
         tf.searchsorted(map_values, values=flat_y, side='right'))
     lower_candidates = tf.maximum(0, upper_candidates - 1)
     candidates = tf.stack([lower_candidates, upper_candidates], axis=-1)
     lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates))
     upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates))
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_near(tf.minimum(lower_cand_diff,
                                                    upper_cand_diff),
                                         0,
                                         message='inverse value not found')
         ]):
             candidates = tf.identity(candidates)
     candidate_selector = tf.stack([
         tf.range(tf.size(flat_y), dtype=tf.int32),
         tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32)
     ],
                                   axis=-1)
     return tf.reshape(tf.gather_nd(candidates, candidate_selector),
                       shape=y.shape)
Exemple #3
0
def matrix_rank(a, tol=None, validate_args=False, name=None):
    """Compute the matrix rank; the number of non-zero SVD singular values.

  Arguments:
    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
      pseudo-inverted.
    tol: Threshold below which the singular value is counted as 'zero'.
      Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`).
    validate_args: When `True`, additional assertions might be embedded in the
      graph.
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: 'matrix_rank'.

  Returns:
    matrix_rank: (Batch of) `int32` scalars representing the number of non-zero
      singular values.
  """
    with tf.name_scope(name or 'matrix_rank'):
        a = tf.convert_to_tensor(a, dtype_hint=tf.float32, name='a')
        assertions = _maybe_validate_matrix(a, validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                a = tf.identity(a)
        s = tf.linalg.svd(a, compute_uv=False)
        if tol is None:
            if tensorshape_util.is_fully_defined(a.shape[-2:]):
                m = np.max(a.shape[-2:].as_list())
            else:
                m = tf.reduce_max(tf.shape(a)[-2:])
            eps = np.finfo(dtype_util.as_numpy_dtype(a.dtype)).eps
            tol = (eps * tf.cast(m, a.dtype) *
                   tf.reduce_max(s, axis=-1, keepdims=True))
        return tf.reduce_sum(tf.cast(s > tol, tf.int32), axis=-1)
Exemple #4
0
def assert_finite(x, data=None, summarize=None, message=None, name=None):
    """Assert all elements of `x` are finite.

  Args:
    x:  Numeric `Tensor`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).
      Defaults to "assert_finite".

  Returns:
    Op raising `InvalidArgumentError` unless `x` has specified rank or lower.
    If static checks determine `x` has correct rank, a `no_op` is returned.

  Raises:
    ValueError:  If static checks determine `x` has wrong rank.
  """
    with tf.name_scope(name or 'assert_finite'):
        x_ = tf.get_static_value(x)
        if x_ is not None:
            if ~np.all(np.isfinite(x_)):
                raise ValueError(message)
            return x
        assertion = tf1.assert_equal(tf.math.is_finite(x),
                                     tf.ones_like(x, tf.bool),
                                     data=data,
                                     summarize=summarize,
                                     message=message)
        with tf.control_dependencies([assertion]):
            return tf.identity(x)
Exemple #5
0
    def _mode(self):
        scale_x_zeros = self.bijector.scale.matvec(
            tf.zeros(self._mode_mean_shape(), self.dtype))

        if self.loc is None:
            return scale_x_zeros

        return tf.identity(self.loc) + scale_x_zeros
 def _inverse(self, y):
     x = tf.identity(y)
     if self.shift is not None:
         x = x - self.shift
     if self.scale is not None:
         x = x / self.scale
     if self.log_scale is not None:
         x = x * tf.exp(-self.log_scale)
     return x
 def _forward(self, x):
     y = tf.identity(x)
     if self.scale is not None:
         y = y * self.scale
     if self.log_scale is not None:
         y = y * tf.exp(self.log_scale)
     if self.shift is not None:
         y = y + self.shift
     return y
Exemple #8
0
 def _inverse_event_shape_tensor(self, output_shape):
     if self.validate_args:
         # It is not possible for a negative shape so we need only check <= 1.
         is_greater_one = assert_util.assert_greater(
             output_shape[-1],
             1,
             message="Need last dimension greater than 1.")
         with tf.control_dependencies([is_greater_one]):
             output_shape = tf.identity(output_shape)
     return tf.concat([output_shape[:-1], [output_shape[-1] - 1]], axis=0)
    def _mean(self):
        shape = tensorshape_util.concatenate(self.batch_shape,
                                             self.event_shape)
        has_static_shape = tensorshape_util.is_fully_defined(shape)
        if not has_static_shape:
            shape = tf.concat([
                self.batch_shape_tensor(),
                self.event_shape_tensor(),
            ], 0)

        if self.loc is None:
            return tf.zeros(shape, self.dtype)

        if has_static_shape and shape == self.loc.shape:
            return tf.identity(self.loc)

        # Add dummy tensor of zeros to broadcast.  This is only necessary if shape
        # != self.loc.shape, but we could not determine if this is the case.
        return tf.identity(self.loc) + tf.zeros(shape, self.dtype)
Exemple #10
0
    def _mean(self):
        # Let
        #   W = (w1,...,wk), with wj ~ iid Exponential(0, 1).
        # Then this distribution is
        #   X = loc + LW,
        # and then E[X] = loc + L1, where 1 is the vector of ones.
        scale_x_ones = self.bijector.scale.matvec(
            tf.ones(self._mode_mean_shape(), self.dtype))

        if self.loc is None:
            return scale_x_ones

        return tf.identity(self.loc) + scale_x_ones
 def _z(self, x, scale, concentration):
     loc = tf.convert_to_tensor(self.loc)
     if self.validate_args:
         valid = (x >= loc) & ((concentration >= 0) |
                               (x <= loc - scale / concentration))
         with tf.control_dependencies([
                 assert_util.assert_equal(
                     valid,
                     True,
                     message='`x` outside distribution\'s support.')
         ]):
             x = tf.identity(x)
     return (x - loc) / scale
Exemple #12
0
 def _forward(self, x):
     map_values = tf.convert_to_tensor(self.map_values)
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_equal(
                     (0 <= x) & (x < tf.size(map_values)),
                     True,
                     message='indices out of bound')
         ]):
             x = tf.identity(x)
     # If we want batch dims in self.map_values, we can (after broadcasting),
     # use:
     # tf.gather(self.map_values, x, batch_dims=-1, axis=-1)
     return tf.gather(map_values, indices=x)
 def _maybe_assert_valid_sample(self, samples):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return samples
     with tf.control_dependencies([
             assert_util.assert_near(1.,
                                     tf.linalg.norm(samples, axis=-1),
                                     message='samples must be unit length'),
             assert_util.assert_equal(
                 tf.shape(samples)[-1:],
                 self.event_shape_tensor(),
                 message=
                 ('samples must have innermost dimension matching that of '
                  '`self.mean_direction`')),
     ]):
         return tf.identity(samples)
Exemple #14
0
 def _prob(self, x):
     if self.validate_args:
         is_vector_check = assert_util.assert_rank_at_least(x, 1)
         right_vec_space_check = assert_util.assert_equal(
             self.event_shape_tensor(),
             tf.gather(tf.shape(x),
                       tf.rank(x) - 1),
             message=
             "Argument 'x' not defined in the same space R^k as this distribution"
         )
         with tf.control_dependencies([is_vector_check]):
             with tf.control_dependencies([right_vec_space_check]):
                 x = tf.identity(x)
     loc = tf.convert_to_tensor(self.loc)
     return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc),
                                  axis=-1),
                    dtype=self.dtype)
    def _mean(self):
        concentration = tf.convert_to_tensor(self.concentration)
        mixing_concentration = tf.convert_to_tensor(self.mixing_concentration)
        mixing_rate = tf.convert_to_tensor(self.mixing_rate)

        mean = concentration * mixing_rate / (mixing_concentration - 1.)
        if self.allow_nan_stats:
            return tf.where(mixing_concentration > 1., mean,
                            dtype_util.as_numpy_dtype(self.dtype)(np.nan))
        else:
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.ones([], self.dtype),
                        mixing_concentration,
                        message=
                        'mean undefined when `mixing_concentration` <= 1'),
            ]):
                return tf.identity(mean)
 def _mode(self):
   concentration = tf.convert_to_tensor(self.concentration)
   k = tf.cast(tf.shape(concentration)[-1], self.dtype)
   total_concentration = tf.reduce_sum(concentration, axis=-1)
   mode = (concentration - 1.) / (total_concentration[..., tf.newaxis] - k)
   if self.allow_nan_stats:
     return tf.where(
         tf.reduce_all(concentration > 1., axis=-1, keepdims=True),
         mode,
         dtype_util.as_numpy_dtype(self.dtype)(np.nan))
   assertions = [
       assert_util.assert_less(
           tf.ones([], self.dtype),
           concentration,
           message='Mode undefined when any concentration <= 1')
   ]
   with tf.control_dependencies(assertions):
     return tf.identity(mode)
def _prevent_2nd_derivative(x):
    """Disables computation of the second derivatives for a tensor.

  NB: you need to apply a non-identity function to the output tensor for the
  exception to be raised.

  Arguments:
    x: A tensor.

  Returns:
    A tensor with the same value and the same derivative as x, but that raises
    LookupError when trying to compute the second derivatives.
  """
    def grad(dy):
        return array_ops.prevent_gradient(
            dy, message="Second derivative is not implemented.")

    return tf.identity(x), grad
def maybe_check_wont_broadcast(flat_xs, validate_args):
    """Verifies that `parts` don't broadcast."""
    flat_xs = tuple(flat_xs)  # So we can receive generators.
    if not validate_args:
        # Note: we don't try static validation because it is theoretically
        # possible that a user wants to take advantage of broadcasting.
        # Only when `validate_args` is `True` do we enforce the validation.
        return flat_xs
    msg = 'Broadcasting probably indicates an error in model specification.'
    s = tuple(prefer_static.shape(x) for x in flat_xs)
    if all(prefer_static.is_numpy(s_) for s_ in s):
        if not all(np.all(a == b) for a, b in zip(s[1:], s[:-1])):
            raise ValueError(msg)
        return flat_xs
    assertions = [
        assert_util.assert_equal(a, b, message=msg)
        for a, b in zip(s[1:], s[:-1])
    ]
    with tf.control_dependencies(assertions):
        return tuple(tf.identity(x) for x in flat_xs)
    def _variance(self):
        concentration = tf.convert_to_tensor(self.concentration)
        mixing_concentration = tf.convert_to_tensor(self.mixing_concentration)
        mixing_rate = tf.convert_to_tensor(self.mixing_rate)

        variance = (tf.square(concentration * mixing_rate /
                              (mixing_concentration - 1.)) /
                    (mixing_concentration - 2.))
        if self.allow_nan_stats:
            return tf.where(mixing_concentration > 2., variance,
                            dtype_util.as_numpy_dtype(self.dtype)(np.nan))
        else:
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.ones([], self.dtype) * 2.,
                        mixing_concentration,
                        message=
                        'variance undefined when `mixing_concentration` <= 2')
            ]):
                return tf.identity(variance)
def _validate_block_sizes(block_sizes, bijectors, validate_args):
  """Helper to validate block sizes."""
  block_sizes_shape = block_sizes.shape
  if tensorshape_util.is_fully_defined(block_sizes_shape):
    if (tensorshape_util.rank(block_sizes_shape) != 1 or
        (tensorshape_util.num_elements(block_sizes_shape) != len(bijectors))):
      raise ValueError(
          '`block_sizes` must be `None`, or a vector of the same length as '
          '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of '
          'length {}'.format(block_sizes_shape, len(bijectors)))
    return block_sizes
  elif validate_args:
    message = ('`block_sizes` must be `None`, or a vector of the same length '
               'as `bijectors`.')
    with tf.control_dependencies([
        assert_util.assert_equal(
            tf.size(block_sizes), len(bijectors), message=message),
        assert_util.assert_equal(tf.rank(block_sizes), 1)
    ]):
      return tf.identity(block_sizes)
  else:
    return block_sizes
Exemple #21
0
 def _validate_correlationness(self, x):
     if not self.validate_args or self.input_output_cholesky:
         return x
     checks = [
         assert_util.assert_less_equal(
             dtype_util.as_numpy_dtype(x.dtype)(-1),
             x,
             message='Correlations must be >= -1.'),
         assert_util.assert_less_equal(
             x,
             dtype_util.as_numpy_dtype(x.dtype)(1),
             message='Correlations must be <= 1.'),
         assert_util.assert_near(tf.linalg.diag_part(x),
                                 dtype_util.as_numpy_dtype(x.dtype)(1),
                                 message='Self-correlations must be = 1.'),
         assert_util.assert_near(
             x,
             tf.linalg.matrix_transpose(x),
             message='Correlation matrices must be symmetric')
     ]
     with tf.control_dependencies(checks):
         return tf.identity(x)
  def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None):
    """Executes `model`, creating both samples and distributions."""
    ds = []
    values_out = []
    seed = SeedStream('JointDistributionCoroutine', seed)
    gen = self._model()
    index = 0
    d = next(gen)
    if not isinstance(d, self.Root):
      raise ValueError('First distribution yielded by coroutine must '
                       'be wrapped in `Root`.')
    try:
      while True:
        actual_distribution = d.distribution if isinstance(d, self.Root) else d
        ds.append(actual_distribution)
        if (value is not None and len(value) > index and
            value[index] is not None):
          seed()
          next_value = value[index]
        else:
          next_value = actual_distribution.sample(
              sample_shape=sample_shape if isinstance(d, self.Root) else (),
              seed=seed())

        if self._validate_args:
          with tf.control_dependencies(
              self._assert_compatible_shape(
                  index, sample_shape, next_value)):
            values_out.append(tf.identity(next_value))
        else:
          values_out.append(next_value)

        index += 1
        d = gen.send(next_value)
    except StopIteration:
      pass
    return ds, values_out
    def _prob(self, x):
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)

        if self.validate_args:
            with tf.control_dependencies([
                    assert_util.assert_greater_equal(x, low),
                    assert_util.assert_less_equal(x, high)
            ]):
                x = tf.identity(x)

        interval_length = high - low
        # This is the pdf function when a low <= high <= x. This looks like
        # a triangle, so we have to treat each line segment separately.
        result_inside_interval = tf.where(
            (x >= low) & (x <= peak),
            # Line segment from (low, 0) to (peak, 2 / (high - low)).
            2. * (x - low) / (interval_length * (peak - low)),
            # Line segment from (peak, 2 / (high - low)) to (high, 0).
            2. * (high - x) / (interval_length * (high - peak)))

        return tf.where((x < low) | (x > high), tf.zeros_like(x),
                        result_inside_interval)
Exemple #24
0
 def _validate_dimension(self, x):
     x = tf.convert_to_tensor(x, name='x')
     if tensorshape_util.is_fully_defined(x.shape[-2:]):
         if (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims(
                 x.shape)[-1] == self.dimension):
             pass
         else:
             raise ValueError(
                 'Input dimension mismatch: expected [..., {}, {}], got {}'.
                 format(self.dimension, self.dimension,
                        tensorshape_util.dims(x.shape)))
     elif self.validate_args:
         msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
             self.dimension, self.dimension, tf.shape(x))
         with tf.control_dependencies([
                 assert_util.assert_equal(tf.shape(x)[-2],
                                          self.dimension,
                                          message=msg),
                 assert_util.assert_equal(tf.shape(x)[-1],
                                          self.dimension,
                                          message=msg)
         ]):
             x = tf.identity(x)
     return x
Exemple #25
0
 def _probs_parameter_no_checks(self):
     if self._logits is None:
         return tf.identity(self._probs)
     return tf.math.softmax(self._logits)
def kl_divergence(distribution_a,
                  distribution_b,
                  allow_nan_stats=True,
                  name=None):
    """Get the KL-divergence KL(distribution_a || distribution_b).

  If there is no KL method registered specifically for `type(distribution_a)`
  and `type(distribution_b)`, then the class hierarchies of these types are
  searched.

  If one KL method is registered between any pairs of classes in these two
  parent hierarchies, it is used.

  If more than one such registered method exists, the method whose registered
  classes have the shortest sum MRO paths to the input types is used.

  If more than one such shortest path exists, the first method
  identified in the search is used (favoring a shorter MRO distance to
  `type(distribution_a)`).

  Args:
    distribution_a: The first distribution.
    distribution_b: The second distribution.
    allow_nan_stats: Python `bool`, default `True`. When `True`,
      statistics (e.g., mean, mode, variance) use the value "`NaN`" to
      indicate the result is undefined. When `False`, an exception is raised
      if one or more of the statistic's batch members are undefined.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    A Tensor with the batchwise KL-divergence between `distribution_a`
    and `distribution_b`.

  Raises:
    NotImplementedError: If no KL method is defined for distribution types
      of `distribution_a` and `distribution_b`.
  """
    kl_fn = _registered_kl(type(distribution_a), type(distribution_b))
    if kl_fn is None:
        raise NotImplementedError(
            "No KL(distribution_a || distribution_b) registered for distribution_a "
            "type {} and distribution_b type {}".format(
                type(distribution_a).__name__,
                type(distribution_b).__name__))

    name = name or "KullbackLeibler"
    with tf.name_scope(name):
        # pylint: disable=protected-access
        with distribution_a._name_and_control_scope(name + "_a"):
            with distribution_b._name_and_control_scope(name + "_b"):
                kl_t = kl_fn(distribution_a, distribution_b, name=name)
                if allow_nan_stats:
                    return kl_t

        # Check KL for NaNs
        kl_t = tf.identity(kl_t, name="kl")

        with tf.control_dependencies([
                tf.Assert(
                    tf.logical_not(tf.reduce_any(tf.math.is_nan(kl_t))),
                    [("KL calculation between {} and {} returned NaN values "
                      "(and was called with allow_nan_stats=False). Values:".
                      format(distribution_a.name, distribution_b.name)), kl_t])
        ]):
            return tf.identity(kl_t, name="checked_kl")
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        logits = self._logits
        probs = self._probs
        param, name = (probs, 'probs') if logits is None else (logits,
                                                               'logits')

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:
            if not dtype_util.is_floating(param.dtype):
                raise TypeError(
                    'Argument `{}` must having floating type.'.format(name))

            msg = 'Argument `{}` must have rank at least 1.'.format(name)
            shape_static = tensorshape_util.dims(param.shape)
            if shape_static is not None:
                if len(shape_static) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_rank_at_least(param, 1, message=msg))
                with tf.control_dependencies(assertions):
                    param = tf.identity(param)

            msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
            msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
                name, tf.int32.max)
            event_size = shape_static[-1] if shape_static is not None else None
            if event_size is not None:
                if event_size < 1:
                    raise ValueError(msg1)
                if event_size > tf.int32.max:
                    raise ValueError(msg2)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_greater_equal(tf.shape(param)[-1],
                                                     1,
                                                     message=msg1))
                # NOTE: For now, we leave out a runtime assertion that
                # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
                # will fail before we get to this point.

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if probs is not None:
            probs = param  # reuse tensor conversion from above
            if is_init != tensor_util.is_ref(probs):
                probs = tf.convert_to_tensor(probs)
                one = tf.ones([], dtype=probs.dtype)
                assertions.extend([
                    assert_util.assert_non_negative(probs),
                    assert_util.assert_less_equal(probs, one),
                    assert_util.assert_near(
                        tf.reduce_sum(probs, axis=-1),
                        one,
                        message='Argument `probs` must sum to 1.'),
                ])

        return assertions
    def __init__(self,
                 perm=None,
                 rightmost_transposed_ndims=None,
                 validate_args=False,
                 name='transpose'):
        """Instantiates the `Transpose` bijector.

    Args:
      perm: Positive `int32` vector-shaped `Tensor` representing permutation of
        rightmost dims (for forward transformation).  Note that the `0`th index
        represents the first of the rightmost dims and the largest value must be
        `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified.
        Default value:
        `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`.
      rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor`
        representing the number of rightmost dimensions to permute.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified.
        Default value: `tf.size(perm)`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.

    Raises:
      ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are
        specified.
      NotImplementedError: if `rightmost_transposed_ndims` is not known prior to
        graph execution.
    """
        with tf.name_scope(name) as name:
            if (rightmost_transposed_ndims is None) == (perm is None):
                raise ValueError('Must specify exactly one of '
                                 '`rightmost_transposed_ndims` and `perm`.')
            if rightmost_transposed_ndims is not None:
                rightmost_transposed_ndims = tf.convert_to_tensor(
                    rightmost_transposed_ndims,
                    dtype_hint=np.int32,
                    name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                assertions = _maybe_validate_rightmost_transposed_ndims(
                    rightmost_transposed_ndims, validate_args)
                if assertions:
                    with tf.control_dependencies(assertions):
                        rightmost_transposed_ndims = tf.identity(
                            rightmost_transposed_ndims)
                perm_start = (distribution_util.prefer_static_value(
                    rightmost_transposed_ndims) - 1)
                perm = tf.range(start=perm_start,
                                limit=-1,
                                delta=-1,
                                name='perm')
            else:  # perm is not None:
                perm = tf.convert_to_tensor(perm,
                                            dtype_hint=np.int32,
                                            name='perm')
                rightmost_transposed_ndims = tf.size(
                    perm, name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                assertions = _maybe_validate_perm(perm, validate_args)
                if assertions:
                    with tf.control_dependencies(assertions):
                        perm = tf.identity(perm)

            # TODO(b/110828604): If bijector base class ever supports dynamic
            # `min_event_ndims`, then this class already works dynamically and the
            # following five lines can be removed.
            if rightmost_transposed_ndims_ is None:
                raise NotImplementedError(
                    '`rightmost_transposed_ndims` must be '
                    'known prior to graph execution.')
            else:
                rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_)

            self._perm = perm
            self._rightmost_transposed_ndims = rightmost_transposed_ndims
            super(Transpose, self).__init__(
                forward_min_event_ndims=rightmost_transposed_ndims_,
                is_constant_jacobian=True,
                validate_args=validate_args,
                name=name)
    def _sample_n(self, n, seed=None):
        seed = SeedStream(seed, salt='vom_mises_fisher')
        # The sampling strategy relies on the fact that vMF variates are symmetric
        # about the mean direction. Accordingly, if we have a sampling strategy for
        # the away-from-mean angle, then we can uniformly sample the remaining
        # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
        # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
        #
        # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
        # von-Mises distributed `x` value in [-1, 1], then uniformly select what
        # amounts to a "up" or "down" additional degree of freedom after unit
        # normalizing, followed by a final rotation to the desired mean direction
        # from a basis of (1, 0).
        #
        # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
        # unit sphere over which the distribution is uniform, in particular the
        # circle where x = \hat{x} intersects the unit sphere. We pick a point on
        # that circle, then rotate to the desired mean direction from a basis of
        # (1, 0, 0).
        event_dim = (tf.compat.dimension_value(self.event_shape[0])
                     or self._event_shape_tensor()[0])

        sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()],
                                       axis=0)
        dim = tf.cast(event_dim - 1, self.dtype)
        if event_dim == 3:
            samples_dim0 = self._sample_3d(n, seed=seed)
        else:
            # Wood'94 provides a rejection algorithm to sample the x coordinate.
            # Wood'94 definition of b:
            # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
            # https://stats.stackexchange.com/questions/156729 suggests:
            b = dim / (2 * self.concentration +
                       tf.sqrt(4 * self.concentration**2 + dim**2))
            # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
            #     https://github.com/nicola-decao/s-vae-tf/
            x = (1 - b) / (1 + b)
            c = self.concentration * x + dim * tf.math.log1p(-x**2)
            beta = beta_lib.Beta(dim / 2, dim / 2)

            def cond_fn(w, should_continue):
                del w
                return tf.reduce_any(should_continue)

            def body_fn(w, should_continue):
                z = beta.sample(sample_shape=sample_batch_shape, seed=seed())
                # set_shape needed here because of b/139013403
                z.set_shape(w.shape)
                w = tf.where(should_continue,
                             (1 - (1 + b) * z) / (1 - (1 - b) * z), w)
                w = tf.debugging.check_numerics(w, 'w')
                unif = tf.random.uniform(sample_batch_shape,
                                         seed=seed(),
                                         dtype=self.dtype)
                # set_shape needed here because of b/139013403
                unif.set_shape(w.shape)
                should_continue = tf.logical_and(
                    should_continue,
                    self.concentration * w + dim * tf.math.log1p(-x * w) - c <
                    tf.math.log(unif))
                return w, should_continue

            w = tf.zeros(sample_batch_shape, dtype=self.dtype)
            should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
            samples_dim0 = tf.while_loop(cond=cond_fn,
                                         body=body_fn,
                                         loop_vars=(w, should_continue))[0]
            samples_dim0 = samples_dim0[..., tf.newaxis]
        if not self._allow_nan_stats:
            # Verify samples are w/in -1, 1, with useful error output tensors (top
            # value rather than all values).
            with tf.control_dependencies([
                    assert_util.assert_less_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(1.01),
                        data=[
                            tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0]
                        ]),
                    assert_util.assert_greater_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(-1.01),
                        data=[
                            -tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0]
                        ])
            ]):
                samples_dim0 = tf.identity(samples_dim0)
        samples_otherdims_shape = tf.concat(
            [sample_batch_shape, [event_dim - 1]], axis=0)
        unit_otherdims = tf.math.l2_normalize(tf.random.normal(
            samples_otherdims_shape, seed=seed(), dtype=self.dtype),
                                              axis=-1)
        samples = tf.concat(
            [
                samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
                tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
            ],
            axis=-1)
        samples = tf.math.l2_normalize(samples, axis=-1)
        if not self._allow_nan_stats:
            samples = tf.debugging.check_numerics(samples, 'samples')

        # Runtime assert that samples are unit length.
        if not self._allow_nan_stats:
            worst, idx = tf.math.top_k(
                tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
            with tf.control_dependencies([
                    assert_util.assert_near(
                        dtype_util.as_numpy_dtype(self.dtype)(0),
                        worst,
                        data=[
                            worst, idx,
                            tf.gather(tf.reshape(samples, [-1, event_dim]),
                                      idx)
                        ],
                        atol=1e-4,
                        summarize=100)
            ]):
                samples = tf.identity(samples)
        # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
        # Now, we move the mode to `self.mean_direction` using a rotation matrix.
        if not self._allow_nan_stats:
            # Assert that the basis vector rotates to the mean direction, as expected.
            basis = tf.cast(
                tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                self.dtype)
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.linalg.norm(self._rotate(basis) -
                                       self.mean_direction,
                                       axis=-1),
                        dtype_util.as_numpy_dtype(self.dtype)(1e-5))
            ]):
                return self._rotate(samples)
        return self._rotate(samples)
    def __init__(self,
                 mean_direction,
                 concentration,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='VonMisesFisher'):
        """Creates a new `VonMisesFisher` instance.

    Args:
      mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D].
        A unit vector indicating the mode of the distribution, or the
        unit-normalized direction of the mean. (This is *not* in general the
        mean of the distribution; the mean is not generally in the support of
        the distribution.) NOTE: `D` is currently restricted to <= 5.
      concentration: Floating-point `Tensor` having batch shape [B1, ... Bn]
        broadcastable with `mean_direction`. The level of concentration of
        samples around the `mean_direction`. `concentration=0` indicates a
        uniform distribution over the unit hypersphere, and `concentration=+inf`
        indicates a `Deterministic` distribution (delta function) at
        `mean_direction`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: For known-bad arguments, i.e. unsupported event dimension.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([mean_direction, concentration],
                                            tf.float32)
            mean_direction = tf.convert_to_tensor(mean_direction,
                                                  name='mean_direction',
                                                  dtype=dtype)
            concentration = tf.convert_to_tensor(concentration,
                                                 name='concentration',
                                                 dtype=dtype)
            assertions = [
                assert_util.assert_non_negative(
                    concentration,
                    message='`concentration` must be non-negative'),
                assert_util.assert_greater(
                    tf.shape(mean_direction)[-1],
                    1,
                    message='`mean_direction` may not have scalar event shape'
                ),
                assert_util.assert_near(
                    1.,
                    tf.linalg.norm(mean_direction, axis=-1),
                    message='`mean_direction` must be unit-length')
            ] if validate_args else []
            static_event_dim = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(mean_direction.shape,
                                                    1)[-1])
            if static_event_dim is not None and static_event_dim > 5:
                raise ValueError('vMF ndims > 5 is not currently supported')
            elif validate_args:
                assertions += [
                    assert_util.assert_less_equal(
                        tf.shape(mean_direction)[-1],
                        5,
                        message='vMF ndims > 5 is not currently supported')
                ]
            with tf.control_dependencies(assertions):
                self._mean_direction = tf.identity(mean_direction)
                self._concentration = tf.identity(concentration)
            dtype_util.assert_same_float_dtype(
                [self._mean_direction, self._concentration])
            # mean_direction is always reparameterized.
            # concentration is only for event_dim==3, via an inversion sampler.
            reparameterization_type = (reparameterization.FULLY_REPARAMETERIZED
                                       if static_event_dim == 3 else
                                       reparameterization.NOT_REPARAMETERIZED)
            super(VonMisesFisher, self).__init__(
                dtype=self._concentration.dtype,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                reparameterization_type=reparameterization_type,
                parameters=parameters,
                name=name)