def make_diag_scale(loc, scale_diag, scale_identity_multiplier,
                    validate_args, assert_positive, name=None):
  """Creates a LinOp from `scale_diag`, `scale_identity_multiplier` kwargs."""
  def _convert_to_tensor(x, name):
    return None if x is None else ops.convert_to_tensor(x, name=name)

  def _maybe_attach_assertion(x):
    if not validate_args:
      return x
    if assert_positive:
      return control_flow_ops.with_dependencies([
          check_ops.assert_positive(
              x, message="diagonal part must be positive"),
      ], x)
    return control_flow_ops.with_dependencies([
        check_ops.assert_none_equal(
            x,
            array_ops.zeros([], x.dtype),
            message="diagonal part must be non-zero")], x)

  with ops.name_scope(name, "make_diag_scale",
                      values=[loc, scale_diag, scale_identity_multiplier]):
    loc = _convert_to_tensor(loc, name="loc")
    scale_diag = _convert_to_tensor(scale_diag, name="scale_diag")
    scale_identity_multiplier = _convert_to_tensor(
        scale_identity_multiplier,
        name="scale_identity_multiplier")

    if scale_diag is not None:
      if scale_identity_multiplier is not None:
        scale_diag += scale_identity_multiplier[..., array_ops.newaxis]
      return linalg.LinearOperatorDiag(
          diag=_maybe_attach_assertion(scale_diag),
          is_non_singular=True,
          is_self_adjoint=True,
          is_positive_definite=assert_positive)

    # TODO(b/35290280): Consider inferring shape from scale_perturb_factor.
    if loc is None:
      raise ValueError(
          "Cannot infer `event_shape` unless `loc` is specified.")

    num_rows = util.dimension_size(loc, -1)

    if scale_identity_multiplier is None:
      return linalg.LinearOperatorIdentity(
          num_rows=num_rows,
          dtype=loc.dtype.base_dtype,
          is_self_adjoint=True,
          is_positive_definite=True,
          assert_proper_shapes=validate_args)

    return linalg.LinearOperatorScaledIdentity(
        num_rows=num_rows,
        multiplier=_maybe_attach_assertion(scale_identity_multiplier),
        is_non_singular=True,
        is_self_adjoint=True,
        is_positive_definite=assert_positive,
        assert_proper_shapes=validate_args)
Beispiel #2
0
    def __init__(self,
                 loc=None,
                 scale_tril=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MultivariateNormalTriL"):
        """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:

    ```none
    scale = scale_tril
    ```

    where `scale_tril` is lower-triangular `k x k` matrix with non-zero
    diagonal, i.e., `tf.diag_part(scale_tril) != 0`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale_tril: Floating-point, lower-triangular `Tensor` with non-zero
        diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where
        `b >= 0` and `k` is the event size.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if neither `loc` nor `scale_tril` are specified.
    """
        parameters = locals()

        def _convert_to_tensor(x, name):
            return None if x is None else ops.convert_to_tensor(x, name=name)

        if loc is None and scale_tril is None:
            raise ValueError(
                "Must specify one or both of `loc`, `scale_tril`.")
        with ops.name_scope(name):
            with ops.name_scope("init", values=[loc, scale_tril]):
                loc = _convert_to_tensor(loc, name="loc")
                scale_tril = _convert_to_tensor(scale_tril, name="scale_tril")
                if scale_tril is None:
                    scale = linalg.LinearOperatorIdentity(
                        num_rows=distribution_util.dimension_size(loc, -1),
                        dtype=loc.dtype,
                        is_self_adjoint=True,
                        is_positive_definite=True,
                        assert_proper_shapes=validate_args)
                else:
                    # No need to validate that scale_tril is non-singular.
                    # LinearOperatorLowerTriangular has an assert_non_singular
                    # method that is called by the Bijector.
                    scale = linalg.LinearOperatorLowerTriangular(
                        scale_tril,
                        is_non_singular=True,
                        is_self_adjoint=False,
                        is_positive_definite=False)
        super(MultivariateNormalTriL,
              self).__init__(loc=loc,
                             scale=scale,
                             validate_args=validate_args,
                             allow_nan_stats=allow_nan_stats,
                             name=name)
        self._parameters = parameters
Beispiel #3
0
  def __init__(self,
               loc=None,
               scale_tril=None,
               validate_args=False,
               allow_nan_stats=True,
               name="MultivariateNormalTriL"):
    """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:

    ```none
    scale = scale_tril
    ```

    where `scale_tril` is lower-triangular `k x k` matrix with non-zero
    diagonal, i.e., `tf.diag_part(scale_tril) != 0`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale_tril: Floating-point, lower-triangular `Tensor` with non-zero
        diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where
        `b >= 0` and `k` is the event size.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if neither `loc` nor `scale_tril` are specified.
    """
    parameters = dict(locals())
    def _convert_to_tensor(x, name):
      return None if x is None else ops.convert_to_tensor(x, name=name)
    if loc is None and scale_tril is None:
      raise ValueError("Must specify one or both of `loc`, `scale_tril`.")
    with ops.name_scope(name) as name:
      with ops.name_scope("init", values=[loc, scale_tril]):
        loc = _convert_to_tensor(loc, name="loc")
        scale_tril = _convert_to_tensor(scale_tril, name="scale_tril")
        if scale_tril is None:
          scale = linalg.LinearOperatorIdentity(
              num_rows=distribution_util.dimension_size(loc, -1),
              dtype=loc.dtype,
              is_self_adjoint=True,
              is_positive_definite=True,
              assert_proper_shapes=validate_args)
        else:
          # No need to validate that scale_tril is non-singular.
          # LinearOperatorLowerTriangular has an assert_non_singular
          # method that is called by the Bijector.
          scale = linalg.LinearOperatorLowerTriangular(
              scale_tril,
              is_non_singular=True,
              is_self_adjoint=False,
              is_positive_definite=False)
    super(MultivariateNormalTriL, self).__init__(
        loc=loc,
        scale=scale,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        name=name)
    self._parameters = parameters
Beispiel #4
0
def make_diag_scale(loc,
                    scale_diag,
                    scale_identity_multiplier,
                    validate_args,
                    assert_positive,
                    name=None):
    """Creates a LinOp from `scale_diag`, `scale_identity_multiplier` kwargs."""
    def _convert_to_tensor(x, name):
        return None if x is None else ops.convert_to_tensor(x, name=name)

    def _maybe_attach_assertion(x):
        if not validate_args:
            return x
        if assert_positive:
            return control_flow_ops.with_dependencies([
                check_ops.assert_positive(
                    x, message="diagonal part must be positive"),
            ], x)
        return control_flow_ops.with_dependencies([
            check_ops.assert_none_equal(
                x,
                array_ops.zeros([], x.dtype),
                message="diagonal part must be non-zero")
        ], x)

    with ops.name_scope(name,
                        "make_diag_scale",
                        values=[loc, scale_diag, scale_identity_multiplier]):
        loc = _convert_to_tensor(loc, name="loc")
        scale_diag = _convert_to_tensor(scale_diag, name="scale_diag")
        scale_identity_multiplier = _convert_to_tensor(
            scale_identity_multiplier, name="scale_identity_multiplier")

        if scale_diag is not None:
            if scale_identity_multiplier is not None:
                scale_diag += scale_identity_multiplier[..., array_ops.newaxis]
            return linalg.LinearOperatorDiag(
                diag=_maybe_attach_assertion(scale_diag),
                is_non_singular=True,
                is_self_adjoint=True,
                is_positive_definite=assert_positive)

        # TODO(b/35290280): Consider inferring shape from scale_perturb_factor.
        if loc is None:
            raise ValueError(
                "Cannot infer `event_shape` unless `loc` is specified.")

        num_rows = util.dimension_size(loc, -1)

        if scale_identity_multiplier is None:
            return linalg.LinearOperatorIdentity(
                num_rows=num_rows,
                dtype=loc.dtype.base_dtype,
                is_self_adjoint=True,
                is_positive_definite=True,
                assert_proper_shapes=validate_args)

        return linalg.LinearOperatorScaledIdentity(
            num_rows=num_rows,
            multiplier=_maybe_attach_assertion(scale_identity_multiplier),
            is_non_singular=True,
            is_self_adjoint=True,
            is_positive_definite=assert_positive,
            assert_proper_shapes=validate_args)