def _check_domain_range_possibly_add_asserts(self): """Static check of init arg `num_rows`, possibly add asserts.""" # Possibly add asserts. if self._assert_proper_shapes: self._num_rows = control_flow_ops.with_dependencies([ check_ops.assert_rank( self._num_rows, 0, message="Argument num_rows must be a 0-D Tensor."), check_ops.assert_non_negative( self._num_rows, message="Argument num_rows must be non-negative."), ], self._num_rows) self._num_columns = control_flow_ops.with_dependencies([ check_ops.assert_rank( self._num_columns, 0, message="Argument num_columns must be a 0-D Tensor."), check_ops.assert_non_negative( self._num_columns, message="Argument num_columns must be non-negative."), ], self._num_columns) # Static checks. if not np.issubdtype(self._num_rows.dtype, np.integer): raise TypeError("Argument num_rows must be integer type. Found:" " %s" % self._num_rows) if not np.issubdtype(self._num_columns.dtype, np.integer): raise TypeError("Argument num_columns must be integer type. Found:" " %s" % self._num_columns) num_rows_static = self._num_rows_static num_columns_static = self._num_columns_static if num_rows_static is not None: if num_rows_static.ndim != 0: raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" " %s" % num_rows_static) if num_rows_static < 0: raise ValueError("Argument num_rows must be non-negative. Found:" " %s" % num_rows_static) if num_columns_static is not None: if num_columns_static.ndim != 0: raise ValueError("Argument num_columns must be a 0-D Tensor. Found:" " %s" % num_columns_static) if num_columns_static < 0: raise ValueError("Argument num_columns must be non-negative. Found:" " %s" % num_columns_static)
def _check_batch_shape_possibly_add_asserts(self): """Static check of init arg `batch_shape`, possibly add asserts.""" if self._batch_shape_arg is None: return # Possibly add asserts if self._assert_proper_shapes: self._batch_shape_arg = control_flow_ops.with_dependencies([ check_ops.assert_rank( self._batch_shape_arg, 1, message="Argument batch_shape must be a 1-D Tensor."), check_ops.assert_non_negative( self._batch_shape_arg, message="Argument batch_shape must be non-negative."), ], self._batch_shape_arg) # Static checks if not np.issubdtype(self._batch_shape_arg.dtype, np.integer): raise TypeError("Argument batch_shape must be integer type. Found:" " %s" % self._batch_shape_arg) if self._batch_shape_static is None: return # Cannot do any other static checks. if self._batch_shape_static.ndim != 1: raise ValueError("Argument batch_shape must be a 1-D Tensor. Found:" " %s" % self._batch_shape_static) if np.any(self._batch_shape_static < 0): raise ValueError("Argument batch_shape must be non-negative. Found:" "%s" % self._batch_shape_static)
def _matmul(self, x, adjoint=False, adjoint_arg=False): if self._assert_proper_shapes: x = linalg.adjoint(x) if adjoint_arg else x aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) x = control_flow_ops.with_dependencies([aps], x) if self.is_square: # Note that adjoint has no effect since this matrix is self-adjoint. if adjoint_arg: output_shape = array_ops.concat([ array_ops.shape(x)[:-2], [array_ops.shape(x)[-1], array_ops.shape(x)[-2]]], axis=0) else: output_shape = array_ops.shape(x) return self._possibly_broadcast_batch_shape( array_ops.zeros(shape=output_shape, dtype=x.dtype)) x_shape = array_ops.shape(x) n = self._num_columns if adjoint else self._num_rows m = x_shape[-2] if adjoint_arg else x_shape[-1] output_shape = array_ops.concat([x_shape[:-2], [n, m]], axis=0) zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype) return self._possibly_broadcast_batch_shape(zeros)
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Note that adjoint has no effect since this matrix is self-adjoint. x = linalg.adjoint(x) if adjoint_arg else x if self._assert_proper_shapes: aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) x = control_flow_ops.with_dependencies([aps], x) return self._possibly_broadcast_batch_shape(x)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs if self._assert_proper_shapes: aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs) rhs = control_flow_ops.with_dependencies([aps], rhs) return rhs / self._make_multiplier_matrix(conjugate=adjoint)
def _matmul(self, x, adjoint=False, adjoint_arg=False): x = linalg.adjoint(x) if adjoint_arg else x if self._assert_proper_shapes: aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) x = control_flow_ops.with_dependencies([aps], x) return x * self._make_multiplier_matrix(conjugate=adjoint)