def _inverse(self, y):
     map_values = tf.convert_to_tensor(self.map_values)
     flat_y = tf.reshape(y, shape=[-1])
     # Search for the indices of map_values that are closest to flat_y.
     # Since map_values is strictly increasing, the closest is either the
     # first one that is strictly greater than flat_y, or the one before it.
     upper_candidates = tf.minimum(
         tf.size(map_values) - 1,
         tf.searchsorted(map_values, values=flat_y, side='right'))
     lower_candidates = tf.maximum(0, upper_candidates - 1)
     candidates = tf.stack([lower_candidates, upper_candidates], axis=-1)
     lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates))
     upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates))
     if self.validate_args:
         with tf.control_dependencies([
                                         message='inverse value not found')
             candidates = tf.identity(candidates)
     candidate_selector = tf.stack([
         tf.range(tf.size(flat_y), dtype=tf.int32),
         tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32)
     return tf.reshape(tf.gather_nd(candidates, candidate_selector),
 def _log_prob(self, x):
     # By convention, we always put the grid points right-most.
     y = tf.stack([aff.inverse(x) for aff in self.interpolated_affine],
     log_prob = tf.reduce_sum(self.distribution.log_prob(y), axis=-2)
     # Because the affine transformation has a constant Jacobian, it is the case
     # that `affine.fldj(x) = -affine.ildj(x)`. This is not true in general.
     fldj = tf.stack([
             x, event_ndims=tf.rank(self.event_shape_tensor()))
         for aff in self.interpolated_affine
     return tf.reduce_logsumexp(self.mixture_distribution.logits - fldj +
Exemple #3
    def _mode(self, samples=None):
        # Samples count can vary by batch member. Use map_fn to compute mode for
        # each batch separately.
        def _get_mode(samples):
            # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed
            count = gen_array_ops.unique_with_counts_v2(samples,
            return tf.argmax(count)

        if samples is None:
            samples = tf.convert_to_tensor(self._samples)
        num_samples = self._compute_num_samples(samples)

        # Flatten samples for each batch.
        if self._event_ndims == 0:
            flattened_samples = tf.reshape(samples, [-1, num_samples])
            mode_shape = self._batch_shape_tensor(samples)
            event_size = tf.reduce_prod(self._event_shape_tensor(samples))
            mode_shape = tf.concat([
            flattened_samples = tf.reshape(samples,
                                           [-1, num_samples, event_size])

        indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64)
        full_indices = tf.stack(
             tf.cast(indices, tf.int32)],

        mode = tf.gather_nd(flattened_samples, full_indices)
        return tf.reshape(mode, mode_shape)
 def _log_variance(self):
     # Following calculation is based on law of total variance:
     # Var[Z] = E[Var[Z | V]] + Var[E[Z | V]]
     # where,
     # Z|v ~ interpolate_affine[v](dist)
     # V ~ mixture_dist
     # thus,
     # E[Var[Z | V]] = sum{ prob[d] Var[d] : d=0, ..., deg-1 }
     # Var[E[Z | V]] = sum{ prob[d] (Mean[d] - Mean)**2 : d=0, ..., deg-1 }
     distributions = self.poisson_and_mixture_distributions()
     dist, mixture_dist = distributions
     v = tf.stack(
             # log(dist.variance()) = log(Var[d]) = log(rate[d])
             # log((Mean[d] - Mean)**2)
             2. * tf.math.log(
                 tf.abs(dist.mean() - self._mean(
                     distributions=distributions)[..., tf.newaxis])),
     return tf.reduce_logsumexp(mixture_dist.logits[..., tf.newaxis] + v,
                                axis=[-2, -1])
Exemple #5
  def _stddev(self):
    with tf.control_dependencies(self._assertions):
      distribution_means = [d.mean() for d in self.components]
      distribution_devs = [d.stddev() for d in self.components]
      cat_probs = self._cat_probs(log_probs=False)

      stacked_means = tf.stack(distribution_means, axis=-1)
      stacked_devs = tf.stack(distribution_devs, axis=-1)
      cat_probs = [self._expand_to_event_rank(c_p) for c_p in cat_probs]
      broadcasted_cat_probs = (
          tf.stack(cat_probs, axis=-1) * tf.ones_like(stacked_means))

      batched_dev = distribution_util.mixture_stddev(
          tf.reshape(broadcasted_cat_probs, [-1, len(self.components)]),
          tf.reshape(stacked_means, [-1, len(self.components)]),
          tf.reshape(stacked_devs, [-1, len(self.components)]))

      # I.e. re-shape to list(batch_shape) + list(event_shape).
      return tf.reshape(batched_dev, tf.shape(broadcasted_cat_probs)[:-1])
Exemple #6
 def _log_cdf(self, x):
   with tf.control_dependencies(self._assertions):
     x = tf.convert_to_tensor(x, name="x")
     distribution_log_cdfs = [d.log_cdf(x) for d in self.components]
     cat_log_probs = self._cat_probs(log_probs=True)
     final_log_cdfs = [
         cat_lp + d_lcdf
         for (cat_lp, d_lcdf) in zip(cat_log_probs, distribution_log_cdfs)
     concatted_log_cdfs = tf.stack(final_log_cdfs, axis=0)
     mixture_log_cdf = tf.reduce_logsumexp(concatted_log_cdfs, axis=[0])
     return mixture_log_cdf
    def _log_prob(self, y, **kwargs):
        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)

        # For caching to work, it is imperative that the bijector is the first to
        # modify the input.
        x = self.bijector.inverse(y, **bijector_kwargs)
        event_ndims = self._maybe_get_static_event_ndims()

        ildj = self.bijector.inverse_log_det_jacobian(y,
        if self.bijector._is_injective:  # pylint: disable=protected-access
            return self._finish_log_prob_for_one_fiber(y, x, ildj, event_ndims,

        lp_on_fibers = [
            self._finish_log_prob_for_one_fiber(y, x_i, ildj_i, event_ndims,
            for x_i, ildj_i in zip(x, ildj)
        return tf.reduce_logsumexp(tf.stack(lp_on_fibers), axis=0)
Exemple #8
 def _sample_n(self, n, seed=None):
   # Need to create logits corresponding to [p, 1 - p].
   # Note that for this distributions, logits corresponds to
   # inverse sigmoid(p) while in multivariate distributions,
   # such as multinomial this corresponds to log(p).
   # Because of this, when we construct the logits for the multinomial
   # sampler, we'll have to be careful.
   # log(p) = log(sigmoid(logits)) = logits - softplus(logits)
   # log(1 - p) = log(1 - sigmoid(logits)) = -softplus(logits)
   # Because softmax is invariant to a constant shift in all inputs,
   # we can offset the logits by softplus(logits) so that we can use
   # [logits, 0.] as our input.
   orig_logits = self._logits_parameter_no_checks()
   logits = tf.stack([orig_logits, tf.zeros_like(orig_logits)], axis=-1)
   return multinomial.draw_sample(
       num_trials=tf.cast(self.total_count, dtype=tf.int32),
       seed=seed)[..., 0]
Exemple #9
  def _sample_n(self, n, seed=None):
    if self._use_static_graph:
      with tf.control_dependencies(self._assertions):
        # This sampling approach is almost the same as the approach used by
        # `MixtureSameFamily`. The differences are due to having a list of
        # `Distribution` objects rather than a single object, and maintaining
        # random seed management that is consistent with the non-static code
        # path.
        samples = []
        cat_samples =, seed=seed)
        stream = SeedStream(seed, salt="Mixture")

        for c in range(self.num_components):
          samples.append(self.components[c].sample(n, seed=stream()))
        stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
        x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        mask = tf.one_hot(
            indices=cat_samples,  # [n, B]
            depth=self._num_components,  # == k
            off_value=npdt(0))  # [n, B, k]
        mask = distribution_util.pad_mixture_dimensions(
            mask, self, self._cat,
            tensorshape_util.rank(self._static_event_shape))  # [n, B, k, [1]*e]
        return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

    with tf.control_dependencies(self._assertions):
      n = tf.convert_to_tensor(n, name="n")
      static_n = tf.get_static_value(n)
      n = int(static_n) if static_n is not None else n
      cat_samples =, seed=seed)

      static_samples_shape = cat_samples.shape
      if tensorshape_util.is_fully_defined(static_samples_shape):
        samples_shape = tensorshape_util.as_list(static_samples_shape)
        samples_size = tensorshape_util.num_elements(static_samples_shape)
        samples_shape = tf.shape(cat_samples)
        samples_size = tf.size(cat_samples)
      static_batch_shape = self.batch_shape
      if tensorshape_util.is_fully_defined(static_batch_shape):
        batch_shape = tensorshape_util.as_list(static_batch_shape)
        batch_size = tensorshape_util.num_elements(static_batch_shape)
        batch_shape = self.batch_shape_tensor()
        batch_size = tf.reduce_prod(batch_shape)
      static_event_shape = self.event_shape
      if tensorshape_util.is_fully_defined(static_event_shape):
        event_shape = np.array(
            tensorshape_util.as_list(static_event_shape), dtype=np.int32)
        event_shape = self.event_shape_tensor()

      # Get indices into the raw cat sampling tensor. We will
      # need these to stitch sample values back out after sampling
      # within the component partitions.
      samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape)

      # Partition the raw indices so that we can use
      # dynamic_stitch later to reconstruct the samples from the
      # known partitions.
      partitioned_samples_indices = tf.dynamic_partition(

      # Copy the batch indices n times, as we will need to know
      # these to pull out the appropriate rows within the
      # component partitions.
      batch_raw_indices = tf.reshape(
          tf.tile(tf.range(0, batch_size), [n]), samples_shape)

      # Explanation of the dynamic partitioning below:
      #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
      # Suppose partitions are:
      #     [1 1 0 0 1 1]
      # After partitioning, batch indices are cut as:
      #     [batch_indices[x] for x in 2, 3]
      #     [batch_indices[x] for x in 0, 1, 4, 5]
      # i.e.
      #     [1 1] and [0 0 0 0]
      # Now we sample n=2 from part 0 and n=4 from part 1.
      # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
      # and for part 1 we want samples from batch entries 0, 0, 0, 0
      #   (samples 0, 1, 2, 3).
      partitioned_batch_indices = tf.dynamic_partition(
      samples_class = [None for _ in range(self.num_components)]

      stream = SeedStream(seed, salt="Mixture")

      for c in range(self.num_components):
        n_class = tf.size(partitioned_samples_indices[c])
        samples_class_c = self.components[c].sample(
            n_class, seed=stream())

        # Pull out the correct batch entries from each index.
        # To do this, we may have to flatten the batch shape.

        # For sample s, batch element b of component c, we get the
        # partitioned batch indices from
        # partitioned_batch_indices[c]; and shift each element by
        # the sample index. The final lookup can be thought of as
        # a matrix gather along locations (s, b) in
        # samples_class_c where the n_class rows correspond to
        # samples within this component and the batch_size columns
        # correspond to batch elements within the component.
        # Thus the lookup index is
        #   lookup[c, i] = batch_size * s[i] + b[c, i]
        # for i = 0 ... n_class[c] - 1.
        lookup_partitioned_batch_indices = (
            batch_size * tf.range(n_class) + partitioned_batch_indices[c])
        samples_class_c = tf.reshape(
            samples_class_c, tf.concat([[n_class * batch_size], event_shape],
        samples_class_c = tf.gather(
        samples_class[c] = samples_class_c

      # Stitch back together the samples across the components.
      lhs_flat_ret = tf.dynamic_stitch(
          indices=partitioned_samples_indices, data=samples_class)
      # Reshape back to proper sample, batch, and event shape.
      ret = tf.reshape(
          lhs_flat_ret, tf.concat(
              [samples_shape, self.event_shape_tensor()], 0))
          tensorshape_util.concatenate(static_samples_shape, self.event_shape))
      return ret
Exemple #10
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
    """The inverse LU decomposition, `X == lu_reconstruct(*`.

    lower_upper: `lu` as returned by ``, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by ``, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

    x: The original input to ``, i.e., `x` as in,

  #### Examples

  import numpy as np
  from tensorflow_probability.python.internal.backend import jax as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tfp.math.lu_reconstruct(*
  tf.assert_near(x, x_reconstructed)
  # ==> True

    with tf.name_scope(name or 'lu_reconstruct'):
        lower_upper = tf.convert_to_tensor(lower_upper,
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')

        assertions = lu_reconstruct_assertions(lower_upper, perm,
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)

        shape = tf.shape(lower_upper)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(shape[:-1], dtype=lower_upper.dtype))
        upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
        x = tf.matmul(lower, upper)

        if (tensorshape_util.rank(lower_upper.shape) is None
                or tensorshape_util.rank(lower_upper.shape) != 2):
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = tf.reduce_prod(shape[:-2])
            d = shape[-1]
            x = tf.reshape(x, [batch_size, d, d])
            perm = tf.reshape(perm, [batch_size, d])
            perm = tf.map_fn(tf.math.invert_permutation, perm)
            batch_indices = tf.broadcast_to(
                tf.range(batch_size)[:, tf.newaxis], [batch_size, d])
            x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
            x = tf.reshape(x, shape)
            x = tf.gather(x, tf.math.invert_permutation(perm))

        return x
Exemple #11
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

    lower_upper: `lu` as returned by ``, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by ``, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

    x: The `X` in `A @ X = RHS`.

  #### Examples

  import numpy as np
  from tensorflow_probability.python.internal.backend import jax as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*, rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True


    with tf.name_scope(name or 'lu_solve'):
        lower_upper = tf.convert_to_tensor(lower_upper,
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        rhs = tf.convert_to_tensor(rhs,

        assertions = _lu_solve_assertions(lower_upper, perm, rhs,
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        if (tensorshape_util.rank(rhs.shape) == 2
                and tensorshape_util.rank(perm.shape) == 1):
            # Both rhs and perm have scalar batch_shape.
            permuted_rhs = tf.gather(rhs, perm, axis=-2)
            # Either rhs or perm have non-scalar batch_shape or we can't determine
            # this information statically.
            rhs_shape = tf.shape(rhs)
            broadcast_batch_shape = tf.broadcast_dynamic_shape(
            d, m = rhs_shape[-2], rhs_shape[-1]
            rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]],

            # Tile out rhs.
            broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
            broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

            # Tile out perm and add batch indices.
            broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
            broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
            broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
            broadcast_batch_indices = tf.broadcast_to(
                tf.range(broadcast_batch_size)[:, tf.newaxis],
                [broadcast_batch_size, d])
            broadcast_perm = tf.stack(
                [broadcast_batch_indices, broadcast_perm], axis=-1)

            permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
            permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
        return linear_operator_util.matrix_triangular_solve_with_broadcast(
            lower_upper,  # Only upper is accessed.
                lower, permuted_rhs),
Exemple #12
    def _sample_n(self, n, seed=None):
        sample_and_batch_shape = tf.concat([[n], self.batch_shape_tensor()], 0)
        flat_batch_and_sample_shape = tf.stack(
            [tf.reduce_prod(self.batch_shape_tensor()), n])

        # In order to be reparameterizable we sample on the truncated_normal of
        # unit variance and mean and scale (but with the standardized
        # truncation bounds).

        def _std_samples_with_gradients(lower, upper):
            """Standard truncated Normal with gradient support for low, high."""
            # Note: Unlike the convention in tf_probability,
            # parameterized_truncated_normal returns a tensor with the final dimension
            # being the sample dimension.
            std_samples = random_ops.parameterized_truncated_normal(

            def grad(dy):
                """Computes a derivative for the min and max parameters.

        This function implements the derivative wrt the truncation bounds, which
        get blocked by the sampler. We use a custom expression for numerical
        stability instead of automatic differentiation on CDF for implicit

          dy: output gradients

           The standard normal samples and the gradients wrt the upper
           bound and lower bound.
                # std_samples has an extra dimension (the sample dimension), expand
                # lower and upper so they broadcast along this dimension.
                # See note above regarding parameterized_truncated_normal, the sample
                # dimension is the final dimension.
                lower_broadcast = lower[..., tf.newaxis]
                upper_broadcast = upper[..., tf.newaxis]

                cdf_samples = ((special_math.ndtr(std_samples) -
                                special_math.ndtr(lower_broadcast)) /
                               (special_math.ndtr(upper_broadcast) -

                # tiny, eps are tolerance parameters to ensure we stay away from giving
                # a zero arg to the log CDF expression.

                tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny
                eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps
                cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps)

                du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) +
                dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) +

                # Reduce the gradient across the samples
                grad_u = tf.reduce_sum(dy * du, axis=-1)
                grad_l = tf.reduce_sum(dy * dl, axis=-1)
                return [grad_l, grad_u]

            return std_samples, grad

        std_samples = _std_samples_with_gradients(
            tf.reshape(self._standardized_low, [-1]),
            tf.reshape(self._standardized_high, [-1]))

        # The returned shape is [flat_batch x n]
        std_samples = tf.transpose(a=std_samples, perm=[1, 0])

        std_samples = tf.reshape(std_samples, sample_and_batch_shape)
        samples = (std_samples * tf.expand_dims(self._scale, axis=0) +
                   tf.expand_dims(self._loc, axis=0))

        return samples
Exemple #13
 def _event_shape_tensor(self):
     dimension = self.scale_operator.domain_dimension_tensor()
     return tf.stack([dimension, dimension])