def _default_event_space_bijector(self):
   return chain_bijector.Chain([
       shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args),
       scale_matvec_linear_operator.ScaleMatvecLinearOperator(
           scale=self.scale, validate_args=self.validate_args),
       softplus_bijector.Softplus(validate_args=self.validate_args)
   ], validate_args=self.validate_args)
Ejemplo n.º 2
0
    def __init__(self,
                 loc=None,
                 scale=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='MultivariateNormalLinearOperator'):
        """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
        `[B1, ..., Bb, k, k]`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if `scale` is unspecified.
      TypeError: if not `scale.dtype.is_floating`
    """
        parameters = dict(locals())
        if scale is None:
            raise ValueError('Missing required `scale` parameter.')
        if not dtype_util.is_floating(scale.dtype):
            raise TypeError(
                '`scale` parameter must have floating-point dtype.')

        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale],
                                            dtype_hint=tf.float32)
            # Since expand_dims doesn't preserve constant-ness, we obtain the
            # non-dynamic value if possible.
            loc = tensor_util.convert_nonref_to_tensor(loc,
                                                       dtype=dtype,
                                                       name='loc')
            batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
                loc, scale)
        self._loc = loc
        self._scale = scale

        bijector = scale_matvec_linear_operator.ScaleMatvecLinearOperator(
            scale, validate_args=validate_args)
        if loc is not None:
            bijector = shift_bijector.Shift(
                shift=loc, validate_args=validate_args)(bijector)

        super(MultivariateNormalLinearOperator, self).__init__(
            distribution=normal.Normal(loc=tf.zeros([], dtype=dtype),
                                       scale=tf.ones([], dtype=dtype)),
            bijector=bijector,
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
            name=name)
        self._parameters = parameters
Ejemplo n.º 3
0
  def __init__(self,
               loc=None,
               scale=None,
               validate_args=False,
               allow_nan_stats=True,
               experimental_use_kahan_sum=False,
               name='MultivariateNormalLinearOperator'):
    """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
        `[B1, ..., Bb, k, k]`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan
        summation to aggregate independent underlying log_prob values. For best
        results, Kahan summation should also be applied when computing the
        log-determinant of the `LinearOperator` representing the scale matrix.
        Kahan summation improves against the precision of a naive float32 sum.
        This can be noticeable in particular for large dimensions in float32.
        See CPU caveat on `tfp.math.reduce_kahan_sum`.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if `scale` is unspecified.
      TypeError: if not `scale.dtype.is_floating`
    """
    parameters = dict(locals())
    self._experimental_use_kahan_sum = experimental_use_kahan_sum
    if scale is None:
      raise ValueError('Missing required `scale` parameter.')
    if not dtype_util.is_floating(scale.dtype):
      raise TypeError('`scale` parameter must have floating-point dtype.')

    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32)
      # Since expand_dims doesn't preserve constant-ness, we obtain the
      # non-dynamic value if possible.
      loc = tensor_util.convert_nonref_to_tensor(
          loc, dtype=dtype, name='loc')
      batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
          loc, scale)
    self._loc = loc
    self._scale = scale

    bijector = scale_matvec_linear_operator.ScaleMatvecLinearOperator(
        scale, validate_args=validate_args)
    if loc is not None:
      bijector = shift_bijector.Shift(
          shift=loc, validate_args=validate_args)(bijector)
    super(MultivariateNormalLinearOperator, self).__init__(
        # TODO(b/137665504): Use batch-adding meta-distribution to set the batch
        # shape instead of tf.zeros.
        # We use `Sample` instead of `Independent` because `Independent`
        # requires concatenating `batch_shape` and `event_shape`, which loses
        # static `batch_shape` information when `event_shape` is not statically
        # known.
        distribution=sample.Sample(
            normal.Normal(
                loc=tf.zeros(batch_shape, dtype=dtype),
                scale=tf.ones([], dtype=dtype)),
            event_shape,
            experimental_use_kahan_sum=experimental_use_kahan_sum),
        bijector=bijector,
        validate_args=validate_args,
        name=name)
    self._parameters = parameters
Ejemplo n.º 4
0
  def __init__(self,
               loc=None,
               precision_factor=None,
               precision=None,
               validate_args=False,
               allow_nan_stats=True,
               name='MultivariateNormalPrecisionFactorLinearOperator'):
    """Initialize distribution.

    Precision is the inverse of the covariance matrix, and
    `precision_factor @ precision_factor.T = precision`.

    The `batch_shape` of this distribution is the broadcast of
    `loc.shape[:-1]` and `precision_factor.batch_shape`.

    The `event_shape` of this distribution is determined by `loc.shape[-1:]`,
    OR `precision_factor.shape[-1:]`, which must match.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      precision_factor: Required nonsingular `tf.linalg.LinearOperator` instance
        with same `dtype` and shape compatible with `loc`.
      precision: Optional square `tf.linalg.LinearOperator` instance with same
        `dtype` and shape compatible with `loc` and `precision_factor`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      if precision_factor is None:
        raise ValueError(
            'Argument `precision_factor` must be provided. Found `None`')

      dtype = dtype_util.common_dtype([loc, precision_factor, precision],
                                      dtype_hint=tf.float32)
      loc = tensor_util.convert_nonref_to_tensor(loc, dtype=dtype, name='loc')

      self._loc = loc
      self._precision_factor = precision_factor
      self._precision = precision

      batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
          loc, precision_factor)

      # Proof of factors (used throughout code):
      # Let,
      #   C = covariance,
      #   P = inv(covariance) = precision
      #   P = F @ F.T  (so F is the `precision_factor`).
      #
      # Then, the log prob term is
      #  x.T @ inv(C) @ x
      #  = x.T @ P @ x
      #  = x.T @ F @ F.T @ x
      #  = || F.T @ x ||**2
      # notice it involves F.T, which is why we set adjoint=True in various
      # places.
      #
      # Also, if w ~ Normal(0, I), then we can sample by setting
      #  x = inv(F.T) @ w + loc,
      # since then
      #  E[(x - loc) @ (x - loc).T]
      #  = E[inv(F.T) @ w @ w.T @ inv(F)]
      #  = inv(F.T) @ inv(F)
      #  = inv(F @ F.T)
      #  = inv(P)
      #  = C.

      if precision is not None:
        precision.shape.assert_is_compatible_with(precision_factor.shape)

      bijector = invert.Invert(
          scale_matvec_linear_operator.ScaleMatvecLinearOperator(
              scale=precision_factor,
              validate_args=validate_args,
              adjoint=True)
      )
      if loc is not None:
        shift = shift_bijector.Shift(shift=loc, validate_args=validate_args)
        bijector = shift(bijector)

      super(MultivariateNormalPrecisionFactorLinearOperator, self).__init__(
          distribution=mvn_diag.MultivariateNormalDiag(
              loc=tf.zeros(
                  ps.concat([batch_shape, event_shape], axis=0), dtype=dtype)),
          bijector=bijector,
          validate_args=validate_args,
          name=name)
      self._parameters = parameters
    def __init__(self,
                 loc=None,
                 scale=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='VectorExponentialLinearOperator'):
        """Construct Vector Exponential distribution supported on a subset of `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
        `[B1, ..., Bb, k, k]`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if `scale` is unspecified.
      TypeError: if not `scale.dtype.is_floating`
    """
        parameters = dict(locals())
        if loc is None:
            loc = 0.0  # Implicit value for backwards compatibility.
        if scale is None:
            raise ValueError('Missing required `scale` parameter.')
        if not dtype_util.is_floating(scale.dtype):
            raise TypeError(
                '`scale` parameter must have floating-point dtype.')

        with tf.name_scope(name) as name:
            # Since expand_dims doesn't preserve constant-ness, we obtain the
            # non-dynamic value if possible.
            loc = loc if loc is None else tf.convert_to_tensor(
                loc, name='loc', dtype=scale.dtype)
            batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
                loc, scale)
            self._loc = loc
            self._scale = scale
            super(VectorExponentialLinearOperator, self).__init__(
                # TODO(b/137665504): Use batch-adding meta-distribution to set the
                # batch shape instead of tf.ones.
                # We use `Sample` instead of `Independent` because `Independent`
                # requires concatenating `batch_shape` and `event_shape`, which loses
                # static `batch_shape` information when `event_shape` is not
                # statically known.
                distribution=sample.Sample(
                    exponential.Exponential(rate=tf.ones(batch_shape,
                                                         dtype=scale.dtype),
                                            allow_nan_stats=allow_nan_stats),
                    event_shape),
                bijector=shift_bijector.Shift(shift=loc)(
                    scale_matvec_linear_operator.ScaleMatvecLinearOperator(
                        scale=scale, validate_args=validate_args)),
                validate_args=validate_args,
                name=name)
            self._parameters = parameters
Ejemplo n.º 6
0
 def _forward(self, x):
   scale = tf.linalg.LinearOperatorHouseholder(self.reflection_axis)
   reflection = smlo.ScaleMatvecLinearOperator(scale)
   return reflection.forward(x)