def quadrature_scheme_softmaxnormal_quantiles(normal_loc,
                                              normal_scale,
                                              quadrature_size,
                                              validate_args=False,
                                              name=None):
    """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex.

  A `SoftmaxNormal` random variable `Y` may be generated via

  ```
  Y = SoftmaxCentered(X),
  X = Normal(normal_loc, normal_scale)
  ```

  Args:
    normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
      The location parameter of the Normal used to construct the SoftmaxNormal.
    normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
      The scale parameter of the Normal used to construct the SoftmaxNormal.
    quadrature_size: Python `int` scalar representing the number of quadrature
      points.
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      convex combination of affine parameters for `K` components.
      `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex.
    probs:  Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      associated with each grid point.
  """
    with tf.name_scope(name or "softmax_normal_grid_and_probs"):
        normal_loc = tf.convert_to_tensor(normal_loc, name="normal_loc")
        dt = dtype_util.base_dtype(normal_loc.dtype)
        normal_scale = tf.convert_to_tensor(normal_scale,
                                            dtype=dt,
                                            name="normal_scale")

        normal_scale = maybe_check_quadrature_param(normal_scale,
                                                    "normal_scale",
                                                    validate_args)

        dist = normal.Normal(loc=normal_loc, scale=normal_scale)

        def _get_batch_ndims():
            """Helper to get rank(dist.batch_shape), statically if possible."""
            ndims = tensorshape_util.rank(dist.batch_shape)
            if ndims is None:
                ndims = tf.shape(dist.batch_shape_tensor())[0]
            return ndims

        batch_ndims = _get_batch_ndims()

        def _get_final_shape(qs):
            """Helper to build `TensorShape`."""
            bs = tensorshape_util.with_rank_at_least(dist.batch_shape, 1)
            num_components = tf.compat.dimension_value(bs[-1])
            if num_components is not None:
                num_components += 1
            tail = tf.TensorShape([num_components, qs])
            return bs[:-1].concatenate(tail)

        def _compute_quantiles():
            """Helper to build quantiles."""
            # Omit {0, 1} since they might lead to Inf/NaN.
            zero = tf.zeros([], dtype=dist.dtype)
            edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
            # Expand edges so its broadcast across batch dims.
            edges = tf.reshape(
                edges,
                shape=tf.concat(
                    [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
            quantiles = dist.quantile(edges)
            quantiles = softmax_centered_bijector.SoftmaxCentered().forward(
                quantiles)
            # Cyclically permute left by one.
            perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
            quantiles = tf.transpose(a=quantiles, perm=perm)
            tensorshape_util.set_shape(quantiles,
                                       _get_final_shape(quadrature_size + 1))
            return quantiles

        quantiles = _compute_quantiles()

        # Compute grid as quantile midpoints.
        grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2.
        # Set shape hints.
        tensorshape_util.set_shape(grid, _get_final_shape(quadrature_size))

        # By construction probs is constant, i.e., `1 / quadrature_size`. This is
        # important, because non-constant probs leads to non-reparameterizable
        # samples.
        probs = tf.fill(dims=[quadrature_size],
                        value=1. / tf.cast(quadrature_size, dist.dtype))

        return grid, probs
Exemplo n.º 2
0
 def _variance(self):
     if self.allow_nan_stats:
         return tf.fill(self.batch_shape_tensor(),
                        dtype_util.as_numpy_dtype(self.dtype)(np.nan))
     raise ValueError(
         '`variance` is undefined for the half-Cauchy distribution.')
def quadrature_scheme_lognormal_quantiles(loc,
                                          scale,
                                          quadrature_size,
                                          validate_args=False,
                                          name=None):
    """Use LogNormal quantiles to form quadrature on positive-reals.

  Args:
    loc: `float`-like (batch of) scalar `Tensor`; the location parameter of
      the LogNormal prior.
    scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
      the LogNormal prior.
    quadrature_size: Python `int` scalar representing the number of quadrature
      points.
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    grid: (Batch of) length-`quadrature_size` vectors representing the
      `log_rate` parameters of a `Poisson`.
    probs: (Batch of) length-`quadrature_size` vectors representing the
      weight associate with each `grid` value.
  """
    with tf.name_scope(name or 'quadrature_scheme_lognormal_quantiles'):
        # Create a LogNormal distribution.
        dist = transformed_distribution.TransformedDistribution(
            distribution=normal.Normal(loc=loc, scale=scale),
            bijector=exp_bijector.Exp(),
            validate_args=validate_args)
        batch_ndims = tensorshape_util.rank(dist.batch_shape)
        if batch_ndims is None:
            batch_ndims = tf.shape(dist.batch_shape_tensor())[0]

        def _compute_quantiles():
            """Helper to build quantiles."""
            # Omit {0, 1} since they might lead to Inf/NaN.
            zero = tf.zeros([], dtype=dist.dtype)
            edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
            # Expand edges so its broadcast across batch dims.
            edges = tf.reshape(
                edges,
                shape=tf.concat(
                    [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
            quantiles = dist.quantile(edges)
            # Cyclically permute left by one.
            perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
            quantiles = tf.transpose(a=quantiles, perm=perm)
            return quantiles

        quantiles = _compute_quantiles()

        # Compute grid as quantile midpoints.
        grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2.
        # Set shape hints.
        new_shape = tensorshape_util.concatenate(dist.batch_shape,
                                                 [quadrature_size])
        tensorshape_util.set_shape(grid, new_shape)

        # By construction probs is constant, i.e., `1 / quadrature_size`. This is
        # important, because non-constant probs leads to non-reparameterizable
        # samples.
        probs = tf.fill(dims=[quadrature_size],
                        value=1. / tf.cast(quadrature_size, dist.dtype))

        return grid, probs
Exemplo n.º 4
0
 def _stddev(self):
     if self.allow_nan_stats:
         return tf.fill(self.batch_shape_tensor(),
                        dtype_util.as_numpy_dtype(self.dtype)(np.nan))
     else:
         raise ValueError('`stddev` is undefined for Cauchy distribution.')