def quadrature_scheme_softmaxnormal_gauss_hermite(normal_loc,
                                                  normal_scale,
                                                  quadrature_size,
                                                  validate_args=False,
                                                  name=None):
    """Use Gauss-Hermite quadrature to form quadrature on `K - 1` simplex.

  A `SoftmaxNormal` random variable `Y` may be generated via

  ```
  Y = SoftmaxCentered(X),
  X = Normal(normal_loc, normal_scale)
  ```

  Note: for a given `quadrature_size`, this method is generally less accurate
  than `quadrature_scheme_softmaxnormal_quantiles`.

  Args:
    normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
      The location parameter of the Normal used to construct the SoftmaxNormal.
    normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
      The scale parameter of the Normal used to construct the SoftmaxNormal.
    quadrature_size: Python `int` scalar representing the number of quadrature
      points.
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      convex combination of affine parameters for `K` components.
      `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex.
    probs:  Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      associated with each grid point.
  """
    with tf.name_scope(name
                       or "quadrature_scheme_softmaxnormal_gauss_hermite"):
        normal_loc = tf.convert_to_tensor(normal_loc, name="normal_loc")
        npdt = dtype_util.as_numpy_dtype(normal_loc.dtype)
        normal_scale = tf.convert_to_tensor(normal_scale,
                                            dtype=npdt,
                                            name="normal_scale")

        normal_scale = maybe_check_quadrature_param(normal_scale,
                                                    "normal_scale",
                                                    validate_args)

        grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size)
        grid = grid.astype(npdt)
        probs = probs.astype(npdt)
        probs /= np.linalg.norm(probs, ord=1, keepdims=True)
        probs = tf.convert_to_tensor(probs, name="probs", dtype=npdt)

        grid = softmax(-distribution_util.pad(
            (normal_loc[..., tf.newaxis] +
             np.sqrt(2.) * normal_scale[..., tf.newaxis] * grid),
            axis=-2,
            front=True),
                       axis=-2)  # shape: [B, components, deg]

        return grid, probs
Example #2
0
 def __init__(self, validate_args=False, name="ordered"):
     with tf.name_scope(name) as name:
         super(Ordered, self).__init__(forward_min_event_ndims=1,
                                       validate_args=validate_args,
                                       name=name)
Example #3
0
    def __init__(self,
                 distribution,
                 bijector,
                 batch_shape=None,
                 event_shape=None,
                 kwargs_split_fn=_default_kwargs_split_fn,
                 validate_args=False,
                 parameters=None,
                 name=None):
        """Construct a Transformed Distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      bijector: The object responsible for calculating the transformation.
        Typically an instance of `Bijector`.
      batch_shape: `integer` vector `Tensor` which overrides `distribution`
        `batch_shape`; valid only if `distribution.is_scalar_batch()`.
      event_shape: `integer` vector `Tensor` which overrides `distribution`
        `event_shape`; valid only if `distribution.is_scalar_event()`.
      kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns
        a tuple of kwargs `dict`s for each of the `distribution` and `bijector`
        parameters respectively.
        Default value: `_default_kwargs_split_fn` (i.e.,
            `lambda kwargs: (kwargs.get('distribution_kwargs', {}),
                             kwargs.get('bijector_kwargs', {}))`)
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      parameters: Locals dict captured by subclass constructor, to be used for
        copy/slice re-instantiation operations.
      name: Python `str` name prefixed to Ops created by this class. Default:
        `bijector.name + distribution.name`.
    """
        parameters = dict(locals()) if parameters is None else parameters
        name = name or (("" if bijector is None else bijector.name) +
                        (distribution.name or ""))
        with tf.name_scope(name) as name:
            self._kwargs_split_fn = (_default_kwargs_split_fn
                                     if kwargs_split_fn is None else
                                     kwargs_split_fn)
            # For convenience we define some handy constants.
            self._zero = tf.constant(0, dtype=tf.int32, name="zero")
            self._empty = tf.constant([], dtype=tf.int32, name="empty")

            # We will keep track of a static and dynamic version of
            # self._is_{batch,event}_override. This way we can do more prior to graph
            # execution, including possibly raising Python exceptions.

            self._override_batch_shape = self._maybe_validate_shape_override(
                batch_shape, distribution.is_scalar_batch(), validate_args,
                "batch_shape")
            self._is_batch_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_batch_shape),
                    self._zero))
            self._is_maybe_batch_override = bool(
                tf.get_static_value(self._override_batch_shape) is None
                or tf.get_static_value(self._override_batch_shape).size != 0)

            self._override_event_shape = self._maybe_validate_shape_override(
                event_shape, distribution.is_scalar_event(), validate_args,
                "event_shape")
            self._is_event_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_event_shape),
                    self._zero))
            self._is_maybe_event_override = bool(
                tf.get_static_value(self._override_event_shape) is None
                or tf.get_static_value(self._override_event_shape).size != 0)

            # To convert a scalar distribution into a multivariate distribution we
            # will draw dims from the sample dims, which are otherwise iid. This is
            # easy to do except in the case that the base distribution has batch dims
            # and we're overriding event shape. When that case happens the event dims
            # will incorrectly be to the left of the batch dims. In this case we'll
            # cyclically permute left the new dims.
            self._needs_rotation = prefer_static.reduce_all([
                self._is_event_override,
                prefer_static.logical_not(self._is_batch_override),
                prefer_static.logical_not(distribution.is_scalar_batch())
            ])
            override_event_ndims = prefer_static.rank_from_shape(
                self._override_event_shape)
            self._rotate_ndims = _pick_scalar_condition(
                self._needs_rotation, override_event_ndims, 0)
            # We'll be reducing the head dims (if at all), i.e., this will be []
            # if we don't need to reduce.
            self._reduce_event_indices = prefer_static.range(
                self._rotate_ndims - override_event_ndims, self._rotate_ndims)

        self._distribution = distribution
        self._bijector = bijector
        super(TransformedDistribution, self).__init__(
            dtype=self._distribution.dtype,
            reparameterization_type=self._distribution.reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=self._distribution.allow_nan_stats,
            parameters=parameters,
            name=name)
Example #4
0
    def __init__(self,
                 loc=None,
                 scale_diag=None,
                 scale_identity_multiplier=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VectorExponentialDiag"):
        """Construct Vector Exponential distribution supported on a subset of `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    ```none
    scale = diag(scale_diag + scale_identity_multiplier * ones(k))
    ```

    where:

    * `scale_diag.shape = [k]`, and,
    * `scale_identity_multiplier.shape = []`.

    Additional leading dimensions (if any) will index batches.

    If both `scale_diag` and `scale_identity_multiplier` are `None`, then
    `scale` is the Identity matrix.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
        matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
        and characterizes `b`-batches of `k x k` diagonal matrices added to
        `scale`. When both `scale_identity_multiplier` and `scale_diag` are
        `None` then `scale` is the `Identity`.
      scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
        a scaled-identity-matrix added to `scale`. May have shape
        `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled
        `k x k` identity matrices added to `scale`. When both
        `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is
        the `Identity`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if at most `scale_identity_multiplier` is specified.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            with tf.name_scope("init"):
                # No need to validate_args while making diag_scale.  The returned
                # LinearOperatorDiag has an assert_non_singular method that is called by
                # the Bijector.
                scale = distribution_util.make_diag_scale(
                    loc=loc,
                    scale_diag=scale_diag,
                    scale_identity_multiplier=scale_identity_multiplier,
                    validate_args=False,
                    assert_positive=False)
        super(VectorExponentialDiag,
              self).__init__(loc=loc,
                             scale=scale,
                             validate_args=validate_args,
                             allow_nan_stats=allow_nan_stats,
                             name=name)
        self._parameters = parameters
    def __init__(self,
                 mean_direction,
                 concentration,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='VonMisesFisher'):
        """Creates a new `VonMisesFisher` instance.

    Args:
      mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D].
        A unit vector indicating the mode of the distribution, or the
        unit-normalized direction of the mean. (This is *not* in general the
        mean of the distribution; the mean is not generally in the support of
        the distribution.) NOTE: `D` is currently restricted to <= 5.
      concentration: Floating-point `Tensor` having batch shape [B1, ... Bn]
        broadcastable with `mean_direction`. The level of concentration of
        samples around the `mean_direction`. `concentration=0` indicates a
        uniform distribution over the unit hypersphere, and `concentration=+inf`
        indicates a `Deterministic` distribution (delta function) at
        `mean_direction`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: For known-bad arguments, i.e. unsupported event dimension.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([mean_direction, concentration],
                                            tf.float32)
            mean_direction = tf.convert_to_tensor(mean_direction,
                                                  name='mean_direction',
                                                  dtype=dtype)
            concentration = tf.convert_to_tensor(concentration,
                                                 name='concentration',
                                                 dtype=dtype)
            assertions = [
                assert_util.assert_non_negative(
                    concentration,
                    message='`concentration` must be non-negative'),
                assert_util.assert_greater(
                    tf.shape(mean_direction)[-1],
                    1,
                    message='`mean_direction` may not have scalar event shape'
                ),
                assert_util.assert_near(
                    1.,
                    tf.linalg.norm(mean_direction, axis=-1),
                    message='`mean_direction` must be unit-length')
            ] if validate_args else []
            static_event_dim = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(mean_direction.shape,
                                                    1)[-1])
            if static_event_dim is not None and static_event_dim > 5:
                raise ValueError('vMF ndims > 5 is not currently supported')
            elif validate_args:
                assertions += [
                    assert_util.assert_less_equal(
                        tf.shape(mean_direction)[-1],
                        5,
                        message='vMF ndims > 5 is not currently supported')
                ]
            with tf.control_dependencies(assertions):
                self._mean_direction = tf.identity(mean_direction)
                self._concentration = tf.identity(concentration)
            dtype_util.assert_same_float_dtype(
                [self._mean_direction, self._concentration])
            # mean_direction is always reparameterized.
            # concentration is only for event_dim==3, via an inversion sampler.
            reparameterization_type = (reparameterization.FULLY_REPARAMETERIZED
                                       if static_event_dim == 3 else
                                       reparameterization.NOT_REPARAMETERIZED)
            super(VonMisesFisher, self).__init__(
                dtype=self._concentration.dtype,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                reparameterization_type=reparameterization_type,
                parameters=parameters,
                name=name)
Example #6
0
def fill_triangular(x, upper=False, name=None):
    """Creates a (batch of) triangular matrix from a vector of inputs.

  Created matrix can be lower- or upper-triangular. (It is more efficient to
  create the matrix as upper or lower, rather than transpose.)

  Triangular matrix elements are filled in a clockwise spiral. See example,
  below.

  If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is
  `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.

  Example:

  ```python
  fill_triangular([1, 2, 3, 4, 5, 6])
  # ==> [[4, 0, 0],
  #      [6, 5, 0],
  #      [3, 2, 1]]

  fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
  # ==> [[1, 2, 3],
  #      [0, 5, 6],
  #      [0, 0, 4]]
  ```

  The key trick is to create an upper triangular matrix by concatenating `x`
  and a tail of itself, then reshaping.

  Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M`
  from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x`
  contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5`
  (so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with
  the first (`n = 5`) elements removed and reversed:

  ```python
  x = np.arange(15) + 1
  xc = np.concatenate([x, x[5:][::-1]])
  # ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13,
  #            12, 11, 10, 9, 8, 7, 6])

  # (We add one to the arange result to disambiguate the zeros below the
  # diagonal of our upper-triangular matrix from the first entry in `x`.)

  # Now, when reshapedlay this out as a matrix:
  y = np.reshape(xc, [5, 5])
  # ==> array([[ 1,  2,  3,  4,  5],
  #            [ 6,  7,  8,  9, 10],
  #            [11, 12, 13, 14, 15],
  #            [15, 14, 13, 12, 11],
  #            [10,  9,  8,  7,  6]])

  # Finally, zero the elements below the diagonal:
  y = np.triu(y, k=0)
  # ==> array([[ 1,  2,  3,  4,  5],
  #            [ 0,  7,  8,  9, 10],
  #            [ 0,  0, 13, 14, 15],
  #            [ 0,  0,  0, 12, 11],
  #            [ 0,  0,  0,  0,  6]])
  ```

  From this example we see that the resuting matrix is upper-triangular, and
  contains all the entries of x, as desired. The rest is details:
  - If `n` is even, `x` doesn't exactly fill an even number of rows (it fills
    `n / 2` rows and half of an additional row), but the whole scheme still
    works.
  - If we want a lower triangular matrix instead of an upper triangular,
    we remove the first `n` elements from `x` rather than from the reversed
    `x`.

  For additional comparisons, a pure numpy version of this function can be found
  in `distribution_util_test.py`, function `_fill_triangular`.

  Args:
    x: `Tensor` representing lower (or upper) triangular elements.
    upper: Python `bool` representing whether output matrix should be upper
      triangular (`True`) or lower triangular (`False`, default).
    name: Python `str`. The name to give this op.

  Returns:
    tril: `Tensor` with lower (or upper) triangular elements filled from `x`.

  Raises:
    ValueError: if `x` cannot be mapped to a triangular matrix.
  """

    with tf.name_scope(name or 'fill_triangular'):
        x = tf.convert_to_tensor(x, name='x')
        m = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(x.shape, 1)[-1])
        if m is not None:
            # Formula derived by solving for n: m = n(n+1)/2.
            m = np.int32(m)
            n = np.sqrt(0.25 + 2. * m) - 0.5
            if n != np.floor(n):
                raise ValueError(
                    'Input right-most shape ({}) does not '
                    'correspond to a triangular matrix.'.format(m))
            n = np.int32(n)
            static_final_shape = tensorshape_util.concatenate(
                x.shape[:-1], [n, n])
        else:
            m = tf.shape(x)[-1]
            # For derivation, see above. Casting automatically lops off the 0.5, so we
            # omit it.  We don't validate n is an integer because this has
            # graph-execution cost; an error will be thrown from the reshape, below.
            n = tf.cast(tf.sqrt(0.25 + tf.cast(2 * m, dtype=tf.float32)),
                        dtype=tf.int32)
            static_final_shape = tensorshape_util.concatenate(
                tensorshape_util.with_rank_at_least(x.shape, 1)[:-1],
                [None, None])

        # Try it out in numpy:
        #  n = 3
        #  x = np.arange(n * (n + 1) / 2)
        #  m = x.shape[0]
        #  n = np.int32(np.sqrt(.25 + 2 * m) - .5)
        #  x_tail = x[(m - (n**2 - m)):]
        #  np.concatenate([x_tail, x[::-1]], 0).reshape(n, n)  # lower
        #  # ==> array([[3, 4, 5],
        #               [5, 4, 3],
        #               [2, 1, 0]])
        #  np.concatenate([x, x_tail[::-1]], 0).reshape(n, n)  # upper
        #  # ==> array([[0, 1, 2],
        #               [3, 4, 5],
        #               [5, 4, 3]])
        #
        # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't
        # correctly handle `m == n == 1`. Hence, we do nonnegative indexing.
        # Furthermore observe that:
        #   m - (n**2 - m)
        #   = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2)
        #   = 2 (n**2 / 2 + n / 2) - n**2
        #   = n**2 + n - n**2
        #   = n
        ndims = prefer_static.rank(x)
        if upper:
            x_list = [x, tf.reverse(x[..., n:], axis=[ndims - 1])]
        else:
            x_list = [x[..., n:], tf.reverse(x, axis=[ndims - 1])]
        new_shape = (tensorshape_util.as_list(static_final_shape)
                     if tensorshape_util.is_fully_defined(static_final_shape)
                     else tf.concat([tf.shape(x)[:-1], [n, n]], axis=0))
        x = tf.reshape(tf.concat(x_list, axis=-1), new_shape)
        x = tf.linalg.band_part(x,
                                num_lower=(0 if upper else -1),
                                num_upper=(-1 if upper else 0))
        tensorshape_util.set_shape(x, static_final_shape)
        return x
Example #7
0
 def _z(self, x):
   """Standardize input `x` to a unit logistic."""
   with tf.name_scope('standardize'):
     return (x - self.loc) / self.scale
Example #8
0
    def __init__(self,
                 df,
                 scale_operator,
                 input_output_cholesky=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name=None):
        """Construct Wishart distributions.

    Args:
      df: `float` or `double` tensor, the degrees of freedom of the
        distribution(s). `df` must be greater than or equal to `k`.
      scale_operator: `float` or `double` instance of `LinearOperator`.
      input_output_cholesky: Python `bool`. If `True`, functions whose input or
        output have the semantics of samples assume inputs are in Cholesky form
        and return outputs in Cholesky form. In particular, if this flag is
        `True`, input to `log_prob` is presumed of Cholesky form and output from
        `sample`, `mean`, and `mode` are of Cholesky form.  Setting this
        argument to `True` is purely a computational optimization and does not
        change the underlying distribution; for instance, `mean` returns the
        Cholesky of the mean, not the mean of Cholesky factors. The `variance`
        and `stddev` methods are unaffected by this flag.
        Default value: `False` (i.e., input/output does not have Cholesky
        semantics).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if scale is not floating-type
      TypeError: if scale.dtype != df.dtype
      ValueError: if df < k, where scale operator event shape is
        `(k, k)`
    """
        parameters = dict(locals())
        self._input_output_cholesky = input_output_cholesky
        with tf.name_scope(name) as name:
            with tf.name_scope("init"):
                if not dtype_util.is_floating(scale_operator.dtype):
                    raise TypeError(
                        "scale_operator.dtype=%s is not a floating-point type"
                        % scale_operator.dtype)
                if not scale_operator.is_square:
                    print(scale_operator.to_dense().eval())
                    raise ValueError("scale_operator must be square.")

                self._scale_operator = scale_operator
                self._df = tf.convert_to_tensor(df,
                                                dtype=scale_operator.dtype,
                                                name="df")
                dtype_util.assert_same_float_dtype(
                    [self._df, self._scale_operator])
                if tf.compat.dimension_value(
                        self._scale_operator.shape[-1]) is None:
                    self._dimension = tf.cast(
                        self._scale_operator.domain_dimension_tensor(),
                        dtype=self._scale_operator.dtype,
                        name="dimension")
                else:
                    self._dimension = tf.convert_to_tensor(
                        tf.compat.dimension_value(
                            self._scale_operator.shape[-1]),
                        dtype=self._scale_operator.dtype,
                        name="dimension")
                df_val = tf.get_static_value(self._df)
                dim_val = tf.get_static_value(self._dimension)
                if df_val is not None and dim_val is not None:
                    df_val = np.asarray(df_val)
                    if not df_val.shape:
                        df_val = [df_val]
                    if np.any(df_val < dim_val):
                        raise ValueError(
                            "Degrees of freedom (df = %s) cannot be less than "
                            "dimension of scale matrix (scale.dimension = %s)"
                            % (df_val, dim_val))
                elif validate_args:
                    assertions = assert_util.assert_less_equal(
                        self._dimension,
                        self._df,
                        message=("Degrees of freedom (df = %s) cannot be "
                                 "less than dimension of scale matrix "
                                 "(scale.dimension = %s)" %
                                 (self._dimension, self._df)))
                    self._df = distribution_util.with_dependencies(
                        [assertions], self._df)
        super(_WishartLinearOperator, self).__init__(
            dtype=self._scale_operator.dtype,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
            parameters=parameters,
            name=name)
Example #9
0
    def __init__(self,
                 event_shape_out,
                 event_shape_in=(-1, ),
                 validate_args=False,
                 name=None):
        """Creates a `Reshape` bijector.

    Args:
      event_shape_out: An `int`-like vector-shaped `Tensor`
        representing the event shape of the transformed output.
      event_shape_in: An optional `int`-like vector-shape `Tensor`
        representing the event shape of the input. This is required in
        order to define inverse operations; the default of (-1,)
        assumes a vector-shaped input.
      validate_args: Python `bool` indicating whether arguments should
        be checked for correctness.
      name: Python `str`, name given to ops managed by this object.

    Raises:
      TypeError: if either `event_shape_in` or `event_shape_out` has
        non-integer `dtype`.
      ValueError: if either of `event_shape_in` or `event_shape_out`
       has non-vector shape (`rank > 1`), or if their sizes do not
       match.
    """
        with tf.name_scope(name or 'reshape') as name:
            event_shape_out = tf.convert_to_tensor(event_shape_out,
                                                   name='event_shape_out',
                                                   dtype_hint=tf.int32)
            event_shape_in = tf.convert_to_tensor(event_shape_in,
                                                  name='event_shape_in',
                                                  dtype_hint=tf.int32)

            forward_min_event_ndims_ = tensorshape_util.num_elements(
                event_shape_in.shape)
            if forward_min_event_ndims_ is None:
                raise NotImplementedError(
                    '`event_shape_in` `size` must be statically known. For dynamic '
                    'support, please contact `[email protected]`.')

            inverse_min_event_ndims_ = tensorshape_util.num_elements(
                event_shape_out.shape)
            if inverse_min_event_ndims_ is None:
                raise NotImplementedError(
                    '`event_shape_out` `size` must be statically known. For dynamic '
                    'support, please contact `[email protected]`.')

            assertions = []
            assertions.extend(
                _maybe_check_valid_shape(event_shape_out, validate_args))
            assertions.extend(
                _maybe_check_valid_shape(event_shape_in, validate_args))

            if assertions:
                with tf.control_dependencies(assertions):
                    event_shape_in = tf.identity(
                        event_shape_in, name='validated_event_shape_in')
                    event_shape_out = tf.identity(
                        event_shape_out, name='validated_event_shape_out')

            self._event_shape_in = event_shape_in
            self._event_shape_out = event_shape_out

            super(Reshape, self).__init__(
                forward_min_event_ndims=forward_min_event_ndims_,
                inverse_min_event_ndims=inverse_min_event_ndims_,
                is_constant_jacobian=True,
                validate_args=validate_args,
                name=name or 'reshape')
Example #10
0
    def __init__(self,
                 cat,
                 components,
                 validate_args=False,
                 allow_nan_stats=True,
                 use_static_graph=False,
                 name="Mixture"):
        """Initialize a Mixture distribution.

    A `Mixture` is defined by a `Categorical` (`cat`, representing the
    mixture probabilities) and a list of `Distribution` objects
    all having matching dtype, batch shape, event shape, and continuity
    properties (the components).

    The `num_classes` of `cat` must be possible to infer at graph construction
    time and match `len(components)`.

    Args:
      cat: A `Categorical` distribution instance, representing the probabilities
          of `distributions`.
      components: A list or tuple of `Distribution` instances.
        Each instance must have the same type, be defined on the same domain,
        and have matching `event_shape` and `batch_shape`.
      validate_args: Python `bool`, default `False`. If `True`, raise a runtime
        error if batch or event ranks are inconsistent between cat and any of
        the distributions. This is only checked if the ranks cannot be
        determined statically at graph construction time.
      allow_nan_stats: Boolean, default `True`. If `False`, raise an
       exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      use_static_graph: Calls to `sample` will not rely on dynamic tensor
        indexing, allowing for some static graph compilation optimizations, but
        at the expense of sampling all underlying distributions in the mixture.
        (Possibly useful when running on TPUs).
        Default value: `False` (i.e., use dynamic indexing).
      name: A name for this distribution (optional).

    Raises:
      TypeError: If cat is not a `Categorical`, or `components` is not
        a list or tuple, or the elements of `components` are not
        instances of `Distribution`, or do not have matching `dtype`.
      ValueError: If `components` is an empty list or tuple, or its
        elements do not have a statically known event rank.
        If `cat.num_classes` cannot be inferred at graph creation time,
        or the constant value of `cat.num_classes` is not equal to
        `len(components)`, or all `components` and `cat` do not have
        matching static batch shapes, or all components do not
        have matching static event shapes.
    """
        parameters = dict(locals())
        if not isinstance(cat, categorical.Categorical):
            raise TypeError(
                "cat must be a Categorical distribution, but saw: %s" % cat)
        if not components:
            raise ValueError("components must be a non-empty list or tuple")
        if not isinstance(components, (list, tuple)):
            raise TypeError("components must be a list or tuple, but saw: %s" %
                            components)
        if not all(
                isinstance(c, distribution.Distribution) for c in components):
            raise TypeError(
                "all entries in components must be Distribution instances"
                " but saw: %s" % components)

        dtype = components[0].dtype
        if not all(d.dtype == dtype for d in components):
            raise TypeError("All components must have the same dtype, but saw "
                            "dtypes: %s" % [(d.name, d.dtype)
                                            for d in components])
        static_event_shape = components[0].event_shape
        static_batch_shape = cat.batch_shape
        for di, d in enumerate(components):
            if not tensorshape_util.is_compatible_with(static_batch_shape,
                                                       d.batch_shape):
                raise ValueError(
                    "components[{}] batch shape must be compatible with cat "
                    "shape and other component batch shapes".format(di))
            static_event_shape = tensorshape_util.merge_with(
                static_event_shape, d.event_shape)
            static_batch_shape = tensorshape_util.merge_with(
                static_batch_shape, d.batch_shape)
        if tensorshape_util.rank(static_event_shape) is None:
            raise ValueError(
                "Expected to know rank(event_shape) from components, but "
                "none of the components provide a static number of ndims")

        # Ensure that all batch and event ndims are consistent.
        with tf.name_scope(name) as name:
            num_components = cat._num_categories()
            static_num_components = tf.get_static_value(num_components)
            if static_num_components is None:
                raise ValueError(
                    "Could not infer number of classes from cat and unable "
                    "to compare this value to the number of components passed in."
                )
            # Possibly convert from numpy 0-D array.
            static_num_components = int(static_num_components)
            if static_num_components != len(components):
                raise ValueError(
                    "cat.num_classes != len(components): %d vs. %d" %
                    (static_num_components, len(components)))

            cat_batch_shape = cat.batch_shape_tensor()
            cat_batch_rank = tf.size(cat_batch_shape)
            if validate_args:
                batch_shapes = [d.batch_shape_tensor() for d in components]
                batch_ranks = [tf.size(bs) for bs in batch_shapes]
                check_message = ("components[%d] batch shape must match cat "
                                 "batch shape")
                self._assertions = [
                    assert_util.assert_equal(cat_batch_rank,
                                             batch_ranks[di],
                                             message=check_message % di)
                    for di in range(len(components))
                ]
                self._assertions += [
                    assert_util.assert_equal(cat_batch_shape,
                                             batch_shapes[di],
                                             message=check_message % di)
                    for di in range(len(components))
                ]
            else:
                self._assertions = []

            self._cat = cat
            self._components = list(components)
            self._num_components = static_num_components
            self._static_event_shape = static_event_shape
            self._static_batch_shape = static_batch_shape

            self._use_static_graph = use_static_graph
            if use_static_graph and static_num_components is None:
                raise ValueError(
                    "Number of categories must be known statically when "
                    "`static_sample=True`.")

        super(Mixture, self).__init__(
            dtype=dtype,
            reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            name=name)
Example #11
0
    def __init__(self,
                 df,
                 scale=None,
                 scale_tril=None,
                 input_output_cholesky=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Wishart"):
        """Construct Wishart distributions.

    Args:
      df: `float` or `double` `Tensor`. Degrees of freedom, must be greater than
        or equal to dimension of the scale matrix.
      scale: `float` or `double` `Tensor`. The symmetric positive definite
        scale matrix of the distribution. Exactly one of `scale` and
        'scale_tril` must be passed.
      scale_tril: `float` or `double` `Tensor`. The Cholesky factorization
        of the symmetric positive definite scale matrix of the distribution.
        Exactly one of `scale` and 'scale_tril` must be passed.
      input_output_cholesky: Python `bool`. If `True`, functions whose input or
        output have the semantics of samples assume inputs are in Cholesky form
        and return outputs in Cholesky form. In particular, if this flag is
        `True`, input to `log_prob` is presumed of Cholesky form and output from
        `sample`, `mean`, and `mode` are of Cholesky form.  Setting this
        argument to `True` is purely a computational optimization and does not
        change the underlying distribution; for instance, `mean` returns the
        Cholesky of the mean, not the mean of Cholesky factors. The `variance`
        and `stddev` methods are unaffected by this flag.
        Default value: `False` (i.e., input/output does not have Cholesky
        semantics).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    Raises:
      ValueError: if zero or both of 'scale' and 'scale_tril' are passed in.
    """
        parameters = dict(locals())

        with tf.name_scope(name) as name:
            with tf.name_scope("init"):
                if (scale is None) == (scale_tril is None):
                    raise ValueError(
                        "Must pass scale or scale_tril, but not both.")

                dtype = dtype_util.common_dtype([df, scale, scale_tril],
                                                tf.float32)
                df = tf.convert_to_tensor(df, name="df", dtype=dtype)
                if scale is not None:
                    scale = tf.convert_to_tensor(scale,
                                                 name="scale",
                                                 dtype=dtype)
                    if validate_args:
                        scale = distribution_util.assert_symmetric(scale)
                    scale_tril = tf.linalg.cholesky(scale)
                else:  # scale_tril is not None
                    scale_tril = tf.convert_to_tensor(scale_tril,
                                                      name="scale_tril",
                                                      dtype=dtype)
                    if validate_args:
                        scale_tril = distribution_util.with_dependencies([
                            assert_util.assert_positive(
                                tf.linalg.diag_part(scale_tril),
                                message="scale_tril must be positive definite"
                            ),
                            assert_util.assert_equal(
                                tf.shape(scale_tril)[-1],
                                tf.shape(scale_tril)[-2],
                                message="scale_tril must be square")
                        ], scale_tril)

            super(Wishart, self).__init__(
                df=df,
                scale_operator=tf.linalg.LinearOperatorLowerTriangular(
                    tril=scale_tril,
                    is_non_singular=True,
                    is_positive_definite=True,
                    is_square=True),
                input_output_cholesky=input_output_cholesky,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                name=name)
        self._parameters = parameters
  def __init__(self,
               loc=None,
               scale_diag=None,
               scale_identity_multiplier=None,
               scale_perturb_factor=None,
               scale_perturb_diag=None,
               validate_args=False,
               allow_nan_stats=True,
               name="MultivariateNormalDiagPlusLowRank"):
    """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:

    ```none
    scale = diag(scale_diag + scale_identity_multiplier ones(k)) +
        scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T
    ```

    where:

    * `scale_diag.shape = [k]`,
    * `scale_identity_multiplier.shape = []`,
    * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and,
    * `scale_perturb_diag.shape = [r]`.

    Additional leading dimensions (if any) will index batches.

    If both `scale_diag` and `scale_identity_multiplier` are `None`, then
    `scale` is the Identity matrix.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
        matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
        and characterizes `b`-batches of `k x k` diagonal matrices added to
        `scale`. When both `scale_identity_multiplier` and `scale_diag` are
        `None` then `scale` is the `Identity`.
      scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
        a scaled-identity-matrix added to `scale`. May have shape
        `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled
        `k x k` identity matrices added to `scale`. When both
        `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is
        the `Identity`.
      scale_perturb_factor: Floating-point `Tensor` representing a rank-`r`
        perturbation added to `scale`. May have shape `[B1, ..., Bb, k, r]`,
        `b >= 0`, and characterizes `b`-batches of rank-`r` updates to `scale`.
        When `None`, no rank-`r` update is added to `scale`.
      scale_perturb_diag: Floating-point `Tensor` representing a diagonal matrix
        inside the rank-`r` perturbation added to `scale`. May have shape
        `[B1, ..., Bb, r]`, `b >= 0`, and characterizes `b`-batches of `r x r`
        diagonal matrices inside the perturbation added to `scale`. When
        `None`, an identity matrix is used inside the perturbation. Can only be
        specified if `scale_perturb_factor` is also specified.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if at most `scale_identity_multiplier` is specified.
    """
    parameters = dict(locals())

    def _convert_to_tensor(x, name, dtype=None):
      return None if x is None else tf.convert_to_tensor(
          x, name=name, dtype=dtype)

    with tf.name_scope(name) as name:
      with tf.name_scope("init"):
        dtype = dtype_util.common_dtype([
            loc, scale_diag, scale_identity_multiplier, scale_perturb_factor,
            scale_perturb_diag
        ], tf.float32)
        has_low_rank = (scale_perturb_factor is not None or
                        scale_perturb_diag is not None)
        scale = distribution_util.make_diag_scale(
            loc=loc,
            scale_diag=scale_diag,
            scale_identity_multiplier=scale_identity_multiplier,
            validate_args=validate_args,
            assert_positive=has_low_rank,
            dtype=dtype)
        scale_perturb_factor = _convert_to_tensor(
            scale_perturb_factor, name="scale_perturb_factor", dtype=dtype)
        scale_perturb_diag = _convert_to_tensor(
            scale_perturb_diag, name="scale_perturb_diag", dtype=dtype)
        if has_low_rank:
          scale = tf.linalg.LinearOperatorLowRankUpdate(
              scale,
              u=scale_perturb_factor,
              diag_update=scale_perturb_diag,
              is_diag_update_positive=scale_perturb_diag is None,
              is_non_singular=True,  # Implied by is_positive_definite=True.
              is_self_adjoint=True,
              is_positive_definite=True,
              is_square=True)
    super(MultivariateNormalDiagPlusLowRank, self).__init__(
        loc=loc,
        scale=scale,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        name=name)
    self._parameters = parameters
Example #13
0
def _kl_dirichlet_dirichlet(d1, d2, name=None):
  """Batchwise KL divergence KL(d1 || d2) with d1 and d2 Dirichlet.

  Args:
    d1: instance of a Dirichlet distribution object.
    d2: instance of a Dirichlet distribution object.
    name: Python `str` name to use for created operations.
      Default value: `None` (i.e., `'kl_dirichlet_dirichlet'`).

  Returns:
    kl_div: Batchwise KL(d1 || d2)
  """
  with tf.name_scope(name or 'kl_dirichlet_dirichlet'):
    # The KL between Dirichlet distributions can be derived as follows. We have
    #
    #   Dir(x; a) = 1 / B(a) * prod_i[x[i]^(a[i] - 1)]
    #
    # where B(a) is the multivariate Beta function:
    #
    #   B(a) = Gamma(a[1]) * ... * Gamma(a[n]) / Gamma(a[1] + ... + a[n])
    #
    # The KL is
    #
    #   KL(Dir(x; a), Dir(x; b)) = E_Dir(x; a){log(Dir(x; a) / Dir(x; b))}
    #
    # so we'll need to know the log density of the Dirichlet. This is
    #
    #   log(Dir(x; a)) = sum_i[(a[i] - 1) log(x[i])] - log B(a)
    #
    # The only term that matters for the expectations is the log(x[i]). To
    # compute the expectation of this term over the Dirichlet density, we can
    # use the following facts about the Dirichlet in exponential family form:
    #   1. log(x[i]) is a sufficient statistic
    #   2. expected sufficient statistics (of any exp family distribution) are
    #      equal to derivatives of the log normalizer with respect to
    #      corresponding natural parameters: E{T[i](x)} = dA/d(eta[i])
    #
    # To proceed, we can rewrite the Dirichlet density in exponential family
    # form as follows:
    #
    #   Dir(x; a) = exp{eta(a) . T(x) - A(a)}
    #
    # where '.' is the dot product of vectors eta and T, and A is a scalar:
    #
    #   eta[i](a) = a[i] - 1
    #     T[i](x) = log(x[i])
    #        A(a) = log B(a)
    #
    # Now, we can use fact (2) above to write
    #
    #   E_Dir(x; a)[log(x[i])]
    #       = dA(a) / da[i]
    #       = d/da[i] log B(a)
    #       = d/da[i] (sum_j lgamma(a[j])) - lgamma(sum_j a[j])
    #       = digamma(a[i])) - digamma(sum_j a[j])
    #
    # Putting it all together, we have
    #
    # KL[Dir(x; a) || Dir(x; b)]
    #     = E_Dir(x; a){log(Dir(x; a) / Dir(x; b)}
    #     = E_Dir(x; a){sum_i[(a[i] - b[i]) log(x[i])} - (lbeta(a) - lbeta(b))
    #     = sum_i[(a[i] - b[i]) * E_Dir(x; a){log(x[i])}] - lbeta(a) + lbeta(b)
    #     = sum_i[(a[i] - b[i]) * (digamma(a[i]) - digamma(sum_j a[j]))]
    #          - lbeta(a) + lbeta(b))

    concentration1 = tf.convert_to_tensor(d1.concentration)
    concentration2 = tf.convert_to_tensor(d2.concentration)
    digamma_sum_d1 = tf.math.digamma(
        tf.reduce_sum(concentration1, axis=-1, keepdims=True))
    digamma_diff = tf.math.digamma(concentration1) - digamma_sum_d1
    concentration_diff = concentration1 - concentration2

    return (
        tf.reduce_sum(concentration_diff * digamma_diff, axis=-1) -
        tf.math.lbeta(concentration1) + tf.math.lbeta(concentration2))
Example #14
0
    def __init__(self,
                 outcomes,
                 logits=None,
                 probs=None,
                 rtol=None,
                 atol=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='FiniteDiscrete'):
        """Construct a finite discrete contribution.

    Args:
      outcomes: A 1-D floating or integer `Tensor`, representing a list of
        possible outcomes in strictly ascending order.
      logits: A floating N-D `Tensor`, `N >= 1`, representing the log
        probabilities of a set of FiniteDiscrete distributions. The first `N -
        1` dimensions index into a batch of independent distributions and the
        last dimension represents a vector of logits for each discrete value.
        Only one of `logits` or `probs` should be passed in.
      probs: A floating  N-D `Tensor`, `N >= 1`, representing the probabilities
        of a set of FiniteDiscrete distributions. The first `N - 1` dimensions
        index into a batch of independent distributions and the last dimension
        represents a vector of probabilities for each discrete value. Only one
        of `logits` or `probs` should be passed in.
      rtol: `Tensor` with same `dtype` as `outcomes`. The relative tolerance for
        floating number comparison. Only effective when `outcomes` is a floating
        `Tensor`. Default is `10 * eps`.
      atol: `Tensor` with same `dtype` as `outcomes`. The absolute tolerance for
        floating number comparison. Only effective when `outcomes` is a floating
        `Tensor`. Default is `10 * eps`.
      validate_args:  Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may render incorrect outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value '`NaN`' to indicate the
        result is undefined. When `False`, an exception is raised if one or more
        of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            outcomes_dtype = dtype_util.common_dtype([outcomes],
                                                     dtype_hint=tf.float32)
            self._outcomes = tensor_util.convert_nonref_to_tensor(
                outcomes, dtype_hint=outcomes_dtype, name='outcomes')

            if dtype_util.is_floating(self._outcomes.dtype):
                eps = np.finfo(dtype_util.as_numpy_dtype(outcomes_dtype)).eps
                self._rtol = 10 * eps if rtol is None else rtol
                self._atol = 10 * eps if atol is None else atol
            else:
                self._rtol = None
                self._atol = None

            self._categorical = categorical.Categorical(
                logits=logits,
                probs=probs,
                dtype=tf.int32,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)
        super(FiniteDiscrete, self).__init__(
            dtype=self._outcomes.dtype,
            reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            name=name)
Example #15
0
def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None):
    """Computes a matrix inverse given the matrix's LU decomposition.

  This op is conceptually identical to,

  ```python
  inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X))
  tf.assert_near(tf.matrix_inverse(X), inv_X)
  # ==> True
  ```

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_matrix_inverse').

  Returns:
    inv_x: The matrix_inv, i.e.,
      `tf.matrix_inverse(tfp.math.lu_reconstruct(lu, perm))`.

  #### Examples

  ```python
  import numpy as np
  from tensorflow_probability.python.internal.backend import numpy as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.numpy

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  inv_x = tfp.math.lu_matrix_inverse(*tf.linalg.lu(x))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_matrix_inverse'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        assertions = lu_reconstruct_assertions(lower_upper, perm,
                                               validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
        shape = tf.shape(lower_upper)
        return lu_solve(lower_upper,
                        perm,
                        rhs=tf.eye(shape[-1],
                                   batch_shape=shape[:-2],
                                   dtype=lower_upper.dtype),
                        validate_args=False)
Example #16
0
    def __init__(self,
                 shift=None,
                 scale_identity_multiplier=None,
                 scale_diag=None,
                 scale_tril=None,
                 scale_perturb_factor=None,
                 scale_perturb_diag=None,
                 adjoint=False,
                 validate_args=False,
                 name="affine",
                 dtype=None):
        """Instantiates the `Affine` bijector.

    This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments,
    giving the forward operation:

    ```none
    Y = g(X) = scale @ X + shift
    ```

    where the `scale` term is logically equivalent to:

    ```python
    scale = (
      scale_identity_multiplier * tf.diag(tf.ones(d)) +
      tf.diag(scale_diag) +
      scale_tril +
      scale_perturb_factor @ diag(scale_perturb_diag) @
        tf.transpose([scale_perturb_factor])
    )
    ```

    If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are
    specified then `scale += IdentityMatrix`. Otherwise specifying a
    `scale` argument has the semantics of `scale += Expand(arg)`, i.e.,
    `scale_diag != None` means `scale += tf.diag(scale_diag)`.

    Args:
      shift: Floating-point `Tensor`. If this is set to `None`, no shift is
        applied.
      scale_identity_multiplier: floating point rank 0 `Tensor` representing a
        scaling done to the identity matrix.
        When `scale_identity_multiplier = scale_diag = scale_tril = None` then
        `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added
        to `scale`.
      scale_diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape `[N1, N2, ...  k]`, which represents a k x k
        diagonal matrix.
        When `None` no diagonal term is added to `scale`.
      scale_tril: Floating-point `Tensor` representing the lower triangular
        matrix. `scale_tril` has shape `[N1, N2, ...  k, k]`, which represents a
        k x k lower triangular matrix.
        When `None` no `scale_tril` term is added to `scale`.
        The upper triangular elements above the diagonal are ignored.
      scale_perturb_factor: Floating-point `Tensor` representing factor matrix
        with last two dimensions of shape `(k, r)`. When `None`, no rank-r
        update is added to `scale`.
      scale_perturb_diag: Floating-point `Tensor` representing the diagonal
        matrix. `scale_perturb_diag` has shape `[N1, N2, ...  r]`, which
        represents an `r x r` diagonal matrix. When `None` low rank updates will
        take the form `scale_perturb_factor * scale_perturb_factor.T`.
      adjoint: Python `bool` indicating whether to use the `scale` matrix as
        specified or its adjoint.
        Default value: `False`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
      dtype: `tf.DType` to prefer when converting args to `Tensor`s. Else, we
        fall back to a common dtype inferred from the args, finally falling back
        to float32.

    Raises:
      ValueError: if `perturb_diag` is specified but not `perturb_factor`.
      TypeError: if `shift` has different `dtype` from `scale` arguments.
    """
        # Ambiguous definition of low rank update.
        if scale_perturb_diag is not None and scale_perturb_factor is None:
            raise ValueError("When scale_perturb_diag is specified, "
                             "scale_perturb_factor must be specified.")

        # Special case, only handling a scaled identity matrix. We don't know its
        # dimensions, so this is special cased.
        # We don't check identity_multiplier, since below we set it to 1. if all
        # other scale args are None.
        self._is_only_identity_multiplier = (scale_tril is None
                                             and scale_diag is None
                                             and scale_perturb_factor is None)

        with tf.name_scope(name) as name:
            self._name = name
            self._validate_args = validate_args

            if dtype is None:
                dtype = dtype_util.common_dtype([
                    shift, scale_identity_multiplier, scale_diag, scale_tril,
                    scale_perturb_diag, scale_perturb_factor
                ], tf.float32)

            if shift is not None:
                shift = tf.convert_to_tensor(shift, name="shift", dtype=dtype)
            self._shift = shift

            # When no args are specified, pretend the scale matrix is the identity
            # matrix.
            if (self._is_only_identity_multiplier
                    and scale_identity_multiplier is None):
                scale_identity_multiplier = tf.convert_to_tensor(1.,
                                                                 dtype=dtype)

            # self._create_scale_operator returns a LinearOperator in all cases
            # except if self._is_only_identity_multiplier; in which case it
            # returns a scalar Tensor.
            scale = self._create_scale_operator(
                identity_multiplier=scale_identity_multiplier,
                diag=scale_diag,
                tril=scale_tril,
                perturb_diag=scale_perturb_diag,
                perturb_factor=scale_perturb_factor,
                shift=shift,
                validate_args=validate_args,
                dtype=dtype)

            if (scale is not None and not self._is_only_identity_multiplier
                    and not dtype_util.SKIP_DTYPE_CHECKS):
                if (shift is not None and
                        not dtype_util.base_equal(shift.dtype, scale.dtype)):
                    raise TypeError(
                        "shift.dtype ({}) is incompatible with scale.dtype ({})."
                        .format(shift.dtype, scale.dtype))

            self._scale = scale
            self._adjoint = adjoint
            super(Affine, self).__init__(forward_min_event_ndims=1,
                                         is_constant_jacobian=True,
                                         dtype=dtype,
                                         validate_args=validate_args,
                                         name=name)
Example #17
0
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
    """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

  Returns:
    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
      `lu_reconstruct(*tf.linalg.lu(x))`.

  #### Examples

  ```python
  import numpy as np
  from tensorflow_probability.python.internal.backend import numpy as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.numpy

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x))
  tf.assert_near(x, x_reconstructed)
  # ==> True
  ```

  """
    with tf.name_scope(name or 'lu_reconstruct'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')

        assertions = lu_reconstruct_assertions(lower_upper, perm,
                                               validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)

        shape = tf.shape(lower_upper)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(shape[:-1], dtype=lower_upper.dtype))
        upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
        x = tf.matmul(lower, upper)

        if (tensorshape_util.rank(lower_upper.shape) is None
                or tensorshape_util.rank(lower_upper.shape) != 2):
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = tf.reduce_prod(shape[:-2])
            d = shape[-1]
            x = tf.reshape(x, [batch_size, d, d])
            perm = tf.reshape(perm, [batch_size, d])
            perm = tf.map_fn(tf.math.invert_permutation, perm)
            batch_indices = tf.broadcast_to(
                tf.range(batch_size)[:, tf.newaxis], [batch_size, d])
            x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
            x = tf.reshape(x, shape)
        else:
            x = tf.gather(x, tf.math.invert_permutation(perm))

        x.set_shape(lower_upper.shape)
        return x
    def __init__(self,
                 loc=None,
                 scale=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='MultivariateNormalLinearOperator'):
        """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
        `[B1, ..., Bb, k, k]`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if `scale` is unspecified.
      TypeError: if not `scale.dtype.is_floating`
    """
        parameters = dict(locals())
        if scale is None:
            raise ValueError('Missing required `scale` parameter.')
        if not dtype_util.is_floating(scale.dtype):
            raise TypeError(
                '`scale` parameter must have floating-point dtype.')

        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale],
                                            dtype_hint=tf.float32)
            # Since expand_dims doesn't preserve constant-ness, we obtain the
            # non-dynamic value if possible.
            loc = tensor_util.convert_nonref_to_tensor(loc,
                                                       dtype=dtype,
                                                       name='loc')
            batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
                loc, scale)

        super(MultivariateNormalLinearOperator, self).__init__(
            distribution=normal.Normal(loc=tf.zeros([], dtype=dtype),
                                       scale=tf.ones([], dtype=dtype)),
            bijector=affine_linear_operator_bijector.AffineLinearOperator(
                shift=loc, scale=scale, validate_args=validate_args),
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
            name=name)
        self._parameters = parameters
Example #19
0
    def __init__(self,
                 loc,
                 scale,
                 low,
                 high,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="TruncatedNormal"):
        """Construct TruncatedNormal.

    All parameters of the distribution will be broadcast to the same shape,
    so the resulting distribution will have a batch_shape of the broadcast
    shape of all parameters.

    Args:
      loc: Floating point tensor; the mean of the normal distribution(s) (
        note that the mean of the resulting distribution will be different
        since it is modified by the bounds).
      scale: Floating point tensor; the std deviation of the normal
        distribution(s).
      low: `float` `Tensor` representing lower bound of the distribution's
        support. Must be such that `low < high`.
      high: `float` `Tensor` representing upper bound of the distribution's
        support. Must be such that `low < high`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked at run-time.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale, low, high],
                                            tf.float32)
            loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype)
            scale = tf.convert_to_tensor(scale, name="scale", dtype=dtype)
            low = tf.convert_to_tensor(low, name="low", dtype=dtype)
            high = tf.convert_to_tensor(high, name="high", dtype=dtype)
            dtype_util.assert_same_float_dtype([loc, scale, low, high])

            self._broadcast_batch_shape = distribution_util.get_broadcast_shape(
                loc, scale, low, high)

            # Broadcast all parameters to the same shape
            broadcast_ones = tf.ones(shape=self._broadcast_batch_shape,
                                     dtype=scale.dtype)
            self._scale = scale * broadcast_ones
            self._loc = loc * broadcast_ones
            self._low = low * broadcast_ones
            self._high = high * broadcast_ones

            with tf.control_dependencies(
                [self._validate()] if validate_args else []):
                self._loc = tf.identity(self._loc)

        super(TruncatedNormal, self).__init__(
            dtype=dtype,
            # This distribution is fully reparameterized. loc, scale have straight
            # through gradients. The gradients for the bounds are implemented using
            # custom derived expressions based on implicit gradients.
            # For the special case of lower bound zero and a positive upper bound
            # an equivalent expression can also be found in Sec 9.1.1.
            # of https://arxiv.org/pdf/1806.01851.pdf. The implementation here
            # handles arbitrary bounds.
            reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            name=name)
def _kl_brute_force(a, b, name=None):
    """Batched KL divergence `KL(a || b)` for multivariate Normals.

  With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and
  covariance `C_a`, `C_b` respectively,

  ```
  KL(a || b) = 0.5 * ( L - k + T + Q ),
  L := Log[Det(C_b)] - Log[Det(C_a)]
  T := trace(C_b^{-1} C_a),
  Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
  ```

  This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
  methods for solving systems with `C_b` may be available, a dense version of
  (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B`
  is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
  and `y`.

  Args:
    a: Instance of `MultivariateNormalLinearOperator`.
    b: Instance of `MultivariateNormalLinearOperator`.
    name: (optional) name to use for created ops. Default "kl_mvn".

  Returns:
    Batchwise `KL(a || b)`.
  """
    def squared_frobenius_norm(x):
        """Helper to make KL calculation slightly more readable."""
        # http://mathworld.wolfram.com/FrobeniusNorm.html
        # The gradient of KL[p,q] is not defined when p==q. The culprit is
        # tf.norm, i.e., we cannot use the commented out code.
        # return tf.square(tf.norm(x, ord="fro", axis=[-2, -1]))
        return tf.reduce_sum(tf.square(x), axis=[-2, -1])

    # TODO(b/35041439): See also b/35040945. Remove this function once LinOp
    # supports something like:
    #   A.inverse().solve(B).norm(order='fro', axis=[-1, -2])
    def is_diagonal(x):
        """Helper to identify if `LinearOperator` has only a diagonal component."""
        return (isinstance(x, tf.linalg.LinearOperatorIdentity)
                or isinstance(x, tf.linalg.LinearOperatorScaledIdentity)
                or isinstance(x, tf.linalg.LinearOperatorDiag))

    with tf.name_scope(name or 'kl_mvn'):
        # Calculation is based on:
        # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians
        # and,
        # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
        # i.e.,
        #   If Ca = AA', Cb = BB', then
        #   tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
        #                  = tr[inv(B) A A' inv(B)']
        #                  = tr[(inv(B) A) (inv(B) A)']
        #                  = sum_{ij} (inv(B) A)_{ij}**2
        #                  = ||inv(B) A||_F**2
        # where ||.||_F is the Frobenius norm and the second equality follows from
        # the cyclic permutation property.
        if is_diagonal(a.scale) and is_diagonal(b.scale):
            # Using `stddev` because it handles expansion of Identity cases.
            b_inv_a = (a.stddev() / b.stddev())[..., tf.newaxis]
        else:
            b_inv_a = b.scale.solve(a.scale.to_dense())
        kl_div = (b.scale.log_abs_determinant() -
                  a.scale.log_abs_determinant() + 0.5 *
                  (-tf.cast(a.scale.domain_dimension_tensor(), a.dtype) +
                   squared_frobenius_norm(b_inv_a) + squared_frobenius_norm(
                       b.scale.solve((b.mean() - a.mean())[..., tf.newaxis]))))
        tensorshape_util.set_shape(
            kl_div, tf.broadcast_static_shape(a.batch_shape, b.batch_shape))
        return kl_div
Example #21
0
    def __init__(self,
                 distributions,
                 dtype_override=None,
                 validate_args=False,
                 allow_nan_stats=False,
                 name='Blockwise'):
        """Construct the `Blockwise` distribution.

    Args:
      distributions: Python `list` of `tfp.distributions.Distribution`
        instances. All distribution instances must have the same `batch_shape`
        and all must have `event_ndims==1`, i.e., be vector-variate
        distributions.
      dtype_override: samples of `distributions` will be cast to this `dtype`.
        If unspecified, all `distributions` must have the same `dtype`.
        Default value: `None` (i.e., do not cast).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or more
        of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._distributions = distributions
            if dtype_override is not None:
                distributions = tf.nest.map_structure(
                    lambda d: _Cast(d, dtype_override), distributions)
            if _is_iterable(distributions):
                self._distribution = (
                    joint_distribution_sequential.JointDistributionSequential(
                        list(distributions)))
            else:
                self._distribution = distributions

            # Need to cache these for JointDistributions as the batch shape of that
            # distribution can change after `_sample` calls.
            self._cached_batch_shape_tensor = self._distribution.batch_shape_tensor(
            )
            self._cached_batch_shape = self._distribution.batch_shape

            if dtype_override is not None:
                dtype = dtype_override
            else:
                dtype = set(
                    dtype_util.base_dtype(dtype)
                    for dtype in tf.nest.flatten(self._distribution.dtype)
                    if dtype is not None)
                if len(dtype) == 0:  # pylint: disable=g-explicit-length-test
                    dtype = tf.float32
                elif len(dtype) == 1:
                    dtype = dtype.pop()
                else:
                    raise TypeError(
                        'Distributions must have same dtype; found: {}.'.
                        format(self._distribution.dtype))

            reparameterization_type = set(
                tf.nest.flatten(self._distribution.reparameterization_type))
            reparameterization_type = (reparameterization_type.pop() if
                                       len(reparameterization_type) == 1 else
                                       reparameterization.NOT_REPARAMETERIZED)

            super(Blockwise, self).__init__(
                dtype=dtype,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                reparameterization_type=reparameterization_type,
                parameters=parameters,
                name=name)
Example #22
0
def fill_triangular_inverse(x, upper=False, name=None):
    """Creates a vector from a (batch of) triangular matrix.

  The vector is created from the lower-triangular or upper-triangular portion
  depending on the value of the parameter `upper`.

  If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
  `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.

  Example:

  ```python
  fill_triangular_inverse(
    [[4, 0, 0],
     [6, 5, 0],
     [3, 2, 1]])

  # ==> [1, 2, 3, 4, 5, 6]

  fill_triangular_inverse(
    [[1, 2, 3],
     [0, 5, 6],
     [0, 0, 4]], upper=True)

  # ==> [1, 2, 3, 4, 5, 6]
  ```

  Args:
    x: `Tensor` representing lower (or upper) triangular elements.
    upper: Python `bool` representing whether output matrix should be upper
      triangular (`True`) or lower triangular (`False`, default).
    name: Python `str`. The name to give this op.

  Returns:
    flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
      (or upper) triangular elements from `x`.
  """

    with tf.name_scope(name or 'fill_triangular_inverse'):
        x = tf.convert_to_tensor(x, name='x')
        n = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(x.shape, 2)[-1])
        if n is not None:
            n = np.int32(n)
            m = np.int32((n * (n + 1)) // 2)
            static_final_shape = tensorshape_util.concatenate(
                x.shape[:-2], [m])
        else:
            n = tf.shape(x)[-1]
            m = (n * (n + 1)) // 2
            static_final_shape = tensorshape_util.concatenate(
                tensorshape_util.with_rank_at_least(x.shape, 2)[:-2], [None])
        ndims = prefer_static.rank(x)
        if upper:
            initial_elements = x[..., 0, :]
            triangular_portion = x[..., 1:, :]
        else:
            initial_elements = tf.reverse(x[..., -1, :], axis=[ndims - 2])
            triangular_portion = x[..., :-1, :]
        rotated_triangular_portion = tf.reverse(tf.reverse(triangular_portion,
                                                           axis=[ndims - 1]),
                                                axis=[ndims - 2])
        consolidated_matrix = triangular_portion + rotated_triangular_portion
        end_sequence = tf.reshape(
            consolidated_matrix,
            tf.concat([tf.shape(x)[:-2], [n * (n - 1)]], axis=0))
        y = tf.concat([initial_elements, end_sequence[..., :m - n]], axis=-1)
        tensorshape_util.set_shape(y, static_final_shape)
        return y
def kl_divergence(distribution_a,
                  distribution_b,
                  allow_nan_stats=True,
                  name=None):
    """Get the KL-divergence KL(distribution_a || distribution_b).

  If there is no KL method registered specifically for `type(distribution_a)`
  and `type(distribution_b)`, then the class hierarchies of these types are
  searched.

  If one KL method is registered between any pairs of classes in these two
  parent hierarchies, it is used.

  If more than one such registered method exists, the method whose registered
  classes have the shortest sum MRO paths to the input types is used.

  If more than one such shortest path exists, the first method
  identified in the search is used (favoring a shorter MRO distance to
  `type(distribution_a)`).

  Args:
    distribution_a: The first distribution.
    distribution_b: The second distribution.
    allow_nan_stats: Python `bool`, default `True`. When `True`,
      statistics (e.g., mean, mode, variance) use the value "`NaN`" to
      indicate the result is undefined. When `False`, an exception is raised
      if one or more of the statistic's batch members are undefined.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    A Tensor with the batchwise KL-divergence between `distribution_a`
    and `distribution_b`.

  Raises:
    NotImplementedError: If no KL method is defined for distribution types
      of `distribution_a` and `distribution_b`.
  """
    kl_fn = _registered_kl(type(distribution_a), type(distribution_b))
    if kl_fn is None:
        raise NotImplementedError(
            "No KL(distribution_a || distribution_b) registered for distribution_a "
            "type {} and distribution_b type {}".format(
                type(distribution_a).__name__,
                type(distribution_b).__name__))

    name = name or "KullbackLeibler"
    with tf.name_scope(name):
        # pylint: disable=protected-access
        with distribution_a._name_and_control_scope(name + "_a"):
            with distribution_b._name_and_control_scope(name + "_b"):
                kl_t = kl_fn(distribution_a, distribution_b, name=name)
                if allow_nan_stats:
                    return kl_t

        # Check KL for NaNs
        kl_t = tf.identity(kl_t, name="kl")

        with tf.control_dependencies([
                tf.Assert(
                    tf.logical_not(tf.reduce_any(tf.math.is_nan(kl_t))),
                    [("KL calculation between {} and {} returned NaN values "
                      "(and was called with allow_nan_stats=False). Values:".
                      format(distribution_a.name, distribution_b.name)), kl_t])
        ]):
            return tf.identity(kl_t, name="checked_kl")
Example #24
0
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None):
    """Computes the (partial) pivoted cholesky decomposition of `matrix`.

  The pivoted Cholesky is a low rank approximation of the Cholesky decomposition
  of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The
  currently-worst-approximated diagonal element is selected as the pivot at each
  iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn,
  N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`.
  Note that, unlike the Cholesky decomposition, `lr` is not triangular even in
  a rectangular-matrix sense. However, under a permutation it could be made
  triangular (it has one more zero in each column as you move to the right).

  Such a matrix can be useful as a preconditioner for conjugate gradient
  optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be
  cheaply done via the Woodbury matrix identity, as implemented by
  `tf.linalg.LinearOperatorLowRankUpdate`.

  Args:
    matrix: Floating point `Tensor` batch of symmetric, positive definite
      matrices.
    max_rank: Scalar `int` `Tensor`, the rank at which to truncate the
      approximation.
    diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the
      errors of all diagonal elements of `lr @ lr.T` are each lower than
      `element * diag_rtol`, iteration is permitted to terminate early.
    name: Optional name for the op.

  Returns:
    lr: Low rank pivoted Cholesky approximation of `matrix`.

  #### References

  [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the
       pivoted Cholesky decomposition. _Applied numerical mathematics_,
       62(4):428-440, 2012.

  [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points.
       _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114
  """
    with tf.name_scope(name or 'pivoted_cholesky'):
        dtype = dtype_util.common_dtype([matrix, diag_rtol],
                                        dtype_hint=tf.float32)
        matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype)
        if tensorshape_util.rank(matrix.shape) is None:
            raise NotImplementedError(
                'Rank of `matrix` must be known statically')

        max_rank = tf.convert_to_tensor(max_rank,
                                        name='max_rank',
                                        dtype=tf.int64)
        max_rank = tf.minimum(
            max_rank,
            prefer_static.shape(matrix, out_type=tf.int64)[-1])
        diag_rtol = tf.convert_to_tensor(diag_rtol,
                                         dtype=dtype,
                                         name='diag_rtol')
        matrix_diag = tf.linalg.diag_part(matrix)
        # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs.
        orig_error = tf.reduce_max(matrix_diag, axis=-1)

        def cond(m, pchol, perm, matrix_diag):
            """Condition for `tf.while_loop` continuation."""
            del pchol
            del perm
            error = tf.linalg.norm(matrix_diag, ord=1, axis=-1)
            max_err = tf.reduce_max(error / orig_error)
            return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol))

        batch_dims = tensorshape_util.rank(matrix.shape) - 2

        def batch_gather(params, indices, axis=-1):
            return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)

        def body(m, pchol, perm, matrix_diag):
            """Body of a single `tf.while_loop` iteration."""
            # Here is roughly a numpy, non-batched version of what's going to happen.
            # (See also Algorithm 1 of Harbrecht et al.)
            # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
            # 2: maxval = matrix_diag[perm][maxi]
            # 3: perm[m], perm[maxi] = perm[maxi], perm[m]
            # 4: row = matrix[perm[m]][perm[m + 1:]]
            # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
            # 6: pivot = np.sqrt(maxval); row /= pivot
            # 7: row = np.concatenate([[[pivot]], row], -1)
            # 8: matrix_diag[perm[m:]] -= row**2
            # 9: pchol[m, perm[m:]] = row

            # Find the maximal position of the (remaining) permuted diagonal.
            # Steps 1, 2 above.
            permuted_diag = batch_gather(matrix_diag, perm[..., m:])
            maxi = tf.argmax(permuted_diag, axis=-1,
                             output_type=tf.int64)[..., tf.newaxis]
            maxval = batch_gather(permuted_diag, maxi)
            maxi = maxi + m
            maxval = maxval[..., 0]
            # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
            perm = _swap_m_with_i(perm, m, maxi)
            # Step 4.
            row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
            row = batch_gather(row, perm[..., m + 1:])
            # Step 5.
            prev_rows = pchol[..., :m, :]
            prev_rows_perm_m_onward = batch_gather(prev_rows, perm[...,
                                                                   m + 1:])
            prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
            row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col,
                                 axis=-2)[..., tf.newaxis, :]
            # Step 6.
            pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
            # Step 7.
            row = tf.concat([pivot, row / pivot], axis=-1)
            # TODO(b/130899118): Pad grad fails with int64 paddings.
            # Step 8.
            paddings = tf.concat([
                tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32),
                [[tf.cast(m, tf.int32), 0]]
            ],
                                 axis=0)
            diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :]
            reverse_perm = _invert_permutation(perm)
            matrix_diag -= batch_gather(diag_update, reverse_perm)
            # Step 9.
            row = tf.pad(row, paddings=paddings)
            # TODO(bjp): Defer the reverse permutation all-at-once at the end?
            row = batch_gather(row, reverse_perm)
            pchol_shape = pchol.shape
            pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
                              axis=-2)
            tensorshape_util.set_shape(pchol, pchol_shape)
            return m + 1, pchol, perm, matrix_diag

        m = np.int64(0)
        pchol = tf.zeros_like(matrix[..., :max_rank, :])
        matrix_shape = prefer_static.shape(matrix, out_type=tf.int64)
        perm = tf.broadcast_to(prefer_static.range(matrix_shape[-1]),
                               matrix_shape[:-1])
        _, pchol, _, _ = tf.while_loop(cond=cond,
                                       body=body,
                                       loop_vars=(m, pchol, perm, matrix_diag))
        pchol = tf.linalg.matrix_transpose(pchol)
        tensorshape_util.set_shape(
            pchol, tensorshape_util.concatenate(matrix_diag.shape, [None]))
        return pchol
Example #25
0
    def __init__(self,
                 perm=None,
                 rightmost_transposed_ndims=None,
                 validate_args=False,
                 name='transpose'):
        """Instantiates the `Transpose` bijector.

    Args:
      perm: Positive `int32` vector-shaped `Tensor` representing permutation of
        rightmost dims (for forward transformation).  Note that the `0`th index
        represents the first of the rightmost dims and the largest value must be
        `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified.
        Default value:
        `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`.
      rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor`
        representing the number of rightmost dimensions to permute.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified.
        Default value: `tf.size(perm)`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.

    Raises:
      ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are
        specified.
      NotImplementedError: if `rightmost_transposed_ndims` is not known prior to
        graph execution.
    """
        with tf.name_scope(name) as name:
            if (rightmost_transposed_ndims is None) == (perm is None):
                raise ValueError('Must specify exactly one of '
                                 '`rightmost_transposed_ndims` and `perm`.')
            if rightmost_transposed_ndims is not None:
                rightmost_transposed_ndims = tf.convert_to_tensor(
                    rightmost_transposed_ndims,
                    dtype_hint=np.int32,
                    name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                assertions = _maybe_validate_rightmost_transposed_ndims(
                    rightmost_transposed_ndims, validate_args)
                if assertions:
                    with tf.control_dependencies(assertions):
                        rightmost_transposed_ndims = tf.identity(
                            rightmost_transposed_ndims)
                perm_start = (distribution_util.prefer_static_value(
                    rightmost_transposed_ndims) - 1)
                perm = tf.range(start=perm_start,
                                limit=-1,
                                delta=-1,
                                name='perm')
            else:  # perm is not None:
                perm = tf.convert_to_tensor(perm,
                                            dtype_hint=np.int32,
                                            name='perm')
                rightmost_transposed_ndims = tf.size(
                    perm, name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                assertions = _maybe_validate_perm(perm, validate_args)
                if assertions:
                    with tf.control_dependencies(assertions):
                        perm = tf.identity(perm)

            # TODO(b/110828604): If bijector base class ever supports dynamic
            # `min_event_ndims`, then this class already works dynamically and the
            # following five lines can be removed.
            if rightmost_transposed_ndims_ is None:
                raise NotImplementedError(
                    '`rightmost_transposed_ndims` must be '
                    'known prior to graph execution.')
            else:
                rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_)

            self._perm = perm
            self._rightmost_transposed_ndims = rightmost_transposed_ndims
            super(Transpose, self).__init__(
                forward_min_event_ndims=rightmost_transposed_ndims_,
                is_constant_jacobian=True,
                validate_args=validate_args,
                name=name)
Example #26
0
def pinv(a, rcond=None, validate_args=False, name=None):
    """Compute the Moore-Penrose pseudo-inverse of a matrix.

  Calculate the [generalized inverse of a matrix](
  https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its
  singular-value decomposition (SVD) and including all large singular values.

  The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves'
  [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then
  `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if
  `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then
  `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1]

  This function is analogous to [`numpy.linalg.pinv`](
  https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html).
  It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
  default `rcond` is `1e-15`. Here the default is
  `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`.

  Args:
    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
      pseudo-inverted.
    rcond: `Tensor` of small singular value cutoffs.  Singular values smaller
      (in modulus) than `rcond` * largest_singular_value (again, in modulus) are
      set to zero. Must broadcast against `tf.shape(a)[:-2]`.
      Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`.
    validate_args: When `True`, additional assertions might be embedded in the
      graph.
      Default value: `False` (i.e., no graph assertions are added).
    name: Python `str` prefixed to ops created by this function.
      Default value: 'pinv'.

  Returns:
    a_pinv: The pseudo-inverse of input `a`. Has same shape as `a` except
      rightmost two dimensions are transposed.

  Raises:
    TypeError: if input `a` does not have `float`-like `dtype`.
    ValueError: if input `a` has fewer than 2 dimensions.

  #### Examples

  ```python
  from tensorflow_probability.python.internal.backend import numpy as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.numpy

  a = tf.constant([[1.,  0.4,  0.5],
                   [0.4, 0.2,  0.25],
                   [0.5, 0.25, 0.35]])
  tf.matmul(tfp.math.pinv(a), a)
  # ==> array([[1., 0., 0.],
               [0., 1., 0.],
               [0., 0., 1.]], dtype=float32)

  a = tf.constant([[1.,  0.4,  0.5,  1.],
                   [0.4, 0.2,  0.25, 2.],
                   [0.5, 0.25, 0.35, 3.]])
  tf.matmul(tfp.math.pinv(a), a)
  # ==> array([[ 0.76,  0.37,  0.21, -0.02],
               [ 0.37,  0.43, -0.33,  0.02],
               [ 0.21, -0.33,  0.81,  0.01],
               [-0.02,  0.02,  0.01,  1.  ]], dtype=float32)
  ```

  #### References

  [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press,
       Inc., 1980, pp. 139-142.
  """
    with tf.name_scope(name or 'pinv'):
        a = tf.convert_to_tensor(a, name='a')

        assertions = _maybe_validate_matrix(a, validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                a = tf.identity(a)

        dtype = dtype_util.as_numpy_dtype(a.dtype)

        if rcond is None:

            def get_dim_size(dim):
                if tf.compat.dimension_value(a.shape[dim]) is not None:
                    return tf.compat.dimension_value(a.shape[dim])
                return tf.shape(a)[dim]

            num_rows = get_dim_size(-2)
            num_cols = get_dim_size(-1)
            if isinstance(num_rows, int) and isinstance(num_cols, int):
                max_rows_cols = float(max(num_rows, num_cols))
            else:
                max_rows_cols = tf.cast(tf.maximum(num_rows, num_cols), dtype)
            rcond = 10. * max_rows_cols * np.finfo(dtype).eps

        rcond = tf.convert_to_tensor(rcond, dtype=dtype, name='rcond')

        # Calculate pseudo inverse via SVD.
        # Note: if a is symmetric then u == v. (We might observe additional
        # performance by explicitly setting `v = u` in such cases.)
        [
            singular_values,  # Sigma
            left_singular_vectors,  # U
            right_singular_vectors,  # V
        ] = tf.linalg.svd(a, full_matrices=False, compute_uv=True)

        # Saturate small singular values to inf. This has the effect of make
        # `1. / s = 0.` while not resulting in `NaN` gradients.
        cutoff = rcond * tf.reduce_max(singular_values, axis=-1)
        singular_values = tf.where(singular_values > cutoff[..., tf.newaxis],
                                   singular_values, np.array(np.inf, dtype))

        # Although `a == tf.matmul(u, s * v, transpose_b=True)` we swap
        # `u` and `v` here so that `tf.matmul(pinv(A), A) = tf.eye()`, i.e.,
        # a matrix inverse has 'transposed' semantics.
        a_pinv = tf.matmul(right_singular_vectors /
                           singular_values[..., tf.newaxis, :],
                           left_singular_vectors,
                           adjoint_b=True)

        if tensorshape_util.rank(a.shape) is not None:
            a_pinv.set_shape(a.shape[:-2].concatenate(
                [a.shape[-1], a.shape[-2]]))

        return a_pinv
Example #27
0
 def __init__(self, validate_args=False, name="sigmoid"):
     with tf.name_scope(name) as name:
         super(Sigmoid, self).__init__(forward_min_event_ndims=0,
                                       validate_args=validate_args,
                                       name=name)
Example #28
0
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  from tensorflow_probability.python.internal.backend import numpy as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.numpy

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_solve'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        rhs = tf.convert_to_tensor(rhs,
                                   dtype_hint=lower_upper.dtype,
                                   name='rhs')

        assertions = _lu_solve_assertions(lower_upper, perm, rhs,
                                          validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        if (tensorshape_util.rank(rhs.shape) == 2
                and tensorshape_util.rank(perm.shape) == 1):
            # Both rhs and perm have scalar batch_shape.
            permuted_rhs = tf.gather(rhs, perm, axis=-2)
        else:
            # Either rhs or perm have non-scalar batch_shape or we can't determine
            # this information statically.
            rhs_shape = tf.shape(rhs)
            broadcast_batch_shape = tf.broadcast_dynamic_shape(
                rhs_shape[:-2],
                tf.shape(perm)[:-1])
            d, m = rhs_shape[-2], rhs_shape[-1]
            rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]],
                                            axis=0)

            # Tile out rhs.
            broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
            broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

            # Tile out perm and add batch indices.
            broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
            broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
            broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
            broadcast_batch_indices = tf.broadcast_to(
                tf.range(broadcast_batch_size)[:, tf.newaxis],
                [broadcast_batch_size, d])
            broadcast_perm = tf.stack(
                [broadcast_batch_indices, broadcast_perm], axis=-1)

            permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
            permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
        return linear_operator_util.matrix_triangular_solve_with_broadcast(
            lower_upper,  # Only upper is accessed.
            linear_operator_util.matrix_triangular_solve_with_broadcast(
                lower, permuted_rhs),
            lower=False)
Example #29
0
  def __init__(self,
               forward_fn=None,
               inverse_fn=None,
               inverse_log_det_jacobian_fn=None,
               forward_log_det_jacobian_fn=None,
               forward_event_shape_fn=None,
               forward_event_shape_tensor_fn=None,
               inverse_event_shape_fn=None,
               inverse_event_shape_tensor_fn=None,
               is_constant_jacobian=False,
               validate_args=False,
               forward_min_event_ndims=None,
               inverse_min_event_ndims=None,
               name='inline'):
    """Creates a `Bijector` from callables.

    At the minimum, you must supply one of `forward_min_event_ndims` or
    `inverse_min_event_ndims`. To be fully functional, a typical bijector will
    also require `forward_fn`, `inverse_fn` and at least one of
    `inverse_log_det_jacobian_fn` or `forward_log_det_jacobian_fn`.

    Args:
      forward_fn: Python callable implementing the forward transformation.
      inverse_fn: Python callable implementing the inverse transformation.
      inverse_log_det_jacobian_fn: Python callable implementing the
        `log o det o jacobian` of the inverse transformation.
      forward_log_det_jacobian_fn: Python callable implementing the
        `log o det o jacobian` of the forward transformation.
      forward_event_shape_fn: Python callable implementing non-identical
        static event shape changes. Default: shape is assumed unchanged.
      forward_event_shape_tensor_fn: Python callable implementing non-identical
        event shape changes. Default: shape is assumed unchanged.
      inverse_event_shape_fn: Python callable implementing non-identical
        static event shape changes. Default: shape is assumed unchanged.
      inverse_event_shape_tensor_fn: Python callable implementing non-identical
        event shape changes. Default: shape is assumed unchanged.
      is_constant_jacobian: Python `bool` indicating that the Jacobian is
        constant for all input arguments.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      forward_min_event_ndims: Python `int` indicating the minimal
        dimensionality this bijector acts on.
      inverse_min_event_ndims: Python `int` indicating the minimal
        dimensionality this bijector acts on.
      name: Python `str`, name given to ops managed by this object.

    Raises:
      TypeError: If any of the non-`None` `*_fn` arguments are not callable.
    """
    with tf.name_scope(name) as name:
      self._maybe_implement(forward_fn, '_forward', 'forward_fn')
      self._maybe_implement(inverse_fn, '_inverse', 'inverse_fn')
      self._maybe_implement(inverse_log_det_jacobian_fn,
                            '_inverse_log_det_jacobian',
                            'inverse_log_det_jacobian_fn')
      self._maybe_implement(forward_log_det_jacobian_fn,
                            '_forward_log_det_jacobian',
                            'forward_log_det_jacobian_fn')

      # By default assume shape doesn't change.
      self._forward_event_shape = _maybe_impute_as_identity(
          forward_event_shape_fn, 'forward_event_shape_fn')
      self._forward_event_shape_tensor = _maybe_impute_as_identity(
          forward_event_shape_tensor_fn, 'forward_event_shape_tensor_fn')
      self._inverse_event_shape = _maybe_impute_as_identity(
          inverse_event_shape_fn, 'inverse_event_shape_fn')
      self._inverse_event_shape_tensor = _maybe_impute_as_identity(
          inverse_event_shape_tensor_fn, 'inverse_event_shape_tensor_fn')

      super(Inline, self).__init__(
          forward_min_event_ndims=forward_min_event_ndims,
          inverse_min_event_ndims=inverse_min_event_ndims,
          is_constant_jacobian=is_constant_jacobian,
          validate_args=validate_args,
          name=name)
    def __init__(self,
                 mix_loc,
                 temperature,
                 distribution,
                 loc=None,
                 scale=None,
                 quadrature_size=8,
                 quadrature_fn=quadrature_scheme_softmaxnormal_quantiles,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VectorDiffeomixture"):
        """Constructs the VectorDiffeomixture on `R^d`.

    The vector diffeomixture (VDM) approximates the compound distribution

    ```none
    p(x) = int p(x | z) p(z) dz,
    where z is in the K-simplex, and
    p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
    ```

    Args:
      mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
        In terms of samples, larger `mix_loc[..., k]` ==>
        `Z` is more likely to put more weight on its `kth` component.
      temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`.
        In terms of samples, smaller `temperature` means one component is more
        likely to dominate.  I.e., smaller `temperature` makes the VDM look more
        like a standard mixture of `K` components.
      distribution: `tfp.distributions.Distribution`-like instance. Distribution
        from which `d` iid samples are used as input to the selected affine
        transformation. Must be a scalar-batch, scalar-event distribution.
        Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED`
        or it is a function of non-trainable parameters. WARNING: If you
        backprop through a VectorDiffeomixture sample and the `distribution`
        is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables,
        then the gradient will be incorrect!
      loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element
        represents the `shift` used for the `k`-th affine transformation.  If
        the `k`-th item is `None`, `loc` is implicitly `0`.  When specified,
        must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event
        size.
      scale: Length-`K` list of `LinearOperator`s. Each should be
        positive-definite and operate on a `d`-dimensional vector space. The
        `k`-th element represents the `scale` used for the `k`-th affine
        transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
        `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
      quadrature_size: Python `int` scalar representing number of
        quadrature points.  Larger `quadrature_size` means `q_N(x)` better
        approximates `p(x)`.
      quadrature_fn: Python callable taking `normal_loc`, `normal_scale`,
        `quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
        representing the SoftmaxNormal grid and corresponding normalized weight.
        normalized) weight.
        Default value: `quadrature_scheme_softmaxnormal_quantiles`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if `not scale or len(scale) < 2`.
      ValueError: if `len(loc) != len(scale)`
      ValueError: if `quadrature_grid_and_probs is not None` and
        `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
      ValueError: if `validate_args` and any not scale.is_positive_definite.
      TypeError: if any scale.dtype != scale[0].dtype.
      TypeError: if any loc.dtype != scale[0].dtype.
      NotImplementedError: if `len(scale) != 2`.
      ValueError: if `not distribution.is_scalar_batch`.
      ValueError: if `not distribution.is_scalar_event`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            if not scale or len(scale) < 2:
                raise ValueError(
                    "Must specify list (or list-like object) of scale "
                    "LinearOperators, one for each component with "
                    "num_component >= 2.")

            if loc is None:
                loc = [None] * len(scale)

            if len(loc) != len(scale):
                raise ValueError("loc/scale must be same-length lists "
                                 "(or same-length list-like objects).")

            dtype = dtype_util.base_dtype(scale[0].dtype)

            loc = [
                tf.convert_to_tensor(loc_, dtype=dtype, name="loc{}".format(k))
                if loc_ is not None else None for k, loc_ in enumerate(loc)
            ]

            for k, scale_ in enumerate(scale):
                if validate_args and not scale_.is_positive_definite:
                    raise ValueError(
                        "scale[{}].is_positive_definite = {} != True".format(
                            k, scale_.is_positive_definite))
                if dtype_util.base_dtype(scale_.dtype) != dtype:
                    raise TypeError(
                        "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\""
                        .format(k, dtype_util.name(scale_.dtype),
                                dtype_util.name(dtype)))

            self._endpoint_affine = [
                affine_linear_operator_bijector.AffineLinearOperator(  # pylint: disable=g-complex-comprehension
                    shift=loc_,
                    scale=scale_,
                    validate_args=validate_args,
                    name="endpoint_affine_{}".format(k))
                for k, (loc_, scale_) in enumerate(zip(loc, scale))
            ]

            # TODO(jvdillon): Remove once we support k-mixtures.
            # We make this assertion here because otherwise `grid` would need to be a
            # vector not a scalar.
            if len(scale) != 2:
                raise NotImplementedError(
                    "Currently only bimixtures are supported; "
                    "len(scale)={} is not 2.".format(len(scale)))

            mix_loc = tf.convert_to_tensor(mix_loc,
                                           dtype=dtype,
                                           name="mix_loc")
            temperature = tf.convert_to_tensor(temperature,
                                               dtype=dtype,
                                               name="temperature")
            self._grid, probs = tuple(
                quadrature_fn(mix_loc / temperature, 1. / temperature,
                              quadrature_size, validate_args))

            # Note: by creating the logits as `log(prob)` we ensure that
            # `self.mixture_distribution.logits` is equivalent to
            # `math_ops.log(self.mixture_distribution.probs)`.
            self._mixture_distribution = categorical.Categorical(
                logits=tf.math.log(probs),
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)

            asserts = distribution_util.maybe_check_scalar_distribution(
                distribution, dtype, validate_args)
            if asserts:
                self._grid = distribution_util.with_dependencies(
                    asserts, self._grid)
            self._distribution = distribution

            self._interpolated_affine = [
                affine_linear_operator_bijector.AffineLinearOperator(  # pylint: disable=g-complex-comprehension
                    shift=loc_,
                    scale=scale_,
                    validate_args=validate_args,
                    name="interpolated_affine_{}".format(k))
                for k, (loc_, scale_) in enumerate(
                    zip(interpolate_loc(self._grid, loc),
                        interpolate_scale(self._grid, scale)))
            ]

            [
                self._batch_shape_,
                self._batch_shape_tensor_,
                self._event_shape_,
                self._event_shape_tensor_,
            ] = determine_batch_event_shapes(self._grid, self._endpoint_affine)

            super(VectorDiffeomixture, self).__init__(
                dtype=dtype,
                # We hard-code `FULLY_REPARAMETERIZED` because when
                # `validate_args=True` we verify that indeed
                # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A
                # distribution which is a function of only non-trainable parameters
                # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot
                # easily test for that possibility thus we use `validate_args=False`
                # as a "back-door" to allow users a way to use non
                # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS
                # RESPONSIBILITY to verify that the base distribution is a function of
                # non-trainable parameters.
                reparameterization_type=reparameterization.
                FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)