def _set_diag_operators(self, diag, is_diag_positive): """Set attributes self._diag and self._diag_operator.""" if diag is not None: self._diag_operator = linear_operator_diag.LinearOperatorDiag( self._diag, is_positive_definite=is_diag_positive) self._diag_inv_operator = linear_operator_diag.LinearOperatorDiag( 1. / self._diag, is_positive_definite=is_diag_positive) else: if self.u.get_shape()[-1].value is not None: r = self.u.get_shape()[-1].value else: r = array_ops.shape(self.u)[-1] self._diag_operator = linear_operator_identity.LinearOperatorIdentity( num_rows=r, dtype=self.dtype) self._diag_inv_operator = self._diag_operator
def linop(self, num_rows=None, multiplier=None, diag=None): """Helper to create non-singular, symmetric, positive definite matrices.""" if num_rows is not None and multiplier is not None: if any(p is not None for p in [diag]): raise ValueError("Found extra args for scaled identity.") return linop_identity_lib.LinearOperatorScaledIdentity( num_rows=num_rows, multiplier=multiplier, is_positive_definite=True) elif num_rows is not None: if any(p is not None for p in [multiplier, diag]): raise ValueError("Found extra args for identity.") return linop_identity_lib.LinearOperatorIdentity( num_rows=num_rows, is_positive_definite=True) elif diag is not None: if any(p is not None for p in [num_rows, multiplier]): raise ValueError("Found extra args for diag.") return linop_diag_lib.LinearOperatorDiag(diag=diag, is_positive_definite=True) else: raise ValueError("Must specify at least one arg.")