Exemple #1
0
 def test_non_positive_definite_matrix_raises(self):
   # Singlular matrix with one positive eigenvalue and one zero eigenvalue.
   with self.test_session():
     diag = [1.0, 0.0]
     operator = operator_pd_diag.OperatorPDSqrtDiag(diag)
     with self.assertRaisesOpError('assert_positive'):
       operator.to_dense().eval()
Exemple #2
0
 def testNonPositiveDefiniteMatrixDoesNotRaiseIfNotVerifyPd(self):
     # Singlular matrix with one positive eigenvalue and one zero eigenvalue.
     with self.test_session():
         diag = [1.0, 0.0]
         operator = operator_pd_diag.OperatorPDSqrtDiag(diag,
                                                        verify_pd=False)
         operator.to_dense().eval()  # Should not raise
 def test_non_positive_definite_matrix_does_not_raise_if_not_verify_pd(
         self):
     # Singlular matrix with one positive eigenvalue and one zero eigenvalue.
     with self.test_session():
         diag = [1.0, 0.0]
         operator = operator_pd_diag.OperatorPDSqrtDiag(diag,
                                                        verify_pd=False)
         operator.to_dense().eval()  # Should not raise
    def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
        # Create a diagonal matrix explicitly.
        # Create an OperatorPDSqrtDiag using the same diagonal.
        # The operator should have the same behavior.
        #
        batch_shape = list(batch_shape)
        diag_shape = batch_shape + [k]

        # The diag is the square root.
        diag = self._random_pd_diag(diag_shape).astype(dtype)
        mat = self._diag_to_matrix(diag).astype(dtype)
        operator = operator_pd_diag.OperatorPDSqrtDiag(diag)

        return operator, mat
Exemple #5
0
    def __init__(self,
                 mu,
                 diag_stddev,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MultivariateNormalDiag"):
        """Multivariate Normal distributions on `R^k`.

    User must provide means `mu` and standard deviations `diag_stddev`.
    Each batch member represents a random vector `(X_1,...,X_k)` of independent
    random normals.
    The mean of `X_i` is `mu[i]`, and the standard deviation is
    `diag_stddev[i]`.

    Args:
      mu:  Rank `N + 1` floating point tensor with shape `[N1,...,Nb, k]`,
        `b >= 0`.
      diag_stddev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`,
        representing the standard deviations.  Must be positive.
      validate_args: `Boolean`, default `False`.  Whether to validate
        input with asserts.  If `validate_args` is `False`,
        and the inputs are invalid, correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: If `mu` and `diag_stddev` are different dtypes.
    """
        parameters = locals()
        parameters.pop("self")
        with ops.name_scope(name, values=[diag_stddev]) as ns:
            cov = operator_pd_diag.OperatorPDSqrtDiag(diag_stddev,
                                                      verify_pd=validate_args)
        super(MultivariateNormalDiag,
              self).__init__(mu,
                             cov,
                             allow_nan_stats=allow_nan_stats,
                             validate_args=validate_args,
                             name=ns)
        self._parameters = parameters
    def _create_scale_operator(self, identity_multiplier, diag, tril,
                               perturb_diag, perturb_factor, event_ndims,
                               validate_args):
        """Construct `scale` from various components.

    Args:
      identity_multiplier: floating point rank 0 `Tensor` representing a scaling
        done to the identity matrix.
      diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ...  k], which represents a k x k
        diagonal matrix.
      tril: Floating-point `Tensor` representing the diagonal matrix.
        `scale_tril` has shape [N1, N2, ...  k], which represents a k x k lower
        triangular matrix.
      perturb_diag: Floating-point `Tensor` representing the diagonal matrix of
        the low rank update.
      perturb_factor: Floating-point `Tensor` representing factor matrix.
      event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions
        associated with a particular draw from the distribution. Must be 0 or 1
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.

    Returns:
      scale. In the case of scaling by a constant, scale is a
      floating point `Tensor`. Otherwise, scale is an `OperatorPD`.

    Raises:
      ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`.
    """
        identity_multiplier = _as_tensor(identity_multiplier,
                                         "identity_multiplier")
        diag = _as_tensor(diag, "diag")
        tril = _as_tensor(tril, "tril")
        perturb_diag = _as_tensor(perturb_diag, "perturb_diag")
        perturb_factor = _as_tensor(perturb_factor, "perturb_factor")

        identity_multiplier = self._maybe_validate_identity_multiplier(
            identity_multiplier, validate_args)

        if perturb_factor is not None:
            perturb_factor = self._process_matrix(perturb_factor,
                                                  min_rank=2,
                                                  event_ndims=event_ndims)

        if perturb_diag is not None:
            perturb_diag = self._process_matrix(perturb_diag,
                                                min_rank=1,
                                                event_ndims=event_ndims)

        # The following if-statments are ordered by increasingly stronger
        # assumptions in the base matrix, i.e., we process in the order:
        # TriL, Diag, Identity.

        if tril is not None:
            tril = self._preprocess_tril(identity_multiplier, diag, tril,
                                         event_ndims)
            if perturb_factor is None:
                return operator_pd_cholesky.OperatorPDCholesky(
                    tril, verify_pd=validate_args)
            return _TriLPlusVDVTLightweightOperatorPD(
                tril=tril,
                v=perturb_factor,
                diag=perturb_diag,
                validate_args=validate_args)

        if diag is not None:
            diag = self._preprocess_diag(identity_multiplier, diag,
                                         event_ndims)
            if perturb_factor is None:
                return operator_pd_diag.OperatorPDSqrtDiag(
                    diag, verify_pd=validate_args)
            return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator=operator_pd_diag.OperatorPDDiag(
                    diag, verify_pd=validate_args),
                v=perturb_factor,
                diag=perturb_diag,
                verify_pd=validate_args)

        if identity_multiplier is not None:
            if perturb_factor is None:
                return identity_multiplier
            # Infer the shape from the V and D.
            v_shape = array_ops.shape(perturb_factor)
            identity_shape = array_ops.concat([v_shape[:-1], [v_shape[-2]]], 0)
            scaled_identity = operator_pd_identity.OperatorPDIdentity(
                identity_shape,
                perturb_factor.dtype.base_dtype,
                scale=identity_multiplier,
                verify_pd=validate_args)
            return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate(
                operator=scaled_identity,
                v=perturb_factor,
                diag=perturb_diag,
                verify_pd=validate_args)

        raise ValueError(
            "One of tril, diag and/or identity_multiplier must be "
            "specified.")