def test_non_positive_definite_matrix_raises(self): # Singlular matrix with one positive eigenvalue and one zero eigenvalue. with self.test_session(): diag = [1.0, 0.0] operator = operator_pd_diag.OperatorPDSqrtDiag(diag) with self.assertRaisesOpError('assert_positive'): operator.to_dense().eval()
def testNonPositiveDefiniteMatrixDoesNotRaiseIfNotVerifyPd(self): # Singlular matrix with one positive eigenvalue and one zero eigenvalue. with self.test_session(): diag = [1.0, 0.0] operator = operator_pd_diag.OperatorPDSqrtDiag(diag, verify_pd=False) operator.to_dense().eval() # Should not raise
def test_non_positive_definite_matrix_does_not_raise_if_not_verify_pd( self): # Singlular matrix with one positive eigenvalue and one zero eigenvalue. with self.test_session(): diag = [1.0, 0.0] operator = operator_pd_diag.OperatorPDSqrtDiag(diag, verify_pd=False) operator.to_dense().eval() # Should not raise
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64): # Create a diagonal matrix explicitly. # Create an OperatorPDSqrtDiag using the same diagonal. # The operator should have the same behavior. # batch_shape = list(batch_shape) diag_shape = batch_shape + [k] # The diag is the square root. diag = self._random_pd_diag(diag_shape).astype(dtype) mat = self._diag_to_matrix(diag).astype(dtype) operator = operator_pd_diag.OperatorPDSqrtDiag(diag) return operator, mat
def __init__(self, mu, diag_stddev, validate_args=False, allow_nan_stats=True, name="MultivariateNormalDiag"): """Multivariate Normal distributions on `R^k`. User must provide means `mu` and standard deviations `diag_stddev`. Each batch member represents a random vector `(X_1,...,X_k)` of independent random normals. The mean of `X_i` is `mu[i]`, and the standard deviation is `diag_stddev[i]`. Args: mu: Rank `N + 1` floating point tensor with shape `[N1,...,Nb, k]`, `b >= 0`. diag_stddev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`, representing the standard deviations. Must be positive. validate_args: `Boolean`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: TypeError: If `mu` and `diag_stddev` are different dtypes. """ parameters = locals() parameters.pop("self") with ops.name_scope(name, values=[diag_stddev]) as ns: cov = operator_pd_diag.OperatorPDSqrtDiag(diag_stddev, verify_pd=validate_args) super(MultivariateNormalDiag, self).__init__(mu, cov, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=ns) self._parameters = parameters
def _create_scale_operator(self, identity_multiplier, diag, tril, perturb_diag, perturb_factor, event_ndims, validate_args): """Construct `scale` from various components. Args: identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ... k], which represents a k x k diagonal matrix. tril: Floating-point `Tensor` representing the diagonal matrix. `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower triangular matrix. perturb_diag: Floating-point `Tensor` representing the diagonal matrix of the low rank update. perturb_factor: Floating-point `Tensor` representing factor matrix. event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions associated with a particular draw from the distribution. Must be 0 or 1 validate_args: Python `bool` indicating whether arguments should be checked for correctness. Returns: scale. In the case of scaling by a constant, scale is a floating point `Tensor`. Otherwise, scale is an `OperatorPD`. Raises: ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. """ identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") diag = _as_tensor(diag, "diag") tril = _as_tensor(tril, "tril") perturb_diag = _as_tensor(perturb_diag, "perturb_diag") perturb_factor = _as_tensor(perturb_factor, "perturb_factor") identity_multiplier = self._maybe_validate_identity_multiplier( identity_multiplier, validate_args) if perturb_factor is not None: perturb_factor = self._process_matrix(perturb_factor, min_rank=2, event_ndims=event_ndims) if perturb_diag is not None: perturb_diag = self._process_matrix(perturb_diag, min_rank=1, event_ndims=event_ndims) # The following if-statments are ordered by increasingly stronger # assumptions in the base matrix, i.e., we process in the order: # TriL, Diag, Identity. if tril is not None: tril = self._preprocess_tril(identity_multiplier, diag, tril, event_ndims) if perturb_factor is None: return operator_pd_cholesky.OperatorPDCholesky( tril, verify_pd=validate_args) return _TriLPlusVDVTLightweightOperatorPD( tril=tril, v=perturb_factor, diag=perturb_diag, validate_args=validate_args) if diag is not None: diag = self._preprocess_diag(identity_multiplier, diag, event_ndims) if perturb_factor is None: return operator_pd_diag.OperatorPDSqrtDiag( diag, verify_pd=validate_args) return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator=operator_pd_diag.OperatorPDDiag( diag, verify_pd=validate_args), v=perturb_factor, diag=perturb_diag, verify_pd=validate_args) if identity_multiplier is not None: if perturb_factor is None: return identity_multiplier # Infer the shape from the V and D. v_shape = array_ops.shape(perturb_factor) identity_shape = array_ops.concat([v_shape[:-1], [v_shape[-2]]], 0) scaled_identity = operator_pd_identity.OperatorPDIdentity( identity_shape, perturb_factor.dtype.base_dtype, scale=identity_multiplier, verify_pd=validate_args) return operator_pd_vdvt_update.OperatorPDSqrtVDVTUpdate( operator=scaled_identity, v=perturb_factor, diag=perturb_diag, verify_pd=validate_args) raise ValueError( "One of tril, diag and/or identity_multiplier must be " "specified.")