예제 #1
0
def normal_conjugates_known_scale_posterior(prior, scale, s, n):
    """Posterior Normal distribution with conjugate prior on the mean.

  This model assumes that `n` observations (with sum `s`) come from a
  Normal with unknown mean `loc` (described by the Normal `prior`)
  and known variance `scale**2`. The "known scale posterior" is
  the distribution of the unknown `loc`.

  Accepts a prior Normal distribution object, having parameters
  `loc0` and `scale0`, as well as known `scale` values of the predictive
  distribution(s) (also assumed Normal),
  and statistical estimates `s` (the sum(s) of the observations) and
  `n` (the number(s) of observations).

  Returns a posterior (also Normal) distribution object, with parameters
  `(loc', scale'**2)`, where:

  ```
  mu ~ N(mu', sigma'**2)
  sigma'**2 = 1/(1/sigma0**2 + n/sigma**2),
  mu' = (mu0/sigma0**2 + s/sigma**2) * sigma'**2.
  ```

  Distribution parameters from `prior`, as well as `scale`, `s`, and `n`.
  will broadcast in the case of multidimensional sets of parameters.

  Args:
    prior: `Normal` object of type `dtype`:
      the prior distribution having parameters `(loc0, scale0)`.
    scale: tensor of type `dtype`, taking values `scale > 0`.
      The known stddev parameter(s).
    s: Tensor of type `dtype`. The sum(s) of observations.
    n: Tensor of type `int`. The number(s) of observations.

  Returns:
    A new Normal posterior distribution object for the unknown observation
    mean `loc`.

  Raises:
    TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
      Normal object.
  """
    if not isinstance(prior, normal.Normal):
        raise TypeError("Expected prior to be an instance of type Normal")

    if s.dtype != prior.dtype:
        raise TypeError(
            "Observation sum s.dtype does not match prior dtype: %s vs. %s" %
            (s.dtype, prior.dtype))

    n = tf.cast(n, prior.dtype)
    scale0_2 = tf.square(prior.scale)
    scale_2 = tf.square(scale)
    scalep_2 = 1.0 / (1 / scale0_2 + n / scale_2)
    return normal.Normal(loc=(prior.loc / scale0_2 + s / scale_2) * scalep_2,
                         scale=tf.sqrt(scalep_2))
예제 #2
0
def _uniform_unit_norm(dimension, shape, dtype, seed):
    """Returns a batch of points chosen uniformly from the unit hypersphere."""
    # This works because the Gaussian distribution is spherically symmetric.
    # raw shape: shape + [dimension]
    raw = normal.Normal(loc=dtype_util.as_numpy_dtype(dtype)(0),
                        scale=dtype_util.as_numpy_dtype(dtype)(1)).sample(
                            tf.concat([shape, [dimension]], axis=0),
                            seed=seed())
    unit_norm = raw / tf.norm(raw, ord=2, axis=-1)[..., tf.newaxis]
    return unit_norm
예제 #3
0
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='LogNormal'):
        """Construct a log-normal distribution.

    The LogNormal distribution models positive-valued random variables
    whose logarithm is normally distributed with mean `loc` and
    standard deviation `scale`. It is constructed as the exponential
    transformation of a Normal distribution.

    Args:
      loc: Floating-point `Tensor`; the means of the underlying
        Normal distribution(s).
      scale: Floating-point `Tensor`; the stddevs of the underlying
        Normal distribution(s).
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            super(LogNormal,
                  self).__init__(distribution=normal.Normal(loc=loc,
                                                            scale=scale),
                                 bijector=exp_bijector.Exp(),
                                 validate_args=validate_args,
                                 parameters=parameters,
                                 name=name)
    def __init__(self,
                 loc=None,
                 scale=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='MultivariateNormalLinearOperator'):
        """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
        `[B1, ..., Bb, k, k]`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if `scale` is unspecified.
      TypeError: if not `scale.dtype.is_floating`
    """
        parameters = dict(locals())
        if scale is None:
            raise ValueError('Missing required `scale` parameter.')
        if not dtype_util.is_floating(scale.dtype):
            raise TypeError(
                '`scale` parameter must have floating-point dtype.')

        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale],
                                            dtype_hint=tf.float32)
            # Since expand_dims doesn't preserve constant-ness, we obtain the
            # non-dynamic value if possible.
            loc = tensor_util.convert_nonref_to_tensor(loc,
                                                       dtype=dtype,
                                                       name='loc')
            batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
                loc, scale)

        super(MultivariateNormalLinearOperator, self).__init__(
            distribution=normal.Normal(loc=tf.zeros([], dtype=dtype),
                                       scale=tf.ones([], dtype=dtype)),
            bijector=affine_linear_operator_bijector.AffineLinearOperator(
                shift=loc, scale=scale, validate_args=validate_args),
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
            name=name)
        self._parameters = parameters
def quadrature_scheme_softmaxnormal_quantiles(normal_loc,
                                              normal_scale,
                                              quadrature_size,
                                              validate_args=False,
                                              name=None):
    """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex.

  A `SoftmaxNormal` random variable `Y` may be generated via

  ```
  Y = SoftmaxCentered(X),
  X = Normal(normal_loc, normal_scale)
  ```

  Args:
    normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
      The location parameter of the Normal used to construct the SoftmaxNormal.
    normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
      The scale parameter of the Normal used to construct the SoftmaxNormal.
    quadrature_size: Python `int` scalar representing the number of quadrature
      points.
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      convex combination of affine parameters for `K` components.
      `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex.
    probs:  Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
      associated with each grid point.
  """
    with tf.name_scope(name or "softmax_normal_grid_and_probs"):
        normal_loc = tf.convert_to_tensor(normal_loc, name="normal_loc")
        dt = dtype_util.base_dtype(normal_loc.dtype)
        normal_scale = tf.convert_to_tensor(normal_scale,
                                            dtype=dt,
                                            name="normal_scale")

        normal_scale = maybe_check_quadrature_param(normal_scale,
                                                    "normal_scale",
                                                    validate_args)

        dist = normal.Normal(loc=normal_loc, scale=normal_scale)

        def _get_batch_ndims():
            """Helper to get rank(dist.batch_shape), statically if possible."""
            ndims = tensorshape_util.rank(dist.batch_shape)
            if ndims is None:
                ndims = tf.shape(dist.batch_shape_tensor())[0]
            return ndims

        batch_ndims = _get_batch_ndims()

        def _get_final_shape(qs):
            """Helper to build `TensorShape`."""
            bs = tensorshape_util.with_rank_at_least(dist.batch_shape, 1)
            num_components = tf.compat.dimension_value(bs[-1])
            if num_components is not None:
                num_components += 1
            tail = tf.TensorShape([num_components, qs])
            return bs[:-1].concatenate(tail)

        def _compute_quantiles():
            """Helper to build quantiles."""
            # Omit {0, 1} since they might lead to Inf/NaN.
            zero = tf.zeros([], dtype=dist.dtype)
            edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
            # Expand edges so its broadcast across batch dims.
            edges = tf.reshape(
                edges,
                shape=tf.concat(
                    [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
            quantiles = dist.quantile(edges)
            quantiles = softmax_centered_bijector.SoftmaxCentered().forward(
                quantiles)
            # Cyclically permute left by one.
            perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
            quantiles = tf.transpose(a=quantiles, perm=perm)
            tensorshape_util.set_shape(quantiles,
                                       _get_final_shape(quadrature_size + 1))
            return quantiles

        quantiles = _compute_quantiles()

        # Compute grid as quantile midpoints.
        grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2.
        # Set shape hints.
        tensorshape_util.set_shape(grid, _get_final_shape(quadrature_size))

        # By construction probs is constant, i.e., `1 / quadrature_size`. This is
        # important, because non-constant probs leads to non-reparameterizable
        # samples.
        probs = tf.fill(dims=[quadrature_size],
                        value=1. / tf.cast(quadrature_size, dist.dtype))

        return grid, probs
예제 #6
0
def quadrature_scheme_lognormal_quantiles(
    loc, scale, quadrature_size,
    validate_args=False, name=None):
  """Use LogNormal quantiles to form quadrature on positive-reals.

  Args:
    loc: `float`-like (batch of) scalar `Tensor`; the location parameter of
      the LogNormal prior.
    scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
      the LogNormal prior.
    quadrature_size: Python `int` scalar representing the number of quadrature
      points.
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    grid: (Batch of) length-`quadrature_size` vectors representing the
      `log_rate` parameters of a `Poisson`.
    probs: (Batch of) length-`quadrature_size` vectors representing the
      weight associate with each `grid` value.
  """
  with tf.name_scope(name or 'quadrature_scheme_lognormal_quantiles'):
    # Create a LogNormal distribution.
    dist = transformed_distribution.TransformedDistribution(
        distribution=normal.Normal(loc=loc, scale=scale),
        bijector=exp_bijector.Exp(),
        validate_args=validate_args)
    batch_ndims = tensorshape_util.rank(dist.batch_shape)
    if batch_ndims is None:
      batch_ndims = tf.shape(dist.batch_shape_tensor())[0]

    def _compute_quantiles():
      """Helper to build quantiles."""
      # Omit {0, 1} since they might lead to Inf/NaN.
      zero = tf.zeros([], dtype=dist.dtype)
      edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
      # Expand edges so its broadcast across batch dims.
      edges = tf.reshape(
          edges,
          shape=tf.concat(
              [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
      quantiles = dist.quantile(edges)
      # Cyclically permute left by one.
      perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
      quantiles = tf.transpose(a=quantiles, perm=perm)
      return quantiles
    quantiles = _compute_quantiles()

    # Compute grid as quantile midpoints.
    grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2.
    # Set shape hints.
    new_shape = tensorshape_util.concatenate(dist.batch_shape,
                                             [quadrature_size])
    tensorshape_util.set_shape(grid, new_shape)

    # By construction probs is constant, i.e., `1 / quadrature_size`. This is
    # important, because non-constant probs leads to non-reparameterizable
    # samples.
    probs = tf.fill(
        dims=[quadrature_size], value=1. / tf.cast(quadrature_size, dist.dtype))

    return grid, probs
예제 #7
0
    def __init__(self,
                 loc,
                 scale,
                 skewness=None,
                 tailweight=None,
                 distribution=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="SinhArcsinh"):
        """Construct SinhArcsinh distribution on `(-inf, inf)`.

    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
    (indexing batch dimensions).  They must all have the same `dtype`.

    Args:
      loc: Floating-point `Tensor`.
      scale:  `Tensor` of same `dtype` as `loc`.
      skewness:  Skewness parameter.  Default is `0.0` (no skew).
      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
      distribution: `tf.Distribution`-like instance. Distribution that is
        transformed to produce this distribution.
        Default is `tfd.Normal(0., 1.)`.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a `SinhArcsinh` sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())

        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight],
                                            tf.float32)
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name="loc",
                                                             dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name="scale",
                                                               dtype=dtype)
            tailweight = 1. if tailweight is None else tailweight
            has_default_skewness = skewness is None
            skewness = 0. if has_default_skewness else skewness
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name="tailweight", dtype=dtype)
            self._skewness = tensor_util.convert_nonref_to_tensor(
                skewness, name="skewness", dtype=dtype)

            batch_shape = distribution_util.get_broadcast_shape(
                self._loc, self._scale, self._tailweight, self._skewness)

            # Recall, with Z a random variable,
            #   Y := loc + scale * F(Z),
            #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C
            #   C := 2 / F_0(2)
            #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
            if distribution is None:
                distribution = normal.Normal(loc=tf.zeros([], dtype=dtype),
                                             scale=tf.ones([], dtype=dtype),
                                             allow_nan_stats=allow_nan_stats,
                                             validate_args=validate_args)
            else:
                asserts = distribution_util.maybe_check_scalar_distribution(
                    distribution, dtype, validate_args)
                if asserts:
                    self._loc = distribution_util.with_dependencies(
                        asserts, self._loc)

            # Make the SAS bijector, 'F'.
            f = sinh_arcsinh_bijector.SinhArcsinh(skewness=self._skewness,
                                                  tailweight=self._tailweight,
                                                  validate_args=validate_args)

            # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
            affine = affine_scalar_bijector.AffineScalar(
                shift=self._loc,
                scale=self._scale,
                validate_args=validate_args)

            bijector = chain_bijector.Chain([affine, f])

            super(SinhArcsinh, self).__init__(distribution=distribution,
                                              bijector=bijector,
                                              batch_shape=batch_shape,
                                              validate_args=validate_args,
                                              name=name)
            self._parameters = parameters
    def __init__(self,
                 loc=None,
                 scale_diag=None,
                 scale_identity_multiplier=None,
                 skewness=None,
                 tailweight=None,
                 distribution=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VectorSinhArcsinhDiag"):
        """Construct VectorSinhArcsinhDiag distribution on `R^k`.

    The arguments `scale_diag` and `scale_identity_multiplier` combine to
    define the diagonal `scale` referred to in this class docstring:

    ```none
    scale = diag(scale_diag + scale_identity_multiplier * ones(k))
    ```

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
        matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
        and characterizes `b`-batches of `k x k` diagonal matrices added to
        `scale`. When both `scale_identity_multiplier` and `scale_diag` are
        `None` then `scale` is the `Identity`.
      scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
        a scale-identity-matrix added to `scale`. May have shape
        `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scale
        `k x k` identity matrices added to `scale`. When both
        `scale_identity_multiplier` and `scale_diag` are `None` then `scale`
        is the `Identity`.
      skewness:  Skewness parameter.  floating-point `Tensor` with shape
        broadcastable with `event_shape`.
      tailweight:  Tailweight parameter.  floating-point `Tensor` with shape
        broadcastable with `event_shape`.
      distribution: `tf.Distribution`-like instance. Distribution from which `k`
        iid samples are used as input to transformation `F`.  Default is
        `tfd.Normal(loc=0., scale=1.)`.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a VectorSinhArcsinhDiag sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if at most `scale_identity_multiplier` is specified.
    """
        parameters = dict(locals())

        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([
                loc, scale_diag, scale_identity_multiplier, skewness,
                tailweight
            ], tf.float32)
            loc = loc if loc is None else tf.convert_to_tensor(
                loc, name="loc", dtype=dtype)
            tailweight = 1. if tailweight is None else tailweight
            skewness = 0. if skewness is None else skewness

            # Recall, with Z a random variable,
            #   Y := loc + C * F(Z),
            #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
            #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
            #   C := 2 * scale / F_0(2)

            # Construct shapes and 'scale' out of the scale_* and loc kwargs.
            # scale_linop is only an intermediary to:
            #  1. get shapes from looking at loc and the two scale args.
            #  2. combine scale_diag with scale_identity_multiplier, which gives us
            #     'scale', which in turn gives us 'C'.
            scale_linop = distribution_util.make_diag_scale(
                loc=loc,
                scale_diag=scale_diag,
                scale_identity_multiplier=scale_identity_multiplier,
                validate_args=False,
                assert_positive=False,
                dtype=dtype)
            batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
                loc, scale_linop)
            # scale_linop.diag_part() is efficient since it is a diag type linop.
            scale_diag_part = scale_linop.diag_part()
            dtype = scale_diag_part.dtype

            if distribution is None:
                distribution = normal.Normal(loc=tf.zeros([], dtype=dtype),
                                             scale=tf.ones([], dtype=dtype),
                                             allow_nan_stats=allow_nan_stats)
            else:
                asserts = distribution_util.maybe_check_scalar_distribution(
                    distribution, dtype, validate_args)
                if asserts:
                    scale_diag_part = distribution_util.with_dependencies(
                        asserts, scale_diag_part)

            # Make the SAS bijector, 'F'.
            skewness = tf.convert_to_tensor(skewness,
                                            dtype=dtype,
                                            name="skewness")
            tailweight = tf.convert_to_tensor(tailweight,
                                              dtype=dtype,
                                              name="tailweight")
            f = sinh_arcsinh_bijector.SinhArcsinh(skewness=skewness,
                                                  tailweight=tailweight)
            affine = affine_bijector.Affine(shift=loc,
                                            scale_diag=scale_diag_part,
                                            validate_args=validate_args)

            bijector = chain_bijector.Chain([affine, f])

            super(VectorSinhArcsinhDiag,
                  self).__init__(distribution=distribution,
                                 bijector=bijector,
                                 batch_shape=batch_shape,
                                 event_shape=event_shape,
                                 validate_args=validate_args,
                                 name=name)
        self._parameters = parameters
        self._loc = loc
        self._scale = scale_linop
        self._tailweight = tailweight
        self._skewness = skewness