コード例 #1
0
def _max_precision_sum(a, b):
  """Coerces `a` or `b` to the higher-precision dtype, and returns the sum."""
  if not dtype_util.base_equal(a.dtype, b.dtype):
    if dtype_util.size(a.dtype) >= dtype_util.size(b.dtype):
      b = tf.cast(b, a.dtype)
    else:
      a = tf.cast(a, b.dtype)
  return a + b
コード例 #2
0
 def _maybe_assert_dtype(self, x):
   """Helper to check dtype when self.dtype is known."""
   if SKIP_DTYPE_CHECKS:
     return
   if (self.dtype is not None and
       not dtype_util.base_equal(self.dtype, x.dtype)):
     raise TypeError(
         'Input had dtype %s but expected %s.' % (x.dtype, self.dtype))
コード例 #3
0
    def __init__(self,
                 shift=None,
                 scale=None,
                 adjoint=False,
                 validate_args=False,
                 name="affine_linear_operator"):
        """Instantiates the `AffineLinearOperator` bijector.

    Args:
      shift: Floating-point `Tensor`.
      scale:  Subclass of `LinearOperator`. Represents the (batch) positive
        definite matrix `M` in `R^{k x k}`.
      adjoint: Python `bool` indicating whether to use the `scale` matrix as
        specified or its adjoint.
        Default value: `False`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.

    Raises:
      TypeError: if `scale` is not a `LinearOperator`.
      TypeError: if `shift.dtype` does not match `scale.dtype`.
      ValueError: if not `scale.is_non_singular`.
    """
        with tf.name_scope(name) as name:
            # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`.
            dtype = tf.float32
            if shift is not None:
                shift = tf.convert_to_tensor(value=shift, name="shift")
                dtype = dtype_util.base_dtype(shift.dtype)
            self._shift = shift
            if scale is not None:
                if (shift is not None and
                        not dtype_util.base_equal(shift.dtype, scale.dtype)):
                    raise TypeError(
                        "shift.dtype({}) is incompatible with scale.dtype({})."
                        .format(shift.dtype, scale.dtype))
                if not isinstance(scale, tf.linalg.LinearOperator):
                    raise TypeError(
                        "scale is not an instance of tf.LinearOperator")
                if validate_args and not scale.is_non_singular:
                    raise ValueError("Scale matrix must be non-singular.")
                if scale.dtype is not None:
                    dtype = dtype_util.base_dtype(scale.dtype)
            self._scale = scale
            self._adjoint = adjoint
            super(AffineLinearOperator,
                  self).__init__(forward_min_event_ndims=1,
                                 is_constant_jacobian=True,
                                 dtype=dtype,
                                 validate_args=validate_args,
                                 name=name)
コード例 #4
0
  def poisson_and_mixture_distributions(self):
    """Returns the Poisson and Mixture distribution parameterized by the quadrature grid and weights."""
    loc = tf.convert_to_tensor(self.loc)
    scale = tf.convert_to_tensor(self.scale)
    quadrature_grid, quadrature_probs = tuple(self._quadrature_fn(
        loc, scale, self.quadrature_size, self.validate_args))
    dt = quadrature_grid.dtype
    if not dtype_util.base_equal(dt, quadrature_probs.dtype):
      raise TypeError('Quadrature grid dtype ({}) does not match quadrature '
                      'probs dtype ({}).'.format(
                          dtype_util.name(dt),
                          dtype_util.name(quadrature_probs.dtype)))

    dist = poisson.Poisson(
        log_rate=quadrature_grid,
        validate_args=self.validate_args,
        allow_nan_stats=self.allow_nan_stats)

    mixture_dist = categorical.Categorical(
        logits=tf.math.log(quadrature_probs),
        validate_args=self.validate_args,
        allow_nan_stats=self.allow_nan_stats)
    return dist, mixture_dist
コード例 #5
0
  def __init__(self,
               power,
               dtype=tf.int32,
               interpolate_nondiscrete=True,
               sample_maximum_iterations=100,
               validate_args=False,
               allow_nan_stats=False,
               name='Zipf'):
    """Initialize a batch of Zipf distributions.

    Args:
      power: `Float` like `Tensor` representing the power parameter. Must be
        strictly greater than `1`.
      dtype: The `dtype` of `Tensor` returned by `sample`.
        Default value: `tf.int32`.
      interpolate_nondiscrete: Python `bool`. When `False`, `log_prob` returns
        `-inf` (and `prob` returns `0`) for non-integer inputs. When `True`,
        `log_prob` evaluates the continuous function `-power log(k) -
        log(zeta(power))` , which matches the Zipf pmf at integer arguments `k`
        (note that this function is not itself a normalized probability
        log-density).
        Default value: `True`.
      sample_maximum_iterations: Maximum number of iterations of allowable
        iterations in `sample`. When `validate_args=True`, samples which fail to
        reach convergence (subject to this cap) are masked out with
        `self.dtype.min` or `nan` depending on `self.dtype.is_integer`.
        Default value: `100`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or more
        of the statistic's batch members are undefined.
        Default value: `False`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Zipf'`.

    Raises:
      TypeError: if `power` is not `float` like.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      self._power = tensor_util.convert_nonref_to_tensor(
          power,
          name='power',
          dtype=dtype_util.common_dtype([power], dtype_hint=tf.float32))
      if (not dtype_util.is_floating(self._power.dtype) or
          dtype_util.base_equal(self._power.dtype, tf.float16)):
        raise TypeError(
            'power.dtype ({}) is not a supported `float` type.'.format(
                dtype_util.name(self._power.dtype)))
      self._interpolate_nondiscrete = interpolate_nondiscrete
      self._sample_maximum_iterations = sample_maximum_iterations
      super(Zipf, self).__init__(
          dtype=dtype,
          reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          name=name)
コード例 #6
0
ファイル: affine.py プロジェクト: xzxzmmnn/probability
    def __init__(self,
                 shift=None,
                 scale_identity_multiplier=None,
                 scale_diag=None,
                 scale_tril=None,
                 scale_perturb_factor=None,
                 scale_perturb_diag=None,
                 adjoint=False,
                 validate_args=False,
                 name="affine",
                 dtype=None):
        """Instantiates the `Affine` bijector.

    This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments,
    giving the forward operation:

    ```none
    Y = g(X) = scale @ X + shift
    ```

    where the `scale` term is logically equivalent to:

    ```python
    scale = (
      scale_identity_multiplier * tf.diag(tf.ones(d)) +
      tf.diag(scale_diag) +
      scale_tril +
      scale_perturb_factor @ diag(scale_perturb_diag) @
        tf.transpose([scale_perturb_factor])
    )
    ```

    If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are
    specified then `scale += IdentityMatrix`. Otherwise specifying a
    `scale` argument has the semantics of `scale += Expand(arg)`, i.e.,
    `scale_diag != None` means `scale += tf.diag(scale_diag)`.

    Args:
      shift: Floating-point `Tensor`. If this is set to `None`, no shift is
        applied.
      scale_identity_multiplier: floating point rank 0 `Tensor` representing a
        scaling done to the identity matrix.
        When `scale_identity_multiplier = scale_diag = scale_tril = None` then
        `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added
        to `scale`.
      scale_diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape `[N1, N2, ...  k]`, which represents a k x k
        diagonal matrix.
        When `None` no diagonal term is added to `scale`.
      scale_tril: Floating-point `Tensor` representing the lower triangular
        matrix. `scale_tril` has shape `[N1, N2, ...  k, k]`, which represents a
        k x k lower triangular matrix.
        When `None` no `scale_tril` term is added to `scale`.
        The upper triangular elements above the diagonal are ignored.
      scale_perturb_factor: Floating-point `Tensor` representing factor matrix
        with last two dimensions of shape `(k, r)`. When `None`, no rank-r
        update is added to `scale`.
      scale_perturb_diag: Floating-point `Tensor` representing the diagonal
        matrix. `scale_perturb_diag` has shape `[N1, N2, ...  r]`, which
        represents an `r x r` diagonal matrix. When `None` low rank updates will
        take the form `scale_perturb_factor * scale_perturb_factor.T`.
      adjoint: Python `bool` indicating whether to use the `scale` matrix as
        specified or its adjoint.
        Default value: `False`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
      dtype: `tf.DType` to prefer when converting args to `Tensor`s. Else, we
        fall back to a common dtype inferred from the args, finally falling back
        to float32.

    Raises:
      ValueError: if `perturb_diag` is specified but not `perturb_factor`.
      TypeError: if `shift` has different `dtype` from `scale` arguments.
    """
        # Ambiguous definition of low rank update.
        if scale_perturb_diag is not None and scale_perturb_factor is None:
            raise ValueError("When scale_perturb_diag is specified, "
                             "scale_perturb_factor must be specified.")

        # Special case, only handling a scaled identity matrix. We don't know its
        # dimensions, so this is special cased.
        # We don't check identity_multiplier, since below we set it to 1. if all
        # other scale args are None.
        self._is_only_identity_multiplier = (scale_tril is None
                                             and scale_diag is None
                                             and scale_perturb_factor is None)

        with tf.name_scope(name) as name:
            self._name = name
            self._validate_args = validate_args

            if dtype is None:
                dtype = dtype_util.common_dtype([
                    shift, scale_identity_multiplier, scale_diag, scale_tril,
                    scale_perturb_diag, scale_perturb_factor
                ], tf.float32)

            if shift is not None:
                shift = tf.convert_to_tensor(shift, name="shift", dtype=dtype)
            self._shift = shift

            # When no args are specified, pretend the scale matrix is the identity
            # matrix.
            if (self._is_only_identity_multiplier
                    and scale_identity_multiplier is None):
                scale_identity_multiplier = tf.convert_to_tensor(1.,
                                                                 dtype=dtype)

            # self._create_scale_operator returns a LinearOperator in all cases
            # except if self._is_only_identity_multiplier; in which case it
            # returns a scalar Tensor.
            scale = self._create_scale_operator(
                identity_multiplier=scale_identity_multiplier,
                diag=scale_diag,
                tril=scale_tril,
                perturb_diag=scale_perturb_diag,
                perturb_factor=scale_perturb_factor,
                shift=shift,
                validate_args=validate_args,
                dtype=dtype)

            if (scale is not None and not self._is_only_identity_multiplier
                    and not dtype_util.SKIP_DTYPE_CHECKS):
                if (shift is not None and
                        not dtype_util.base_equal(shift.dtype, scale.dtype)):
                    raise TypeError(
                        "shift.dtype ({}) is incompatible with scale.dtype ({})."
                        .format(shift.dtype, scale.dtype))

            self._scale = scale
            self._adjoint = adjoint
            super(Affine, self).__init__(forward_min_event_ndims=1,
                                         is_constant_jacobian=True,
                                         dtype=dtype,
                                         validate_args=validate_args,
                                         name=name)
コード例 #7
0
ファイル: special_math.py プロジェクト: xzxzmmnn/probability
def log_ndtr(x, series_order=3, name="log_ndtr"):
  """Log Normal distribution function.

  For details of the Normal distribution function see `ndtr`.

  This function calculates `(log o ndtr)(x)` by either calling `log(ndtr(x))` or
  using an asymptotic series. Specifically:
  - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on
    `log(1-x) ~= -x, x << 1`.
  - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique
    and take a log.
  - For `x <= lower_segment`, we use the series approximation of erf to compute
    the log CDF directly.

  The `lower_segment` is set based on the precision of the input:

  ```
  lower_segment = { -20,  x.dtype=float64
                  { -10,  x.dtype=float32
  upper_segment = {   8,  x.dtype=float64
                  {   5,  x.dtype=float32
  ```

  When `x < lower_segment`, the `ndtr` asymptotic series approximation is:

  ```
     ndtr(x) = scale * (1 + sum) + R_N
     scale   = exp(-0.5 x**2) / (-x sqrt(2 pi))
     sum     = Sum{(-1)^n (2n-1)!! / (x**2)^n, n=1:N}
     R_N     = O(exp(-0.5 x**2) (2N+1)!! / |x|^{2N+3})
  ```

  where `(2n-1)!! = (2n-1) (2n-3) (2n-5) ...  (3) (1)` is a
  [double-factorial](https://en.wikipedia.org/wiki/Double_factorial).


  Args:
    x: `Tensor` of type `float32`, `float64`.
    series_order: Positive Python `integer`. Maximum depth to
      evaluate the asymptotic expansion. This is the `N` above.
    name: Python string. A name for the operation (default="log_ndtr").

  Returns:
    log_ndtr: `Tensor` with `dtype=x.dtype`.

  Raises:
    TypeError: if `x.dtype` is not handled.
    TypeError: if `series_order` is a not Python `integer.`
    ValueError:  if `series_order` is not in `[0, 30]`.
  """
  if not isinstance(series_order, int):
    raise TypeError("series_order must be a Python integer.")
  if series_order < 0:
    raise ValueError("series_order must be non-negative.")
  if series_order > 30:
    raise ValueError("series_order must be <= 30.")

  with tf.name_scope(name):
    x = tf.convert_to_tensor(x, name="x")

    if dtype_util.base_equal(x.dtype, tf.float64):
      lower_segment = np.array(LOGNDTR_FLOAT64_LOWER, np.float64)
      upper_segment = np.array(LOGNDTR_FLOAT64_UPPER, np.float64)
    elif dtype_util.base_equal(x.dtype, tf.float32):
      lower_segment = np.array(LOGNDTR_FLOAT32_LOWER, np.float32)
      upper_segment = np.array(LOGNDTR_FLOAT32_UPPER, np.float32)
    else:
      raise TypeError("x.dtype=%s is not supported." % x.dtype)

    # The basic idea here was ported from:
    #   https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
    # We copy the main idea, with a few changes
    # * For x >> 1, and X ~ Normal(0, 1),
    #     Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x],
    #     which extends the range of validity of this function.
    # * We use one fixed series_order for all of 'x', rather than adaptive.
    # * Our docstring properly reflects that this is an asymptotic series, not a
    #   Taylor series. We also provided a correct bound on the remainder.
    # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when
    #   x=0. This happens even though the branch is unchosen because when x=0
    #   the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
    #   regardless of whether dy is finite. Note that the minimum is a NOP if
    #   the branch is chosen.
    return tf.where(
        x > upper_segment,
        -_ndtr(-x),  # log(1-x) ~= -x, x << 1
        tf.where(
            x > lower_segment,
            tf.math.log(_ndtr(tf.maximum(x, lower_segment))),
            _log_ndtr_lower(tf.minimum(x, lower_segment), series_order)))
コード例 #8
0
    def __init__(self,
                 loc,
                 scale,
                 quadrature_size=8,
                 quadrature_fn=quadrature_scheme_lognormal_quantiles,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="PoissonLogNormalQuadratureCompound"):
        """Constructs the PoissonLogNormalQuadratureCompound`.

    Note: `probs` returned by (optional) `quadrature_fn` are presumed to be
    either a length-`quadrature_size` vector or a batch of vectors in 1-to-1
    correspondence with the returned `grid`. (I.e., broadcasting is only
    partially supported.)

    Args:
      loc: `float`-like (batch of) scalar `Tensor`; the location parameter of
        the LogNormal prior.
      scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
        the LogNormal prior.
      quadrature_size: Python `int` scalar representing the number of quadrature
        points.
      quadrature_fn: Python callable taking `loc`, `scale`,
        `quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
        representing the LogNormal grid and corresponding normalized weight.
        normalized) weight.
        Default value: `quadrature_scheme_lognormal_quantiles`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `quadrature_grid` and `quadrature_probs` have different base
        `dtype`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale], tf.float32)
            if loc is not None:
                loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype)
            if scale is not None:
                scale = tf.convert_to_tensor(scale, dtype=dtype, name="scale")
            self._quadrature_grid, self._quadrature_probs = tuple(
                quadrature_fn(loc, scale, quadrature_size, validate_args))

            dt = self._quadrature_grid.dtype
            if not dtype_util.base_equal(dt, self._quadrature_probs.dtype):
                raise TypeError(
                    "Quadrature grid dtype ({}) does not match quadrature "
                    "probs dtype ({}).".format(
                        dtype_util.name(dt),
                        dtype_util.name(self._quadrature_probs.dtype)))

            self._distribution = poisson.Poisson(
                log_rate=self._quadrature_grid,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)

            self._mixture_distribution = categorical.Categorical(
                logits=tf.math.log(self._quadrature_probs),
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)

            self._loc = loc
            self._scale = scale
            self._quadrature_size = quadrature_size

            super(PoissonLogNormalQuadratureCompound, self).__init__(
                dtype=dt,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=[loc, scale],
                name=name)