def _set_diag_operators(self, diag, is_diag_positive):
     """Set attributes self._diag and self._diag_operator."""
     if diag is not None:
         self._diag_operator = linear_operator_diag.LinearOperatorDiag(
             self._diag, is_positive_definite=is_diag_positive)
         self._diag_inv_operator = linear_operator_diag.LinearOperatorDiag(
             1. / self._diag, is_positive_definite=is_diag_positive)
     else:
         if self.u.get_shape()[-1].value is not None:
             r = self.u.get_shape()[-1].value
         else:
             r = array_ops.shape(self.u)[-1]
         self._diag_operator = linear_operator_identity.LinearOperatorIdentity(
             num_rows=r, dtype=self.dtype)
         self._diag_inv_operator = self._diag_operator
Esempio n. 2
0
 def linop(self, num_rows=None, multiplier=None, diag=None):
     """Helper to create non-singular, symmetric, positive definite matrices."""
     if num_rows is not None and multiplier is not None:
         if any(p is not None for p in [diag]):
             raise ValueError("Found extra args for scaled identity.")
         return linop_identity_lib.LinearOperatorScaledIdentity(
             num_rows=num_rows,
             multiplier=multiplier,
             is_positive_definite=True)
     elif num_rows is not None:
         if any(p is not None for p in [multiplier, diag]):
             raise ValueError("Found extra args for identity.")
         return linop_identity_lib.LinearOperatorIdentity(
             num_rows=num_rows, is_positive_definite=True)
     elif diag is not None:
         if any(p is not None for p in [num_rows, multiplier]):
             raise ValueError("Found extra args for diag.")
         return linop_diag_lib.LinearOperatorDiag(diag=diag,
                                                  is_positive_definite=True)
     else:
         raise ValueError("Must specify at least one arg.")