def _QrGrad(op, dq, dr): """Gradient for Qr.""" q, r = op.outputs if q.dtype.is_complex: raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype) if (r.shape.ndims is None or r.shape.as_list()[-2] is None or r.shape.as_list()[-1] is None): raise NotImplementedError("QrGrad not implemented with dynamic shapes.") if r.shape.dims[-2].value != r.shape.dims[-1].value: raise NotImplementedError("QrGrad not implemented when ncols > nrows " "or full_matrices is true and ncols != nrows.") qdq = math_ops.matmul(q, dq, adjoint_a=True) qdq_ = qdq - _linalg.adjoint(qdq) rdr = math_ops.matmul(r, dr, adjoint_b=True) rdr_ = rdr - _linalg.adjoint(rdr) tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) def _TriangularSolve(x, r): """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" return _linalg.adjoint( linalg_ops.matrix_triangular_solve( r, _linalg.adjoint(x), lower=False, adjoint=False)) grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) return grad_a + grad_b
def _test_solve(self, with_batch): for use_placeholder in self._use_placeholder_options: for build_info in self._operator_build_infos: for dtype in self._dtypes_to_test: for adjoint in self._adjoint_options: for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( build_info, dtype, use_placeholder=use_placeholder) rhs = self._make_rhs( operator, adjoint=adjoint, with_batch=with_batch) # If adjoint_arg, solve A X = (rhs^H)^H = rhs. if adjoint_arg: op_solve = operator.solve( linalg.adjoint(rhs), adjoint=adjoint, adjoint_arg=adjoint_arg) else: op_solve = operator.solve( rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) mat_solve = linear_operator_util.matrix_solve_with_broadcast( mat, rhs, adjoint=adjoint) if not use_placeholder: self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape()) op_solve_v, mat_solve_v = sess.run( [op_solve, mat_solve], feed_dict=feed_dict) self.assertAC(op_solve_v, mat_solve_v)
def _test_matmul_base( self, use_placeholder, shapes_info, dtype, adjoint, adjoint_arg, with_batch): # If batch dimensions are omitted, but there are # no batch dimensions for the linear operator, then # skip the test case. This is already checked with # with_batch=True. if not with_batch and len(shapes_info.shape) <= 2: return with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self.operator_and_matrix( shapes_info, dtype, use_placeholder=use_placeholder) x = self.make_x( operator, adjoint=adjoint, with_batch=with_batch) # If adjoint_arg, compute A X^H^H = A X. if adjoint_arg: op_matmul = operator.matmul( linalg.adjoint(x), adjoint=adjoint, adjoint_arg=adjoint_arg) else: op_matmul = operator.matmul(x, adjoint=adjoint) mat_matmul = math_ops.matmul(mat, x, adjoint_a=adjoint) if not use_placeholder: self.assertAllEqual(op_matmul.get_shape(), mat_matmul.get_shape()) op_matmul_v, mat_matmul_v = sess.run( [op_matmul, mat_matmul]) self.assertAC(op_matmul_v, mat_matmul_v)
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Note that adjoint has no effect since this matrix is self-adjoint. x = linalg.adjoint(x) if adjoint_arg else x if self._assert_proper_shapes: aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) x = control_flow_ops.with_dependencies([aps], x) return self._possibly_broadcast_batch_shape(x)
def _test_matmul(self, with_batch): for use_placeholder in self._use_placeholder_options: for build_info in self._operator_build_infos: for dtype in self._dtypes_to_test: for adjoint in self._adjoint_options: for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( build_info, dtype, use_placeholder=use_placeholder) x = self._make_x( operator, adjoint=adjoint, with_batch=with_batch) # If adjoint_arg, compute A X^H^H = A X. if adjoint_arg: op_matmul = operator.matmul( linalg.adjoint(x), adjoint=adjoint, adjoint_arg=adjoint_arg) else: op_matmul = operator.matmul(x, adjoint=adjoint) mat_matmul = linear_operator_util.matmul_with_broadcast( mat, x, adjoint_a=adjoint) if not use_placeholder: self.assertAllEqual(op_matmul.get_shape(), mat_matmul.get_shape()) op_matmul_v, mat_matmul_v = sess.run( [op_matmul, mat_matmul], feed_dict=feed_dict) self.assertAC(op_matmul_v, mat_matmul_v)
def _test_solve(self, with_batch): for use_placeholder in self._use_placeholder_options: for build_info in self._operator_build_infos: # If batch dimensions are omitted, but there are # no batch dimensions for the linear operator, then # skip the test case. This is already checked with # with_batch=True. if not with_batch and len(build_info.shape) <= 2: continue for dtype in self._dtypes_to_test: for adjoint in self._adjoint_options: for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( build_info, dtype, use_placeholder=use_placeholder) rhs = self._make_rhs( operator, adjoint=adjoint, with_batch=with_batch) # If adjoint_arg, solve A X = (rhs^H)^H = rhs. if adjoint_arg: op_solve = operator.solve( linalg.adjoint(rhs), adjoint=adjoint, adjoint_arg=adjoint_arg) else: op_solve = operator.solve( rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) mat_solve = linear_operator_util.matrix_solve_with_broadcast( mat, rhs, adjoint=adjoint) if not use_placeholder: self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape()) op_solve_v, mat_solve_v = sess.run( [op_solve, mat_solve], feed_dict=feed_dict) self.assertAC(op_solve_v, mat_solve_v)
def _test_matmul(self, with_batch): for use_placeholder in self._use_placeholder_options: for build_info in self._operator_build_infos: # If batch dimensions are omitted, but there are # no batch dimensions for the linear operator, then # skip the test case. This is already checked with # with_batch=True. if not with_batch and len(build_info.shape) <= 2: continue for dtype in self._dtypes_to_test: for adjoint in self._adjoint_options: for adjoint_arg in self._adjoint_arg_options: with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) x = self._make_x( operator, adjoint=adjoint, with_batch=with_batch) # If adjoint_arg, compute A X^H^H = A X. if adjoint_arg: op_matmul = operator.matmul( linalg.adjoint(x), adjoint=adjoint, adjoint_arg=adjoint_arg) else: op_matmul = operator.matmul(x, adjoint=adjoint) mat_matmul = linear_operator_util.matmul_with_broadcast( mat, x, adjoint_a=adjoint) if not use_placeholder: self.assertAllEqual(op_matmul.get_shape(), mat_matmul.get_shape()) op_matmul_v, mat_matmul_v = sess.run( [op_matmul, mat_matmul]) self.assertAC(op_matmul_v, mat_matmul_v)
def test_solve(self): self._skip_if_tests_to_skip_contains("solve") for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: for adjoint in self._adjoint_options: for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( shape, dtype, use_placeholder=use_placeholder) rhs = self._make_rhs(operator, adjoint=adjoint) # If adjoint_arg, solve A X = (rhs^H)^H = rhs. if adjoint_arg: op_solve = operator.solve( linalg.adjoint(rhs), adjoint=adjoint, adjoint_arg=adjoint_arg) else: op_solve = operator.solve( rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) mat_solve = linalg_ops.matrix_solve(mat, rhs, adjoint=adjoint) if not use_placeholder: self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape()) op_solve_v, mat_solve_v = sess.run( [op_solve, mat_solve], feed_dict=feed_dict) self.assertAC(op_solve_v, mat_solve_v)
def test_matmul(self): self._skip_if_tests_to_skip_contains("matmul") for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: for adjoint in self._adjoint_options: for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( shape, dtype, use_placeholder=use_placeholder) x = self._make_x(operator, adjoint=adjoint) # If adjoint_arg, compute A X^H^H = A X. if adjoint_arg: op_matmul = operator.matmul( linalg.adjoint(x), adjoint=adjoint, adjoint_arg=adjoint_arg) else: op_matmul = operator.matmul(x, adjoint=adjoint) mat_matmul = math_ops.matmul(mat, x, adjoint_a=adjoint) if not use_placeholder: self.assertAllEqual(op_matmul.get_shape(), mat_matmul.get_shape()) op_matmul_v, mat_matmul_v = sess.run( [op_matmul, mat_matmul], feed_dict=feed_dict) self.assertAC(op_matmul_v, mat_matmul_v)
def _matmul(self, x, adjoint=False, adjoint_arg=False): if self._assert_proper_shapes: x = linalg.adjoint(x) if adjoint_arg else x aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) x = control_flow_ops.with_dependencies([aps], x) if self.is_square: # Note that adjoint has no effect since this matrix is self-adjoint. if adjoint_arg: output_shape = array_ops.concat([ array_ops.shape(x)[:-2], [array_ops.shape(x)[-1], array_ops.shape(x)[-2]]], axis=0) else: output_shape = array_ops.shape(x) return self._possibly_broadcast_batch_shape( array_ops.zeros(shape=output_shape, dtype=x.dtype)) x_shape = array_ops.shape(x) n = self._num_columns if adjoint else self._num_rows m = x_shape[-2] if adjoint_arg else x_shape[-1] output_shape = array_ops.concat([x_shape[:-2], [n, m]], axis=0) zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype) return self._possibly_broadcast_batch_shape(zeros)
def _assert_self_adjoint(self): dense = self._get_cached_dense_matrix() logging.warn( "Using (possibly slow) default implementation of assert_self_adjoint." " Requires conversion to a dense matrix.") return check_ops.assert_equal( dense, linalg.adjoint(dense), message="Matrix was not equal to its adjoint.")
def _matmul(self, x, adjoint=False, adjoint_arg=False): x = linalg.adjoint(x) if adjoint_arg else x if adjoint: matrix = self._multiplier_matrix_conj else: matrix = self._multiplier_matrix if self._assert_proper_shapes: aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) x = control_flow_ops.with_dependencies([aps], x) return x * matrix
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs if adjoint: matrix = self._multiplier_matrix_conj else: matrix = self._multiplier_matrix if self._assert_proper_shapes: aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs) rhs = control_flow_ops.with_dependencies([aps], rhs) return rhs / matrix
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs spectrum = self._conj_spectrum if adjoint else self._spectrum_complex rhs, spectrum = self._broadcast_batch_dims(rhs, spectrum) rhs_vb = self._vectorize_then_blockify(rhs) fft_rhs_vb = self._fft(rhs_vb) solution_vb = self._ifft(fft_rhs_vb / spectrum) x = self._unblockify_then_matricize(solution_vb) return math_ops.cast(x, self.dtype)
def test_adjoint(self): with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self.operator_and_matrix( shapes_info, dtype, use_placeholder=use_placeholder) op_adjoint = operator.adjoint().to_dense() op_adjoint_h = operator.H.to_dense() mat_adjoint = linalg.adjoint(mat) op_adjoint_v, op_adjoint_h_v, mat_adjoint_v = sess.run( [op_adjoint, op_adjoint_h, mat_adjoint]) self.assertAC(mat_adjoint_v, op_adjoint_v) self.assertAC(mat_adjoint_v, op_adjoint_h_v)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): """Default implementation of _solve.""" if self.is_square is False: raise NotImplementedError( "Solve is not yet implemented for non-square operators.") logging.warn( "Using (possibly slow) default implementation of solve." " Requires conversion to a dense matrix and O(N^3) operations.") rhs = linalg.adjoint(rhs) if adjoint_arg else rhs if self._can_use_cholesky(): return linalg_ops.cholesky_solve(self._get_cached_chol(), rhs) return linalg_ops.matrix_solve( self._get_cached_dense_matrix(), rhs, adjoint=adjoint)
def test_adjoint(self): self._skip_if_tests_to_skip_contains("adjoint") for use_placeholder in self._use_placeholder_options: for build_info in self._operator_build_infos: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_adjoint = operator.adjoint().to_dense() op_adjoint_h = operator.H.to_dense() mat_adjoint = linalg.adjoint(mat) op_adjoint_v, op_adjoint_h_v, mat_adjoint_v = sess.run( [op_adjoint, op_adjoint_h, mat_adjoint]) self.assertAC(mat_adjoint_v, op_adjoint_v) self.assertAC(mat_adjoint_v, op_adjoint_h_v)
def _matmul(self, x, adjoint=False, adjoint_arg=False): x = linalg.adjoint(x) if adjoint_arg else x # With F the matrix of a DFT, and F^{-1}, F^H the inverse and Hermitian # transpose, one can show that F^{-1} = F^{H} is the IDFT matrix. Therefore # matmul(x) = F^{-1} diag(spectrum) F x, # = F^{H} diag(spectrum) F x, # so that # matmul(x, adjoint=True) = F^{H} diag(conj(spectrum)) F x. spectrum = self._conj_spectrum if adjoint else self._spectrum_complex x, spectrum = self._broadcast_batch_dims(x, spectrum) x_vb = self._vectorize_then_blockify(x) fft_x_vb = self._fft(x_vb) block_vector_result = self._ifft(spectrum * fft_x_vb) y = self._unblockify_then_matricize(block_vector_result) return math_ops.cast(y, self.dtype)
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Given a vector `v`, we would like to reflect `x` about the hyperplane # orthogonal to `v` going through the origin. We first project `x` to `v` # to get v * dot(v, x) / dot(v, v). After we project, we can reflect the # projection about the hyperplane by flipping sign to get # -v * dot(v, x) / dot(v, v). Finally, we can add back the component # that is orthogonal to v. This is invariant under reflection, since the # whole hyperplane is invariant. This component is equal to x - v * dot(v, # x) / dot(v, v), giving the formula x - 2 * v * dot(v, x) / dot(v, v) # for the reflection. # Note that because this is a reflection, it lies in O(n) (for real vector # spaces) or U(n) (for complex vector spaces), and thus is its own adjoint. x = linalg.adjoint(x) if adjoint_arg else x normalized_axis = self.reflection_axis / linalg.norm( self.reflection_axis, axis=-1, keepdims=True) mat = normalized_axis[..., array_ops.newaxis] x_dot_normalized_v = math_ops.matmul(mat, x, adjoint_a=True) return x - 2 * mat * x_dot_normalized_v
def _SelfAdjointEigV2Grad(op, grad_e, grad_v): """Gradient for SelfAdjointEigV2.""" e = op.outputs[0] compute_v = op.get_attr("compute_v") # a = op.inputs[0], which satisfies # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] with ops.control_dependencies([grad_e, grad_v]): if compute_v: v = op.outputs[1] # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). # Notice that because of the term involving f, the gradient becomes # infinite (or NaN in practice) when eigenvalues are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate eigenvalues, the corresponding eigenvectors are only defined # up to arbitrary rotation in a (k-dimensional) subspace. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), array_ops.zeros_like(e)) grad_a = math_ops.matmul( v, math_ops.matmul( array_ops.matrix_diag(grad_e) + f * math_ops.matmul(v, grad_v, adjoint_a=True), v, adjoint_b=True)) else: _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) grad_a = math_ops.matmul(v, math_ops.matmul( array_ops.matrix_diag(grad_e), v, adjoint_b=True)) # The forward op only depends on the lower triangular part of a, so here we # symmetrize and take the lower triangle grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0) grad_a = array_ops.matrix_set_diag(grad_a, 0.5 * array_ops.matrix_diag_part(grad_a)) return grad_a
def _CholeskyGrad(op, grad): """Gradient for Cholesky.""" # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} l = op.outputs[0] num_rows = array_ops.shape(l)[-1] batch_shape = array_ops.shape(l)[:-2] l_inverse = linalg_ops.matrix_triangular_solve(l, linalg_ops.eye( num_rows, batch_shape=batch_shape, dtype=l.dtype)) middle = math_ops.matmul(l, grad, adjoint_a=True) middle = array_ops.matrix_set_diag(middle, 0.5 * array_ops.matrix_diag_part(middle)) middle = array_ops.matrix_band_part(middle, -1, 0) grad_a = math_ops.matmul( math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse) grad_a += _linalg.adjoint(grad_a) return grad_a * 0.5
def _matmul( # pylint:disable=missing-docstring a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False, a_is_sparse=False, b_is_sparse=False, name=None): if transpose_a or transpose_b: raise ValueError("Transposing not supported at this time.") if a_is_sparse or b_is_sparse: raise ValueError("Sparse methods not supported at this time.") if not isinstance(a, LinearOperator): # We use the identity (B^HA^H)^H = AB adjoint_matmul = b.matmul( a, adjoint=(not adjoint_b), adjoint_arg=(not adjoint_a), name=name) return linalg.adjoint(adjoint_matmul) return a.matmul( b, adjoint=adjoint_a, adjoint_arg=adjoint_b, name=name)
def _test_solve_base( self, use_placeholder, shapes_info, dtype, adjoint, adjoint_arg, with_batch): # If batch dimensions are omitted, but there are # no batch dimensions for the linear operator, then # skip the test case. This is already checked with # with_batch=True. if not with_batch and len(shapes_info.shape) <= 2: return with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self.operator_and_matrix( shapes_info, dtype, use_placeholder=use_placeholder) rhs = self.make_rhs( operator, adjoint=adjoint, with_batch=with_batch) # If adjoint_arg, solve A X = (rhs^H)^H = rhs. if adjoint_arg: op_solve = operator.solve( linalg.adjoint(rhs), adjoint=adjoint, adjoint_arg=adjoint_arg) else: op_solve = operator.solve( rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) mat_solve = linear_operator_util.matrix_solve_with_broadcast( mat, rhs, adjoint=adjoint) if not use_placeholder: self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape()) op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve]) self.assertAC(op_solve_v, mat_solve_v)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs return linear_operator_util.matrix_triangular_solve_with_broadcast( self._tril, rhs, lower=True, adjoint=adjoint)
def _TriangularSolve(x, r): """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" return _linalg.adjoint( linalg_ops.matrix_triangular_solve( r, _linalg.adjoint(x), lower=False, adjoint=False))
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs return linalg.triangular_solve( self._get_tril(), rhs, lower=True, adjoint=adjoint)
def _reshape_for_efficiency(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False): """Maybe reshape a, b, and return an inverse map. For matmul/solve.""" def identity(x): return x # At this point, we have not taken transpose/adjoint of a/b. still_need_to_transpose = True if a.shape.ndims is None or b.shape.ndims is None: return a, b, identity, still_need_to_transpose # This could be handled in the future, but seems less common. if a.shape.ndims >= b.shape.ndims: return a, b, identity, still_need_to_transpose # From now on, we might modify b, but will not modify a. # Suppose: # a.shape = C + [m, n], b.shape = # b.shape = S + C + [n, r] b_extra_ndims = b.shape.ndims - a.shape.ndims # b_extra_sh = S, b_main_sh = C + [n, r] b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # No reason to flip unless the extra dims of b are big enough. Why? # Assume adjoint/transpose = False. Then... # By not flipping, we have to replicate a to shape # b_extra_sh + a.shape, # which could use extra memory. But in all cases, the final output has shape # b_extra_sh + a.shape[:-1] + [b.shape[-1]] # So we only end up creating a larger object if the end dim of b is smaller # than the end dim of a. This often happens, e.g. if b was a vector that was # expanded to a matrix (by appending a singleton). # Since adjoint/transpose may not be False, we must make adjustments here. # The dim of b that holds the multiple equations. a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1] b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1] b_extra_sz_ = ( np.prod(b.shape[:b_extra_ndims].as_list()) if b.shape[:b_extra_ndims].is_fully_defined() else None) if (a_domain_sz_ is not None and b_eq_sz_ is not None and b_extra_sz_ is not None): if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_: return a, b, identity, still_need_to_transpose # At this point, we're flipping for sure! # Any transposes/adjoints will happen here explicitly, rather than in calling # code. Why? To avoid having to write separate complex code for each case. if adjoint_a: a = linalg.adjoint(a) elif transpose_a: a = linalg.transpose(a) if adjoint_b: b = linalg.adjoint(b) elif transpose_b: b = linalg.transpose(b) still_need_to_transpose = False # Recompute shapes, since the transpose/adjoint may have changed them. b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # Permutation to put the extra dims at the end. perm = ( np.concatenate( (np.arange(b_extra_ndims, b.shape.ndims), np.arange(0, b_extra_ndims)), 0)) b_extra_on_end = array_ops.transpose(b, perm=perm) # Now squash this end into one long dim. b_squashed_end = array_ops.reshape( b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0)) def reshape_inv(y): # Expand the extra dims hanging off the end, "b_extra_sh". # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y # Could have different batch dims than a and b, because of broadcasting. y_extra_shape = array_ops.concat( (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) y_extra_on_end = array_ops.reshape(y, y_extra_shape) inverse_perm = np.argsort(perm) return array_ops.transpose(y_extra_on_end, perm=inverse_perm) return a, b_squashed_end, reshape_inv, still_need_to_transpose
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs return linalg_ops.matrix_triangular_solve(self._tril, rhs, lower=True, adjoint=adjoint)
def _SvdGrad(op, grad_s, grad_u, grad_v): """Gradient for the singular value decomposition.""" # The derivation for the compute_uv=False case, and most of # the derivation for the full_matrices=True case, are in # Giles' paper (see reference at top of file). A derivation for # the full_matrices=False case is available at # https://j-towns.github.io/papers/svd-derivative.pdf a = op.inputs[0] a_shape = a.get_shape().with_rank_at_least(2) grad_s_mat = array_ops.matrix_diag(grad_s) if not op.get_attr("compute_uv"): s, u, v = linalg_ops.svd(a, compute_uv=True) grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) grad_a.set_shape(a_shape) return grad_a full_matrices = op.get_attr("full_matrices") # TODO(rmlarsen): Make this work with complex types. if a.dtype.is_complex: raise NotImplementedError( "SVD gradient is not implemented for complex types and " "compute_uv=True.") grad_u_shape = grad_u.get_shape().with_rank_at_least(2) grad_v_shape = grad_v.get_shape().with_rank_at_least(2) m = a_shape.dims[-2].merge_with(grad_u_shape[-2]) n = a_shape.dims[-1].merge_with(grad_v_shape[-2]) batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( grad_v_shape[:-2]) a_shape = batch_shape.concatenate([m, n]) m = a_shape.dims[-2].value n = a_shape.dims[-1].value # TODO(rmlarsen): Make this work with placeholders. if m is None or n is None: raise NotImplementedError( "SVD gradient has not been implemented for input with unknown " "inner matrix shape.") s = op.outputs[0] u = op.outputs[1] v = op.outputs[2] use_adjoint = False if m > n: # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the # Hermitian transpose of the gradient at the end. use_adjoint = True m, n = n, m u, v = v, u grad_u, grad_v = grad_v, grad_u with ops.control_dependencies([grad_s, grad_u, grad_v]): if full_matrices and abs(m - n) > 1: raise NotImplementedError( "svd gradient is not implemented for abs(m - n) > 1 " "when full_matrices is True") s_mat = array_ops.matrix_diag(s) s2 = math_ops.square(s) # NOTICE: Because of the term involving f, the gradient becomes # infinite (or NaN in practice) when singular values are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate singular values, the corresponding singular vectors are # only defined up a (k-dimensional) subspace. In practice, this can # lead to numerical instability when singular values are close but not # exactly equal. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), array_ops.zeros_like(s)) s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s)) v1 = v[..., :, :m] grad_v1 = grad_v[..., :, :m] u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True) f_u = f * u_gu f_v = f * v_gv term1_nouv = ( grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True)) if m == n: grad_a_before_transpose = term1 else: gv1t = array_ops.matrix_transpose(grad_v1) gv1t_v1 = math_ops.matmul(gv1t, v1) term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) if full_matrices: v2 = v[..., :, m:n] grad_v2 = grad_v[..., :, m:n] v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True) term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True) u_s_inv = math_ops.matmul(u, s_inv_mat) term2 = math_ops.matmul(u_s_inv, term2_nous) grad_a_before_transpose = term1 + term2 if use_adjoint: grad_a = array_ops.matrix_transpose(grad_a_before_transpose) else: grad_a = grad_a_before_transpose grad_a.set_shape(a_shape) return grad_a
def _SvdGrad(op, grad_s, grad_u, grad_v): """Gradient for the singular value decomposition.""" # The derivation for the compute_uv=False case, and most of # the derivation for the full_matrices=True case, are in # Giles' paper (see reference at top of file). A derivation for # the full_matrices=False case is available at # https://j-towns.github.io/papers/svd-derivative.pdf # The derivation for complex valued SVD can be found in # https://re-ra.xyz/misc/complexsvd.pdf or # https://giggleliu.github.io/2019/04/02/einsumbp.html a = op.inputs[0] a_shape = a.get_shape().with_rank_at_least(2) grad_s = math_ops.cast(grad_s, a.dtype) grad_s_mat = array_ops.matrix_diag(grad_s) if not op.get_attr("compute_uv"): s, u, v = linalg_ops.svd(a, compute_uv=True) grad_a = math_ops.matmul( u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) grad_a.set_shape(a_shape) return grad_a full_matrices = op.get_attr("full_matrices") grad_u_shape = grad_u.get_shape().with_rank_at_least(2) grad_v_shape = grad_v.get_shape().with_rank_at_least(2) m = a_shape.dims[-2].merge_with(grad_u_shape[-2]) n = a_shape.dims[-1].merge_with(grad_v_shape[-2]) batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( grad_v_shape[:-2]) a_shape = batch_shape.concatenate([m, n]) m = a_shape.dims[-2].value n = a_shape.dims[-1].value # TODO(rmlarsen): Make this work with placeholders. if m is None or n is None: raise NotImplementedError( "SVD gradient has not been implemented for input with unknown " "inner matrix shape.") s = op.outputs[0] u = op.outputs[1] v = op.outputs[2] s = math_ops.cast(s, a.dtype) use_adjoint = False if m > n: # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the # Hermitian transpose of the gradient at the end. use_adjoint = True m, n = n, m u, v = v, u grad_u, grad_v = grad_v, grad_u with ops.control_dependencies([grad_s, grad_u, grad_v]): if full_matrices and abs(m - n) > 1: raise NotImplementedError( "svd gradient is not implemented for abs(m - n) > 1 " "when full_matrices is True") s_mat = array_ops.matrix_diag(s) s2 = math_ops.square(s) # NOTICE: Because of the term involving f, the gradient becomes # infinite (or NaN in practice) when singular values are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate singular values, the corresponding singular vectors are # only defined up a (k-dimensional) subspace. In practice, this can # lead to numerical instability when singular values are close but not # exactly equal. s_shape = array_ops.shape(s) f = array_ops.matrix_set_diag( _SafeReciprocal( array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), array_ops.zeros_like(s)) s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s)) v1 = v[..., :, :m] grad_v1 = grad_v[..., :, :m] u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True) f_u = f * u_gu f_v = f * v_gv term1_nouv = (grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) term1 = math_ops.matmul( u, math_ops.matmul(term1_nouv, v1, adjoint_b=True)) if m == n: grad_a_before_transpose = term1 else: gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True) gv1t_v1 = math_ops.matmul(gv1t, v1) term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) if full_matrices: v2 = v[..., :, m:n] grad_v2 = grad_v[..., :, m:n] v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True) term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True) u_s_inv = math_ops.matmul(u, s_inv_mat) term2 = math_ops.matmul(u_s_inv, term2_nous) grad_a_before_transpose = term1 + term2 if a.dtype.is_complex: eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype) l = eye * v_gv term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l) term3 = 1 / 2. * math_ops.matmul( u, math_ops.matmul(term3_nouv, v1, adjoint_b=True)) grad_a_before_transpose += term3 if use_adjoint: grad_a = array_ops.matrix_transpose(grad_a_before_transpose, conjugate=True) else: grad_a = grad_a_before_transpose grad_a.set_shape(a_shape) return grad_a
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs if self._assert_proper_shapes: aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs) rhs = control_flow_ops.with_dependencies([aps], rhs) return rhs / self._make_multiplier_matrix(conjugate=adjoint)
def _test_matmul_base( self, use_placeholder, shapes_info, dtype, adjoint, adjoint_arg, blockwise_arg, with_batch): # If batch dimensions are omitted, but there are # no batch dimensions for the linear operator, then # skip the test case. This is already checked with # with_batch=True. if not with_batch and len(shapes_info.shape) <= 2: return with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self.operator_and_matrix( shapes_info, dtype, use_placeholder=use_placeholder) x = self.make_x( operator, adjoint=adjoint, with_batch=with_batch) # If adjoint_arg, compute A X^H^H = A X. if adjoint_arg: op_matmul = operator.matmul( linalg.adjoint(x), adjoint=adjoint, adjoint_arg=adjoint_arg) else: op_matmul = operator.matmul(x, adjoint=adjoint) mat_matmul = math_ops.matmul(mat, x, adjoint_a=adjoint) if not use_placeholder: self.assertAllEqual(op_matmul.shape, mat_matmul.shape) # If the operator is blockwise, test both blockwise `x` and `Tensor` `x`; # else test only `Tensor` `x`. In both cases, evaluate all results in a # single `sess.run` call to avoid re-sampling the random `x` in graph mode. if blockwise_arg and len(operator.operators) > 1: # pylint: disable=protected-access block_dimensions = ( operator._block_range_dimensions() if adjoint else operator._block_domain_dimensions()) block_dimensions_fn = ( operator._block_range_dimension_tensors if adjoint else operator._block_domain_dimension_tensors) # pylint: enable=protected-access split_x = linear_operator_util.split_arg_into_blocks( block_dimensions, block_dimensions_fn, x, axis=-2) if adjoint_arg: split_x = [linalg.adjoint(y) for y in split_x] split_matmul = operator.matmul( split_x, adjoint=adjoint, adjoint_arg=adjoint_arg) self.assertEqual(len(split_matmul), len(operator.operators)) split_matmul = linear_operator_util.broadcast_matrix_batch_dims( split_matmul) fused_block_matmul = array_ops.concat(split_matmul, axis=-2) op_matmul_v, mat_matmul_v, fused_block_matmul_v = sess.run([ op_matmul, mat_matmul, fused_block_matmul]) # Check that the operator applied to blockwise input gives the same result # as matrix multiplication. self.assertAC(fused_block_matmul_v, mat_matmul_v) else: op_matmul_v, mat_matmul_v = sess.run([op_matmul, mat_matmul]) # Check that the operator applied to a `Tensor` gives the same result as # matrix multiplication. self.assertAC(op_matmul_v, mat_matmul_v)
def _matmul(self, x, adjoint=False, adjoint_arg=False): x = linalg.adjoint(x) if adjoint_arg else x if self._assert_proper_shapes: aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) x = control_flow_ops.with_dependencies([aps], x) return x * self._make_multiplier_matrix(conjugate=adjoint)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): diag_term = math_ops.conj(self._diag) if adjoint else self._diag rhs = linalg.adjoint(rhs) if adjoint_arg else rhs inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1) return rhs * inv_diag_mat
def _solve(self, rhs, adjoint=False, adjoint_arg=False): # Here we follow the same use of Roth's column lemma as in `matmul`, with # the key difference that we replace all `matmul` instances with `solve`. # This follows from the property that inv(A x B) = inv(A) x inv(B). # Below we document the shape manipulation for adjoint=False, # adjoint_arg=False, but the general case of different adjoints is still # handled. if adjoint_arg: rhs = linalg.adjoint(rhs) # Always add a batch dimension to enable broadcasting to work. batch_shape = array_ops.concat( [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0) rhs += array_ops.zeros(batch_shape, dtype=rhs.dtype.base_dtype) # rhs has shape [B, R, C], where B represent some number of batch # dimensions, # R represents the number of rows, and C represents the number of columns. # In order to apply Roth's column lemma, we need to operate on a batch of # column vectors, so we reshape into a batch of column vectors. We put it # at the front to ensure that broadcasting between operators to the batch # dimensions B still works. output = _rotate_last_dim(rhs, rotate_right=True) # Also expand the shape to be [A, C, B, R]. The first dimension will be # used to accumulate dimensions from each operator matmul. output = output[array_ops.newaxis, ...] # In this loop, A is going to refer to the value of the accumulated # dimension. A = 1 at the start, and will end up being self.range_dimension. # V will refer to the last dimension. V = R at the start, and will end up # being 1 in the end. for operator in self.operators[:-1]: # Reshape output from [A, C, B, V] to be # [A, C, B, V / op.domain_dimension, op.domain_dimension] if adjoint: operator_dimension = operator.range_dimension_tensor() else: operator_dimension = operator.domain_dimension_tensor() output = _unvec_by(output, operator_dimension) # We are computing (XA^-1^T) = (A^-1 X^T)^T. # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], # which is being converted to: # [A, C, B, V / op.domain_dimension, op.range_dimension] output = array_ops.matrix_transpose(output) output = operator.solve(output, adjoint=adjoint, adjoint_arg=False) output = array_ops.matrix_transpose(output) # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=True) # After the loop, we will have # A = self.range_dimension / op[-1].range_dimension # V = op[-1].domain_dimension # We convert that using matvec to get: # [A, C, B, op[-1].range_dimension] output = self.operators[-1].solvevec(output, adjoint=adjoint) # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=False) if rhs.shape.is_fully_defined(): column_dim = rhs.shape[-1] broadcast_batch_shape = common_shapes.broadcast_shape( rhs.shape[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] output.set_shape(broadcast_batch_shape.concatenate( matrix_dimensions)) return output
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. The returned `Tensor` will be close to an exact solution if `A` is well conditioned. Otherwise closeness will vary. See class docstring for details. Given the blockwise `n + 1`-by-`n + 1` linear operator: op = [[A_00 0 ... 0 ... 0], [A_10 A_11 ... 0 ... 0], ... [A_k0 A_k1 ... A_kk ... 0], ... [A_n0 A_n1 ... A_nk ... A_nn]] we find `x = op.solve(y)` by observing that `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)` and therefore `x_k = A_kk.solve(y_k - A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))` where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x` and `y` along their appropriate axes. We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`. The adjoint case is solved similarly, beginning with `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards. Examples: ```python # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] operator = LinearOperator(...) operator.shape = [..., M, N] # Solve R > 0 linear systems for every member of the batch. RHS = ... # shape [..., M, R] X = operator.solve(RHS) # X[..., :, r] is the solution to the r'th linear system # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] operator.matmul(X) ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator and compatible shape, or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices meaning for every set of leading dimensions, the last two dimensions defines a matrix. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` is the hermitian transpose (transposition and complex conjugation). name: A name scope to use for ops added by this method. Returns: `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. Raises: NotImplementedError: If `self.is_non_singular` or `is_square` is False. """ if self.is_non_singular is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "be singular.") if self.is_square is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "not be square.") if isinstance(rhs, linear_operator.LinearOperator): left_operator = self.adjoint() if adjoint else self right_operator = rhs.adjoint() if adjoint_arg else rhs if (right_operator.range_dimension is not None and left_operator.domain_dimension is not None and right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `rhs` to have dimension" " {} but got {}.".format( left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): # pylint: disable=not-callable return linear_operator_algebra.solve(left_operator, right_operator) with self._name_scope(name): # pylint: disable=not-callable block_dimensions = (self._block_domain_dimensions() if adjoint else self._block_range_dimensions()) arg_dim = -1 if adjoint_arg else -2 blockwise_arg = linear_operator_util.arg_is_blockwise( block_dimensions, rhs, arg_dim) if blockwise_arg: for i, block in enumerate(rhs): if not isinstance(block, linear_operator.LinearOperator): block = ops.convert_to_tensor_v2_with_dispatch(block) self._check_input_dtype(block) block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim]) rhs[i] = block if adjoint_arg: split_rhs = [linalg.adjoint(y) for y in rhs] else: split_rhs = rhs else: rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") self._check_input_dtype(rhs) op_dimension = (self.domain_dimension if adjoint else self.range_dimension) op_dimension.assert_is_compatible_with(rhs.shape[arg_dim]) rhs = linalg.adjoint(rhs) if adjoint_arg else rhs split_rhs = linear_operator_util.split_arg_into_blocks( self._block_domain_dimensions(), self._block_domain_dimension_tensors, rhs, axis=-2) solution_list = [] if adjoint: # For an adjoint blockwise lower-triangular linear operator, the system # must be solved bottom to top. Iterate backwards over rows of the # adjoint (i.e. columns of the non-adjoint operator). for index in reversed(range(len(self.operators))): y = split_rhs[index] # Iterate top to bottom over the operators in the off-diagonal portion # of the column-partition (i.e. row-partition of the adjoint), apply # the operator to the respective block of the solution found in # previous iterations, and subtract the result from the `rhs` block. # For example,let `A`, `B`, and `D` be the linear operators in the top # row-partition of the adjoint of # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`, # and `x_1` and `x_2` be blocks of the solution found in previous # iterations of the outer loop. The following loop (when `index == 0`) # expresses # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where # `y_0* = y_0 - Bx_1 - Dx_2`. for j in reversed(range(index + 1, len(self.operators))): y = y - self.operators[j][index].matmul( solution_list[len(self.operators) - 1 - j], adjoint=adjoint) # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`. solution_list.append( self._diagonal_operators[index].solve(y, adjoint=adjoint)) solution_list.reverse() else: # Iterate top to bottom over the row-partitions. for row, y in zip(self.operators, split_rhs): # Iterate left to right over the operators in the off-diagonal portion # of the row-partition, apply the operator to the block of the # solution found in previous iterations, and subtract the result from # the `rhs` block. For example, let `D`, `E`, and `F` be the linear # operators in the bottom row-partition of # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and # `x_0` and `x_1` be blocks of the solution found in previous # iterations of the outer loop. The following loop # (when `index == 2`), expresses # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where # `y_2* = y_2 - D_x0 - Ex_1`. for i, operator in enumerate(row[:-1]): y = y - operator.matmul(solution_list[i], adjoint=adjoint) # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`. solution_list.append(row[-1].solve(y, adjoint=adjoint)) if blockwise_arg: return solution_list solution_list = linear_operator_util.broadcast_matrix_batch_dims( solution_list) return array_ops.concat(solution_list, axis=-2)
def _SvdGrad(op, grad_s, grad_u, grad_v): """Gradient for Svd based on Giles' algorithm. Reference at top of file.""" if op.get_attr("compute_uv") and not op.get_attr("full_matrices"): raise NotImplementedError( "SVD gradient is not implemented for compute_uv=True and " "full_matrices=False.") a = op.inputs[0] a_shape = a.get_shape().with_rank_at_least(2) if op.get_attr("compute_uv"): # TODO(rmlarsen): Make this work with complex types. if a.dtype.is_complex: raise NotImplementedError( "SVD gradient is not implemented for complex types and " "compute_uv=True.") grad_u_shape = grad_u.get_shape().with_rank_at_least(2) grad_v_shape = grad_v.get_shape().with_rank_at_least(2) m = a_shape[-2].merge_with(grad_u_shape[-2]) n = a_shape[-1].merge_with(grad_v_shape[-2]) batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( grad_v_shape[:-2]) a_shape = batch_shape.concatenate([m, n]) m = a_shape[-2].value n = a_shape[-1].value # TODO(rmlarsen): Make this work with placeholders. if m is None or n is None: raise NotImplementedError( "SVD gradient has not been implemented for input with unknown " "inner matrix shape.") if not op.get_attr("full_matrices") or not op.get_attr("compute_uv"): s, u, v = linalg_ops.svd(a, compute_uv=True, full_matrices=True) else: s = op.outputs[0] u = op.outputs[1] v = op.outputs[2] use_adjoint = False if m > n: # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the # Hermitian transpose of the gradient at the end. use_adjoint = True m, n = n, m u, v = v, u grad_u, grad_v = grad_v, grad_u with ops.control_dependencies([grad_s, grad_u, grad_v]): grad_s_mat = array_ops.matrix_diag(grad_s) if not op.get_attr("compute_uv"): if use_adjoint: grad_a = math_ops.matmul(v[..., :, :m], math_ops.matmul(u, grad_s_mat), adjoint_b=True) else: grad_a = math_ops.matmul( u, math_ops.matmul(grad_s_mat, v[..., :, :m], adjoint_b=True)) grad_a.set_shape(a_shape) return grad_a # TODO(rmlarsen): Define a gradient that is numerically stable for # abs(m-n) > 1. Currently this does not work because there are effectively # multiple singular values with value zero. I am not sure if this is a true # instability or if it simply throws off the finite difference gradient # checker. if abs(m - n) > 1: raise NotImplementedError( "svd gradient is not implemented for abs(m - n) > 1") s_mat = array_ops.matrix_diag(s) s2 = math_ops.square(s) # NOTICE: Because of the term involving f, the gradient becomes # infinite (or NaN in practice) when singular values are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate singular values, the corresponding singular vectors are # only defined up a (k-dimensional) subspace. In practice, this can # lead to numerical instability when singular values are close but not # exactly equal. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), array_ops.zeros_like(s)) s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s)) u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) v_gv = math_ops.matmul(v, grad_v, adjoint_a=True) if m == n: f_u = f * u_gu f_v = f * v_gv else: dv2 = array_ops.matrix_transpose( v_gv[..., m:n, :m]) - v_gv[..., :m, m:n] f_u = f * u_gu f_v = f * v_gv[..., :m, :m] grad_a_nouv = (grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) if m != n: grad_a_nouv = array_ops.concat( [grad_a_nouv, math_ops.matmul(s_inv_mat, dv2)], -1) if use_adjoint: # Use (U X V^H)^H = V (U X)^H. grad_a = math_ops.matmul(v, math_ops.matmul(u, grad_a_nouv), adjoint_b=True) else: grad_a = math_ops.matmul( u, math_ops.matmul(grad_a_nouv, v, adjoint_b=True)) grad_a.set_shape(a_shape) return grad_a
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Here we heavily rely on Roth's column Lemma [1]: # (A x B) * vec X = vec BXA^T, # where vec stacks all the columns of the matrix under each other. In our # case, x represents a batch of vec X (i.e. we think of x as a batch of # column vectors, rather than a matrix). Each member of the batch can be # reshaped to a matrix (hence we get a batch of matrices). # We can iteratively apply this lemma by noting that if B is a Kronecker # product, then we can apply the lemma again. # [1] W. E. Roth, "On direct product matrices," # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468, # 1934 # Efficiency # Naively doing the Kronecker product, by calculating the dense matrix and # applying it will can take cubic time in the size of domain_dimension # (assuming a square matrix). The other issue is that calculating the dense # matrix can be prohibitively expensive, in that it can take a large amount # of memory. # # This implementation avoids this memory blow up by only computing matmuls # with the factors. In this way, we don't have to realize the dense matrix. # In terms of complexity, if we have Kronecker Factors of size: # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we # have as input a [N, M] matrix, the naive approach would take O(N^2 M). # With this approach (ignoring reshaping of tensors and transposes for now), # the time complexity can be O(M * (\sum n_i) * N). There is also the # benefit of batched multiplication (In this example, the batch size is # roughly M * N) so this can be much faster. However, not factored in are # the costs of the several transposing of tensors, which can affect cache # behavior. # Below we document the shape manipulation for adjoint=False, # adjoint_arg=False, but the general case of different adjoints is still # handled. if adjoint_arg: x = linalg.adjoint(x) # Always add a batch dimension to enable broadcasting to work. batch_shape = array_ops.concat( [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0) x += array_ops.zeros(batch_shape, dtype=x.dtype.base_dtype) # x has shape [B, R, C], where B represent some number of batch dimensions, # R represents the number of rows, and C represents the number of columns. # In order to apply Roth's column lemma, we need to operate on a batch of # column vectors, so we reshape into a batch of column vectors. We put it # at the front to ensure that broadcasting between operators to the batch # dimensions B still works. output = _rotate_last_dim(x, rotate_right=True) # Also expand the shape to be [A, C, B, R]. The first dimension will be # used to accumulate dimensions from each operator matmul. output = output[array_ops.newaxis, ...] # In this loop, A is going to refer to the value of the accumulated # dimension. A = 1 at the start, and will end up being self.range_dimension. # V will refer to the last dimension. V = R at the start, and will end up # being 1 in the end. for operator in self.operators[:-1]: # Reshape output from [A, C, B, V] to be # [A, C, B, V / op.domain_dimension, op.domain_dimension] if adjoint: operator_dimension = operator.range_dimension_tensor() else: operator_dimension = operator.domain_dimension_tensor() output = _unvec_by(output, operator_dimension) # We are computing (XA^T) = (AX^T)^T. # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], # which is being converted to: # [A, C, B, V / op.domain_dimension, op.range_dimension] output = array_ops.matrix_transpose(output) output = operator.matmul(output, adjoint=adjoint, adjoint_arg=False) output = array_ops.matrix_transpose(output) # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=True) # After the loop, we will have # A = self.range_dimension / op[-1].range_dimension # V = op[-1].domain_dimension # We convert that using matvec to get: # [A, C, B, op[-1].range_dimension] output = self.operators[-1].matvec(output, adjoint=adjoint) # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=False) if x.shape.is_fully_defined(): column_dim = x.shape[-1] broadcast_batch_shape = common_shapes.broadcast_shape( x.shape[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] output.set_shape( broadcast_batch_shape.concatenate(matrix_dimensions)) return output
def _reshape_for_efficiency(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False): """Maybe reshape a, b, and return an inverse map. For matmul/solve.""" def identity(x): return x # At this point, we have not taken transpose/adjoint of a/b. still_need_to_transpose = True if a.shape.ndims is None or b.shape.ndims is None: return a, b, identity, still_need_to_transpose # This could be handled in the future, but seems less common. if a.shape.ndims >= b.shape.ndims: return a, b, identity, still_need_to_transpose # From now on, we might modify b, but will not modify a. # Suppose: # a.shape = C + [m, n], b.shape = # b.shape = S + C + [n, r] b_extra_ndims = b.shape.ndims - a.shape.ndims # b_extra_sh = S, b_main_sh = C + [n, r] b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # No reason to flip unless the extra dims of b are big enough. Why? # Assume adjoint/transpose = False. Then... # By not flipping, we have to replicate a to shape # b_extra_sh + a.shape, # which could use extra memory. But in all cases, the final output has shape # b_extra_sh + a.shape[:-1] + [b.shape[-1]] # So we only end up creating a larger object if the end dim of b is smaller # than the end dim of a. This often happens, e.g. if b was a vector that was # expanded to a matrix (by appending a singleton). # Since adjoint/transpose may not be False, we must make adjustments here. # The dim of b that holds the multiple equations. a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1] b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1] b_extra_sz_ = ( np.prod(b.shape[:b_extra_ndims].as_list()) if b.shape[:b_extra_ndims].is_fully_defined() else None) if (a_domain_sz_ is not None and b_eq_sz_ is not None and b_extra_sz_ is not None): if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_: return a, b, identity, still_need_to_transpose # At this point, we're flipping for sure! # Any transposes/adjoints will happen here explicitly, rather than in calling # code. Why? To avoid having to write separate complex code for each case. if adjoint_a: a = linalg.adjoint(a) elif transpose_a: a = linalg.transpose(a) if adjoint_b: b = linalg.adjoint(b) elif transpose_b: b = linalg.transpose(b) still_need_to_transpose = False # Recompute shapes, since the transpose/adjoint may have changed them. b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # Permutation to put the extra dims at the end. perm = ( array_ops.concat( (math_ops.range(b_extra_ndims, b.shape.ndims), math_ops.range(0, b_extra_ndims)), 0)) b_extra_on_end = array_ops.transpose(b, perm=perm) # Now squash this end into one long dim. b_squashed_end = array_ops.reshape( b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0)) def reshape_inv(y): # Expand the extra dims hanging off the end, "b_extra_sh". # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y # Could have different batch dims than a and b, because of broadcasting. y_extra_shape = array_ops.concat( (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) y_extra_on_end = array_ops.reshape(y, y_extra_shape) return array_ops.transpose( y_extra_on_end, perm=array_ops.invert_permutation(perm)) return a, b_squashed_end, reshape_inv, still_need_to_transpose
def _solve(self, rhs, adjoint=False, adjoint_arg=False): # Given the blockwise `n + 1`-by-`n + 1` linear operator: # # op = [[A_00 0 ... 0 ... 0], # [A_10 A_11 ... 0 ... 0], # ... # [A_k0 A_k1 ... A_kk ... 0], # ... # [A_n0 A_n1 ... A_nk ... A_nn]] # # we find `x = op.solve(y)` by observing that # # `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)` # # and therefore # # `x_k = A_kk.solve(y_k - # A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))` # # where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x` # and `y` along their appropriate axes. # # We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve # for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`. # # The adjoint case is solved similarly, beginning with # `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards. rhs = linalg.adjoint(rhs) if adjoint_arg else rhs split_rhs = self._split_input_into_blocks(rhs, axis=-2) solution_list = [] if adjoint: # For an adjoint blockwise lower-triangular linear operator, the system # must be solved bottom to top. Iterate backwards over rows of the adjoint # (i.e. columns of the non-adjoint operator). for index in reversed(range(len(self.operators))): y = split_rhs[index] # Iterate top to bottom over the operators in the off-diagonal portion # of the column-partition (i.e. row-partition of the adjoint), apply # the operator to the respective block of the solution found in previous # iterations, and subtract the result from the `rhs` block. For example, # let `A`, `B`, and `D` be the linear operators in the top row-partition # of the adjoint of # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`, # and `x_1` and `x_2` be blocks of the solution found in previous # iterations of the outer loop. The following loop (when `index == 0`) # expresses # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where # `y_0* = y_0 - Bx_1 - Dx_2`. for j in reversed(range(index + 1, len(self.operators))): y -= self.operators[j][index].matmul( solution_list[len(self.operators) - 1 - j], adjoint=adjoint) # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`. solution_list.append( self.operators[index][index].solve(y, adjoint=adjoint)) solution_list.reverse() else: # Iterate top to bottom over the row-partitions. for row, y in zip(self.operators, split_rhs): # Iterate left to right over the operators in the off-diagonal portion # of the row-partition, apply the operator to the block of the solution # found in previous iterations, and subtract the result from the `rhs` # block. For example, let `D`, `E`, and `F` be the linear operators in # the bottom row-partition of # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and # `x_0` and `x_1` be blocks of the solution found in previous iterations # of the outer loop. The following loop (when `index == 2`), expresses # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where # `y_2* = y_2 - D_x0 - Ex_1`. for i, operator in enumerate(row[:-1]): y -= operator.matmul(solution_list[i], adjoint=adjoint) # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`. solution_list.append(row[-1].solve(y, adjoint=adjoint)) solution_list = linear_operator_util.broadcast_matrix_batch_dims( solution_list) return array_ops.concat(solution_list, axis=-2)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): # Here we follow the same use of Roth's column lemma as in `matmul`, with # the key difference that we replace all `matmul` instances with `solve`. # This follows from the property that inv(A x B) = inv(A) x inv(B). # Below we document the shape manipulation for adjoint=False, # adjoint_arg=False, but the general case of different adjoints is still # handled. if adjoint_arg: rhs = linalg.adjoint(rhs) # Always add a batch dimension to enable broadcasting to work. batch_shape = array_ops.concat( [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0) rhs += array_ops.zeros(batch_shape, dtype=rhs.dtype.base_dtype) # rhs has shape [B, R, C], where B represent some number of batch # dimensions, # R represents the number of rows, and C represents the number of columns. # In order to apply Roth's column lemma, we need to operate on a batch of # column vectors, so we reshape into a batch of column vectors. We put it # at the front to ensure that broadcasting between operators to the batch # dimensions B still works. output = _rotate_last_dim(rhs, rotate_right=True) # Also expand the shape to be [A, C, B, R]. The first dimension will be # used to accumulate dimensions from each operator matmul. output = output[array_ops.newaxis, ...] # In this loop, A is going to refer to the value of the accumulated # dimension. A = 1 at the start, and will end up being self.range_dimension. # V will refer to the last dimension. V = R at the start, and will end up # being 1 in the end. for operator in self.operators[:-1]: # Reshape output from [A, C, B, V] to be # [A, C, B, V / op.domain_dimension, op.domain_dimension] if adjoint: operator_dimension = operator.range_dimension_tensor() else: operator_dimension = operator.domain_dimension_tensor() output = _unvec_by(output, operator_dimension) # We are computing (XA^-1^T) = (A^-1 X^T)^T. # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], # which is being converted to: # [A, C, B, V / op.domain_dimension, op.range_dimension] output = array_ops.matrix_transpose(output) output = operator.solve(output, adjoint=adjoint, adjoint_arg=False) output = array_ops.matrix_transpose(output) # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=True) # After the loop, we will have # A = self.range_dimension / op[-1].range_dimension # V = op[-1].domain_dimension # We convert that using matvec to get: # [A, C, B, op[-1].range_dimension] output = self.operators[-1].solvevec(output, adjoint=adjoint) # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=False) if rhs.shape.is_fully_defined(): column_dim = rhs.shape[-1] broadcast_batch_shape = common_shapes.broadcast_shape( rhs.shape[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] output.set_shape( broadcast_batch_shape.concatenate(matrix_dimensions)) return output
def _matmul(self, x, adjoint=False, adjoint_arg=False): diag_term = math_ops.conj(self._diag) if adjoint else self._diag x = linalg.adjoint(x) if adjoint_arg else x diag_mat = array_ops.expand_dims(diag_term, -1) return diag_mat * x
def _TriangularSolve(x, r): """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" return _linalg.adjoint( linalg_ops.matrix_triangular_solve( r, _linalg.adjoint(x), lower=False, adjoint=False))
def _solve(self, rhs, adjoint=False, adjoint_arg=False): diag_term = math_ops.conj(self._diag) if adjoint else self._diag rhs = linalg.adjoint(rhs) if adjoint_arg else rhs inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1) return rhs * inv_diag_mat
def _to_dense(self): if self.is_self_adjoint: return self.operator.to_dense() return linalg.adjoint(self.operator.to_dense())
def _test_solve_base( self, use_placeholder, shapes_info, dtype, adjoint, adjoint_arg, blockwise_arg, with_batch): # If batch dimensions are omitted, but there are # no batch dimensions for the linear operator, then # skip the test case. This is already checked with # with_batch=True. if not with_batch and len(shapes_info.shape) <= 2: return with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self.operator_and_matrix( shapes_info, dtype, use_placeholder=use_placeholder) rhs = self.make_rhs( operator, adjoint=adjoint, with_batch=with_batch) # If adjoint_arg, solve A X = (rhs^H)^H = rhs. if adjoint_arg: op_solve = operator.solve( linalg.adjoint(rhs), adjoint=adjoint, adjoint_arg=adjoint_arg) else: op_solve = operator.solve( rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) mat_solve = linear_operator_util.matrix_solve_with_broadcast( mat, rhs, adjoint=adjoint) if not use_placeholder: self.assertAllEqual(op_solve.shape, mat_solve.shape) # If the operator is blockwise, test both blockwise rhs and `Tensor` rhs; # else test only `Tensor` rhs. In both cases, evaluate all results in a # single `sess.run` call to avoid re-sampling the random rhs in graph mode. if blockwise_arg and len(operator.operators) > 1: # pylint: disable=protected-access block_dimensions = ( operator._block_range_dimensions() if adjoint else operator._block_domain_dimensions()) block_dimensions_fn = ( operator._block_range_dimension_tensors if adjoint else operator._block_domain_dimension_tensors) # pylint: enable=protected-access split_rhs = linear_operator_util.split_arg_into_blocks( block_dimensions, block_dimensions_fn, rhs, axis=-2) if adjoint_arg: split_rhs = [linalg.adjoint(y) for y in split_rhs] split_solve = operator.solve( split_rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) self.assertEqual(len(split_solve), len(operator.operators)) split_solve = linear_operator_util.broadcast_matrix_batch_dims( split_solve) fused_block_solve = array_ops.concat(split_solve, axis=-2) op_solve_v, mat_solve_v, fused_block_solve_v = sess.run([ op_solve, mat_solve, fused_block_solve]) # Check that the operator and matrix give the same solution when the rhs # is blockwise. self.assertAC(mat_solve_v, fused_block_solve_v) else: op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve]) # Check that the operator and matrix give the same solution when the rhs is # a `Tensor`. self.assertAC(op_solve_v, mat_solve_v)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs return linear_operator_util.matrix_triangular_solve_with_broadcast( array_ops.matrix_set_diag(self._tril, math_ops.exp(self._diag)), rhs, lower=True, adjoint=adjoint)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): rhs = linalg.adjoint(rhs) if adjoint_arg else rhs return linalg_ops.matrix_triangular_solve( self._tril, rhs, lower=True, adjoint=adjoint)
def _matmul(self, x, adjoint=False, adjoint_arg=False): diag_term = math_ops.conj(self._diag) if adjoint else self._diag x = linalg.adjoint(x) if adjoint_arg else x diag_mat = array_ops.expand_dims(diag_term, -1) return diag_mat * x