def test_less_than_two_dims_raises_static(self): x = rng.rand(3) y = rng.rand(1, 1) with self.assertRaisesRegexp(ValueError, "at least two dimensions"): linear_operator_util.broadcast_matrix_batch_dims([x, y]) with self.assertRaisesRegexp(ValueError, "at least two dimensions"): linear_operator_util.broadcast_matrix_batch_dims([y, x])
def _broadcast_batch_dims(self, x, spectrum): """Broadcast batch dims of batch matrix `x` and spectrum.""" spectrum = ops.convert_to_tensor_v2_with_dispatch(spectrum, name="spectrum") # spectrum.shape = batch_shape + block_shape # First make spectrum a batch matrix with # spectrum.shape = batch_shape + [prod(block_shape), 1] batch_shape = self._batch_shape_tensor(shape=self._shape_tensor( spectrum=spectrum)) spec_mat = array_ops.reshape( spectrum, array_ops.concat((batch_shape, [-1, 1]), axis=0)) # Second, broadcast, possibly requiring an addition of array of zeros. x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims( (x, spec_mat)) # Third, put the block shape back into spectrum. x_batch_shape = array_ops.shape(x)[:-2] spectrum_shape = array_ops.shape(spectrum) spectrum = array_ops.reshape( spec_mat, array_ops.concat( (x_batch_shape, self._block_shape_tensor(spectrum_shape=spectrum_shape)), axis=0)) return x, spectrum
def test_one_batch_matrix_returned_after_tensor_conversion(self): arr = rng.rand(2, 3, 4) tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr]) self.assertTrue(isinstance(tensor, ops.Tensor)) with self.cached_session(): self.assertAllClose(arr, self.evaluate(tensor))
def _to_dense(self): num_cols = 0 rows = [] broadcasted_blocks = [operator.to_dense() for operator in self.operators] broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims( broadcasted_blocks) for block in broadcasted_blocks: batch_row_shape = array_ops.shape(block)[:-1] zeros_to_pad_before_shape = array_ops.concat( [batch_row_shape, [num_cols]], axis=-1) zeros_to_pad_before = array_ops.zeros( shape=zeros_to_pad_before_shape, dtype=block.dtype) num_cols += array_ops.shape(block)[-1] zeros_to_pad_after_shape = array_ops.concat( [batch_row_shape, [self.domain_dimension_tensor() - num_cols]], axis=-1) zeros_to_pad_after = array_ops.zeros( shape=zeros_to_pad_after_shape, dtype=block.dtype) rows.append(array_ops.concat( [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1)) mat = array_ops.concat(rows, axis=-2) mat.set_shape(self.shape) return mat
def test_one_batch_matrix_returned_after_tensor_conversion(self): arr = rng.rand(2, 3, 4) tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr]) self.assertTrue(isinstance(tensor, ops.Tensor)) with self.cached_session(): self.assertAllClose(arr, tensor.eval())
def _matmul(self, x, adjoint=False, adjoint_arg=False): arg_dim = -1 if adjoint_arg else -2 block_dimensions = (self._block_range_dimensions() if adjoint else self._block_domain_dimensions()) blockwise_arg = linear_operator_util.arg_is_blockwise( block_dimensions, x, arg_dim) if blockwise_arg: split_x = x else: split_dim = -1 if adjoint_arg else -2 # Split input by rows normally, and otherwise columns. split_x = linear_operator_util.split_arg_into_blocks( self._block_domain_dimensions(), self._block_domain_dimension_tensors, x, axis=split_dim) result_list = [] for index, operator in enumerate(self.operators): result_list += [ operator.matmul(split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg) ] if blockwise_arg: return result_list result_list = linear_operator_util.broadcast_matrix_batch_dims( result_list) return array_ops.concat(result_list, axis=-2)
def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) expected_factors = build_info.__dict__["factors"] matrices = [ linear_operator_test_util.random_positive_definite_matrix( block_shape, dtype, force_well_conditioned=True) for block_shape in expected_factors ] lin_op_matrices = matrices if use_placeholder: lin_op_matrices = [ array_ops.placeholder_with_default(m, shape=None) for m in matrices] operator = kronecker.LinearOperatorKronecker( [linalg.LinearOperatorFullMatrix( l, is_square=True) for l in lin_op_matrices]) matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) kronecker_dense = _kronecker_dense(matrices) if not use_placeholder: kronecker_dense.set_shape(shape) return operator, kronecker_dense
def _to_dense(self): num_cols = 0 rows = [] broadcasted_blocks = [ operator.to_dense() for operator in self.operators ] broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims( broadcasted_blocks) for block in broadcasted_blocks: batch_row_shape = array_ops.shape(block)[:-1] zeros_to_pad_before_shape = array_ops.concat( [batch_row_shape, [num_cols]], axis=-1) zeros_to_pad_before = array_ops.zeros( shape=zeros_to_pad_before_shape, dtype=block.dtype) num_cols += array_ops.shape(block)[-1] zeros_to_pad_after_shape = array_ops.concat( [batch_row_shape, [self.domain_dimension_tensor() - num_cols]], axis=-1) zeros_to_pad_after = array_ops.zeros( shape=zeros_to_pad_after_shape, dtype=block.dtype) rows.append( array_ops.concat( [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1)) mat = array_ops.concat(rows, axis=-2) mat.set_shape(self.shape) return mat
def _test_matmul_base(self, use_placeholder, shapes_info, dtype, adjoint, adjoint_arg, blockwise_arg, with_batch): # If batch dimensions are omitted, but there are # no batch dimensions for the linear operator, then # skip the test case. This is already checked with # with_batch=True. if not with_batch and len(shapes_info.shape) <= 2: return with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self.operator_and_matrix( shapes_info, dtype, use_placeholder=use_placeholder) x = self.make_x(operator, adjoint=adjoint, with_batch=with_batch) # If adjoint_arg, compute A X^H^H = A X. if adjoint_arg: op_matmul = operator.matmul(linalg.adjoint(x), adjoint=adjoint, adjoint_arg=adjoint_arg) else: op_matmul = operator.matmul(x, adjoint=adjoint) mat_matmul = math_ops.matmul(mat, x, adjoint_a=adjoint) if not use_placeholder: self.assertAllEqual(op_matmul.shape, mat_matmul.shape) # If the operator is blockwise, test both blockwise `x` and `Tensor` `x`; # else test only `Tensor` `x`. In both cases, evaluate all results in a # single `sess.run` call to avoid re-sampling the random `x` in graph mode. if blockwise_arg and len(operator.operators) > 1: # pylint: disable=protected-access block_dimensions = (operator._block_range_dimensions() if adjoint else operator._block_domain_dimensions()) block_dimensions_fn = (operator._block_range_dimension_tensors if adjoint else operator._block_domain_dimension_tensors) # pylint: enable=protected-access split_x = linear_operator_util.split_arg_into_blocks( block_dimensions, block_dimensions_fn, x, axis=-2) if adjoint_arg: split_x = [linalg.adjoint(y) for y in split_x] split_matmul = operator.matmul(split_x, adjoint=adjoint, adjoint_arg=adjoint_arg) self.assertEqual(len(split_matmul), len(operator.operators)) split_matmul = linear_operator_util.broadcast_matrix_batch_dims( split_matmul) fused_block_matmul = array_ops.concat(split_matmul, axis=-2) op_matmul_v, mat_matmul_v, fused_block_matmul_v = sess.run( [op_matmul, mat_matmul, fused_block_matmul]) # Check that the operator applied to blockwise input gives the same result # as matrix multiplication. self.assertAC(fused_block_matmul_v, mat_matmul_v) else: op_matmul_v, mat_matmul_v = sess.run([op_matmul, mat_matmul]) # Check that the operator applied to a `Tensor` gives the same result as # matrix multiplication. self.assertAC(op_matmul_v, mat_matmul_v)
def _diag_part(self): diag_list = [] for operator in self.operators: # Extend the axis for broadcasting. diag_list += [operator.diag_part()[..., array_ops.newaxis]] diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list) diagonal = array_ops.concat(diag_list, axis=-2) return array_ops.squeeze(diagonal, axis=-1)
def _eigvals(self): eig_list = [] for op in self._diagonal_operators: # Extend the axis for broadcasting. eig_list.append(op.eigvals()[..., array_ops.newaxis]) eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list) eigs = array_ops.concat(eig_list, axis=-2) return array_ops.squeeze(eigs, axis=-1)
def _test_solve_base(self, use_placeholder, shapes_info, dtype, adjoint, adjoint_arg, blockwise_arg, with_batch): # If batch dimensions are omitted, but there are # no batch dimensions for the linear operator, then # skip the test case. This is already checked with # with_batch=True. if not with_batch and len(shapes_info.shape) <= 2: return with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self.operator_and_matrix( shapes_info, dtype, use_placeholder=use_placeholder) rhs = self.make_rhs(operator, adjoint=adjoint, with_batch=with_batch) # If adjoint_arg, solve A X = (rhs^H)^H = rhs. if adjoint_arg: op_solve = operator.solve(linalg.adjoint(rhs), adjoint=adjoint, adjoint_arg=adjoint_arg) else: op_solve = operator.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) mat_solve = linear_operator_util.matrix_solve_with_broadcast( mat, rhs, adjoint=adjoint) if not use_placeholder: self.assertAllEqual(op_solve.shape, mat_solve.shape) # If the operator is blockwise, test both blockwise rhs and `Tensor` rhs; # else test only `Tensor` rhs. In both cases, evaluate all results in a # single `sess.run` call to avoid re-sampling the random rhs in graph mode. if blockwise_arg and len(operator.operators) > 1: split_rhs = linear_operator_util.split_arg_into_blocks( operator._block_domain_dimensions(), # pylint: disable=protected-access operator._block_domain_dimension_tensors, # pylint: disable=protected-access rhs, axis=-2) if adjoint_arg: split_rhs = [linalg.adjoint(y) for y in split_rhs] split_solve = operator.solve(split_rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) self.assertEqual(len(split_solve), len(operator.operators)) split_solve = linear_operator_util.broadcast_matrix_batch_dims( split_solve) fused_block_solve = array_ops.concat(split_solve, axis=-2) op_solve_v, mat_solve_v, fused_block_solve_v = sess.run( [op_solve, mat_solve, fused_block_solve]) # Check that the operator and matrix give the same solution when the rhs # is blockwise. self.assertAC(mat_solve_v, fused_block_solve_v) else: op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve]) # Check that the operator and matrix give the same solution when the rhs is # a `Tensor`. self.assertAC(op_solve_v, mat_solve_v)
def _diag_part(self): diag_list = [] for op in self._diagonal_operators: # Extend the axis, since `broadcast_matrix_batch_dims` treats all but the # final two dimensions as batch dimensions. diag_list.append(op.diag_part()[..., array_ops.newaxis]) diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list) diagonal = array_ops.concat(diag_list, axis=-2) return array_ops.squeeze(diagonal, axis=-1)
def _matmul(self, x, adjoint=False, adjoint_arg=False): split_dim = -1 if adjoint_arg else -2 # Split input by rows normally, and otherwise columns. split_x = self._split_input_into_blocks(x, axis=split_dim) result_list = [] for index, operator in enumerate(self.operators): result_list += [operator.matmul( split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)] result_list = linear_operator_util.broadcast_matrix_batch_dims( result_list) return array_ops.concat(result_list, axis=-2)
def _eigvals(self): if not all(operator.is_square for operator in self.operators): raise NotImplementedError( "`eigvals` not implemented for an operator whose blocks are not " "square.") eig_list = [] for operator in self.operators: # Extend the axis for broadcasting. eig_list += [operator.eigvals()[..., array_ops.newaxis]] eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list) eigs = array_ops.concat(eig_list, axis=-2) return array_ops.squeeze(eigs, axis=-1)
def _diag_part(self): if not all(operator.is_square for operator in self.operators): raise NotImplementedError( "`diag_part` not implemented for an operator whose blocks are not " "square.") diag_list = [] for operator in self.operators: # Extend the axis for broadcasting. diag_list += [operator.diag_part()[..., array_ops.newaxis]] diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list) diagonal = array_ops.concat(diag_list, axis=-2) return array_ops.squeeze(diagonal, axis=-1)
def operator_and_matrix( self, shape_info, dtype, use_placeholder, ensure_self_adjoint_and_pd=False): expected_blocks = ( shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__ else [[list(shape_info.shape)]]) matrices = [] for i, row_shapes in enumerate(expected_blocks): row = [] for j, block_shape in enumerate(row_shapes): if i == j: # operator is on the diagonal row.append( linear_operator_test_util.random_positive_definite_matrix( block_shape, dtype, force_well_conditioned=True)) else: row.append( linear_operator_test_util.random_normal(block_shape, dtype=dtype)) matrices.append(row) lin_op_matrices = matrices if use_placeholder: lin_op_matrices = [[ array_ops.placeholder_with_default( matrix, shape=None) for matrix in row] for row in matrices] operator = block_lower_triangular.LinearOperatorBlockLowerTriangular( [[linalg.LinearOperatorFullMatrix( # pylint:disable=g-complex-comprehension l, is_square=True, is_self_adjoint=True if ensure_self_adjoint_and_pd else None, is_positive_definite=True if ensure_self_adjoint_and_pd else None) for l in row] for row in lin_op_matrices]) # Should be auto-set. self.assertTrue(operator.is_square) # Broadcast the shapes. expected_shape = list(shape_info.shape) broadcasted_matrices = linear_operator_util.broadcast_matrix_batch_dims( [op for row in matrices for op in row]) # pylint: disable=g-complex-comprehension matrices = [broadcasted_matrices[i * (i + 1) // 2:(i + 1) * (i + 2) // 2] for i in range(len(matrices))] block_lower_triangular_dense = _block_lower_triangular_dense( expected_shape, matrices) if not use_placeholder: block_lower_triangular_dense.set_shape(expected_shape) return operator, block_lower_triangular_dense
def _solve(self, rhs, adjoint=False, adjoint_arg=False): split_dim = -1 if adjoint_arg else -2 # Split input by rows normally, and otherwise columns. split_rhs = self._split_input_into_blocks(rhs, axis=split_dim) solution_list = [] for index, operator in enumerate(self.operators): solution_list += [operator.solve( split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)] solution_list = linear_operator_util.broadcast_matrix_batch_dims( solution_list) return array_ops.concat(solution_list, axis=-2)
def _matmul(self, x, adjoint=False, adjoint_arg=False): split_dim = -1 if adjoint_arg else -2 # Split input by columns if adjoint_arg is True, else rows split_x = self._split_input_into_blocks(x, axis=split_dim) result_list = [] # Iterate over row-partitions (i.e. column-partitions of the adjoint). if adjoint: for index in range(len(self.operators)): # Begin with the operator on the diagonal and apply it to the respective # `rhs` block. result = self.operators[index][index].matmul( split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg) # Iterate top to bottom over the operators in the remainder of the # column-partition (i.e. left to right over the row-partition of the # adjoint), apply the operator to the respective `rhs` block and # accumulate the sum. For example, given the # `LinearOperatorBlockLowerTriangular`: # # op = [[A, 0, 0], # [B, C, 0], # [D, E, F]] # # if `index = 1`, the following loop calculates: # `y_1 = (C.matmul(x_1, adjoint=adjoint) + # E.matmul(x_2, adjoint=adjoint)`, # where `x_1` and `x_2` are splits of `x`. for j in range(index + 1, len(self.operators)): result += self.operators[j][index].matmul( split_x[j], adjoint=adjoint, adjoint_arg=adjoint_arg) result_list.append(result) else: for row in self.operators: # Begin with the left-most operator in the row-partition and apply it to # the first `rhs` block. result = row[0].matmul( split_x[0], adjoint=adjoint, adjoint_arg=adjoint_arg) # Iterate left to right over the operators in the remainder of the row # partition, apply the operator to the respective `rhs` block, and # accumulate the sum. for j, operator in enumerate(row[1:]): result += operator.matmul( split_x[j + 1], adjoint=adjoint, adjoint_arg=adjoint_arg) result_list.append(result) result_list = linear_operator_util.broadcast_matrix_batch_dims( result_list) return array_ops.concat(result_list, axis=-2)
def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) expected_blocks = (build_info.__dict__["blocks"] if "blocks" in build_info.__dict__ else [shape]) diag_matrices = [ linear_operator_test_util.random_uniform(shape=block_shape[:-1], minval=1., maxval=20., dtype=dtype) for block_shape in expected_blocks ] if use_placeholder: diag_matrices_ph = [ array_ops.placeholder(dtype=dtype) for _ in expected_blocks ] diag_matrices = self.evaluate(diag_matrices) # Evaluate here because (i) you cannot feed a tensor, and (ii) # values are random and we want the same value used for both mat and # feed_dict. operator = block_diag.LinearOperatorBlockDiag( [linalg.LinearOperatorDiag(m_ph) for m_ph in diag_matrices_ph]) feed_dict = { m_ph: m for (m_ph, m) in zip(diag_matrices_ph, diag_matrices) } else: operator = block_diag.LinearOperatorBlockDiag( [linalg.LinearOperatorDiag(m) for m in diag_matrices]) feed_dict = None # Should be auto-set. self.assertTrue(operator.is_square) # Broadcast the shapes. expected_shape = list(build_info.shape) matrices = linear_operator_util.broadcast_matrix_batch_dims([ array_ops.matrix_diag(diag_block) for diag_block in diag_matrices ]) block_diag_dense = _block_diag_dense(expected_shape, matrices) if not use_placeholder: block_diag_dense.set_shape( expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) return operator, block_diag_dense, feed_dict
def _operator_and_matrix(self, build_info, dtype, use_placeholder, ensure_self_adjoint_and_pd=False): shape = list(build_info.shape) expected_blocks = (build_info.__dict__["blocks"] if "blocks" in build_info.__dict__ else [shape]) matrices = [ linear_operator_test_util.random_positive_definite_matrix( block_shape, dtype, force_well_conditioned=True) for block_shape in expected_blocks ] lin_op_matrices = matrices if use_placeholder: lin_op_matrices = [ array_ops.placeholder_with_default(matrix, shape=None) for matrix in matrices ] operator = block_diag.LinearOperatorBlockDiag([ linalg.LinearOperatorFullMatrix( l, is_square=True, is_self_adjoint=True if ensure_self_adjoint_and_pd else None, is_positive_definite=True if ensure_self_adjoint_and_pd else None) for l in lin_op_matrices ]) # Should be auto-set. self.assertTrue(operator.is_square) # Broadcast the shapes. expected_shape = list(build_info.shape) matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) block_diag_dense = _block_diag_dense(expected_shape, matrices) if not use_placeholder: block_diag_dense.set_shape( expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) return operator, block_diag_dense
def test_static_dims_broadcast_second_arg_higher_rank(self): # x.batch_shape = [1, 2] # y.batch_shape = [1, 3, 1] # broadcast batch shape = [1, 3, 2] x = rng.rand(1, 2, 1, 5) y = rng.rand(1, 3, 2, 3, 7) batch_of_zeros = np.zeros((1, 3, 2, 1, 1)) x_bc_expected = x + batch_of_zeros y_bc_expected = y + batch_of_zeros x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y]) self.assertAllEqual(x_bc_expected.shape, x_bc.shape) self.assertAllEqual(y_bc_expected.shape, y_bc.shape) x_bc_, y_bc_ = self.evaluate([x_bc, y_bc]) self.assertAllClose(x_bc_expected, x_bc_) self.assertAllClose(y_bc_expected, y_bc_)
def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) expected_blocks = (build_info.__dict__["blocks"] if "blocks" in build_info.__dict__ else [shape]) matrices = [ linear_operator_test_util.random_positive_definite_matrix( block_shape, dtype, force_well_conditioned=True) for block_shape in expected_blocks ] if use_placeholder: matrices_ph = [ array_ops.placeholder(dtype=dtype) for _ in expected_blocks ] # Evaluate here because (i) you cannot feed a tensor, and (ii) # values are random and we want the same value used for both mat and # feed_dict. matrices = self.evaluate(matrices) operator = block_diag.LinearOperatorBlockDiag([ linalg.LinearOperatorFullMatrix(m_ph, is_square=True) for m_ph in matrices_ph ], is_square=True) feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} else: operator = block_diag.LinearOperatorBlockDiag([ linalg.LinearOperatorFullMatrix(m, is_square=True) for m in matrices ]) feed_dict = None # Should be auto-set. self.assertTrue(operator.is_square) # Broadcast the shapes. expected_shape = list(build_info.shape) matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) block_diag_dense = _block_diag_dense(expected_shape, matrices) if not use_placeholder: block_diag_dense.set_shape( expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) return operator, block_diag_dense, feed_dict
def test_static_dims_broadcast(self): # x.batch_shape = [3, 1, 2] # y.batch_shape = [4, 1] # broadcast batch shape = [3, 4, 2] x = rng.rand(3, 1, 2, 1, 5) y = rng.rand(4, 1, 3, 7) batch_of_zeros = np.zeros((3, 4, 2, 1, 1)) x_bc_expected = x + batch_of_zeros y_bc_expected = y + batch_of_zeros x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y]) with self.cached_session() as sess: self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape()) self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape()) x_bc_, y_bc_ = self.evaluate([x_bc, y_bc]) self.assertAllClose(x_bc_expected, x_bc_) self.assertAllClose(y_bc_expected, y_bc_)
def _broadcast_batch_dims(self, x, spectrum): """Broadcast batch dims of batch matrix `x` and spectrum.""" # spectrum.shape = batch_shape + block_shape # First make spectrum a batch matrix with # spectrum.shape = batch_shape + [prod(block_shape), 1] spec_mat = array_ops.reshape( spectrum, array_ops.concat( (self.batch_shape_tensor(), [-1, 1]), axis=0)) # Second, broadcast, possibly requiring an addition of array of zeros. x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims((x, spec_mat)) # Third, put the block shape back into spectrum. batch_shape = array_ops.shape(x)[:-2] spectrum = array_ops.reshape( spec_mat, array_ops.concat((batch_shape, self.block_shape_tensor()), axis=0)) return x, spectrum
def test_dynamic_dims_broadcast_32bit_second_arg_higher_rank(self): # x.batch_shape = [1, 2] # y.batch_shape = [3, 4, 1] # broadcast batch shape = [3, 4, 2] x = rng.rand(1, 2, 1, 5).astype(np.float32) y = rng.rand(3, 4, 1, 3, 7).astype(np.float32) batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32) x_bc_expected = x + batch_of_zeros y_bc_expected = y + batch_of_zeros x_ph = array_ops.placeholder_with_default(x, shape=None) y_ph = array_ops.placeholder_with_default(y, shape=None) x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph]) x_bc_, y_bc_ = self.evaluate([x_bc, y_bc]) self.assertAllClose(x_bc_expected, x_bc_) self.assertAllClose(y_bc_expected, y_bc_)
def test_static_dims_broadcast_second_arg_higher_rank(self): # x.batch_shape = [1, 2] # y.batch_shape = [1, 3, 1] # broadcast batch shape = [1, 3, 2] x = rng.rand(1, 2, 1, 5) y = rng.rand(1, 3, 2, 3, 7) batch_of_zeros = np.zeros((1, 3, 2, 1, 1)) x_bc_expected = x + batch_of_zeros y_bc_expected = y + batch_of_zeros x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y]) with self.cached_session() as sess: self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape()) self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape()) x_bc_, y_bc_ = sess.run([x_bc, y_bc]) self.assertAllClose(x_bc_expected, x_bc_) self.assertAllClose(y_bc_expected, y_bc_)
def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) expected_blocks = ( build_info.__dict__["blocks"] if "blocks" in build_info.__dict__ else [shape]) matrices = [ linear_operator_test_util.random_positive_definite_matrix( block_shape, dtype, force_well_conditioned=True) for block_shape in expected_blocks ] if use_placeholder: matrices_ph = [ array_ops.placeholder(dtype=dtype) for _ in expected_blocks ] # Evaluate here because (i) you cannot feed a tensor, and (ii) # values are random and we want the same value used for both mat and # feed_dict. matrices = self.evaluate(matrices) operator = block_diag.LinearOperatorBlockDiag( [linalg.LinearOperatorFullMatrix( m_ph, is_square=True) for m_ph in matrices_ph], is_square=True) feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} else: operator = block_diag.LinearOperatorBlockDiag( [linalg.LinearOperatorFullMatrix( m, is_square=True) for m in matrices]) feed_dict = None # Should be auto-set. self.assertTrue(operator.is_square) # Broadcast the shapes. expected_shape = list(build_info.shape) matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) block_diag_dense = _block_diag_dense(expected_shape, matrices) if not use_placeholder: block_diag_dense.set_shape( expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) return operator, block_diag_dense, feed_dict
def operator_and_matrix( self, shape_info, dtype, use_placeholder, ensure_self_adjoint_and_pd=False): shape = list(shape_info.shape) expected_blocks = ( shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__ else [shape]) matrices = [ linear_operator_test_util.random_positive_definite_matrix( block_shape, dtype, force_well_conditioned=True) for block_shape in expected_blocks ] lin_op_matrices = matrices if use_placeholder: lin_op_matrices = [ array_ops.placeholder_with_default( matrix, shape=None) for matrix in matrices] operator = block_diag.LinearOperatorBlockDiag( [linalg.LinearOperatorFullMatrix( l, is_square=True, is_self_adjoint=True if ensure_self_adjoint_and_pd else None, is_positive_definite=True if ensure_self_adjoint_and_pd else None) for l in lin_op_matrices]) # Should be auto-set. self.assertTrue(operator.is_square) # Broadcast the shapes. expected_shape = list(shape_info.shape) matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) block_diag_dense = _block_diag_dense(expected_shape, matrices) if not use_placeholder: block_diag_dense.set_shape( expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) return operator, block_diag_dense
def test_dynamic_dims_broadcast_32bit_second_arg_higher_rank(self): # x.batch_shape = [1, 2] # y.batch_shape = [3, 4, 1] # broadcast batch shape = [3, 4, 2] x = rng.rand(1, 2, 1, 5).astype(np.float32) y = rng.rand(3, 4, 1, 3, 7).astype(np.float32) batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32) x_bc_expected = x + batch_of_zeros y_bc_expected = y + batch_of_zeros x_ph = array_ops.placeholder(dtypes.float32) y_ph = array_ops.placeholder(dtypes.float32) x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph]) with self.cached_session() as sess: x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y}) self.assertAllClose(x_bc_expected, x_bc_) self.assertAllClose(y_bc_expected, y_bc_)
def operator_and_matrix(self, shape_info, dtype, use_placeholder, ensure_self_adjoint_and_pd=False): del ensure_self_adjoint_and_pd shape = list(shape_info.shape) expected_blocks = (shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__ else [shape]) matrices = [ linear_operator_test_util.random_normal(block_shape, dtype=dtype) for block_shape in expected_blocks ] lin_op_matrices = matrices if use_placeholder: lin_op_matrices = [ array_ops.placeholder_with_default(matrix, shape=None) for matrix in matrices ] blocks = [] for l in lin_op_matrices: blocks.append( linalg.LinearOperatorFullMatrix(l, is_square=False, is_self_adjoint=False, is_positive_definite=False)) operator = block_diag.LinearOperatorBlockDiag(blocks) # Broadcast the shapes. expected_shape = list(shape_info.shape) matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) block_diag_dense = _block_diag_dense(expected_shape, matrices) if not use_placeholder: block_diag_dense.set_shape(expected_shape) return operator, block_diag_dense
def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) expected_blocks = ( build_info.__dict__["blocks"] if "blocks" in build_info.__dict__ else [shape]) diag_matrices = [ linear_operator_test_util.random_uniform( shape=block_shape[:-1], minval=1., maxval=20., dtype=dtype) for block_shape in expected_blocks ] if use_placeholder: diag_matrices_ph = [ array_ops.placeholder(dtype=dtype) for _ in expected_blocks ] diag_matrices = self.evaluate(diag_matrices) # Evaluate here because (i) you cannot feed a tensor, and (ii) # values are random and we want the same value used for both mat and # feed_dict. operator = block_diag.LinearOperatorBlockDiag( [linalg.LinearOperatorDiag(m_ph) for m_ph in diag_matrices_ph]) feed_dict = {m_ph: m for (m_ph, m) in zip( diag_matrices_ph, diag_matrices)} else: operator = block_diag.LinearOperatorBlockDiag( [linalg.LinearOperatorDiag(m) for m in diag_matrices]) feed_dict = None # Should be auto-set. self.assertTrue(operator.is_square) # Broadcast the shapes. expected_shape = list(build_info.shape) matrices = linear_operator_util.broadcast_matrix_batch_dims( [array_ops.matrix_diag(diag_block) for diag_block in diag_matrices]) block_diag_dense = _block_diag_dense(expected_shape, matrices) if not use_placeholder: block_diag_dense.set_shape( expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) return operator, block_diag_dense, feed_dict
def operator_and_matrix(self, build_info, dtype, use_placeholder, ensure_self_adjoint_and_pd=False): # Kronecker products constructed below will be from symmetric # positive-definite matrices. del ensure_self_adjoint_and_pd shape = list(build_info.shape) expected_factors = build_info.__dict__["factors"] matrices = [ linear_operator_test_util.random_positive_definite_matrix( block_shape, dtype, force_well_conditioned=True) for block_shape in expected_factors ] lin_op_matrices = matrices if use_placeholder: lin_op_matrices = [ array_ops.placeholder_with_default(m, shape=None) for m in matrices ] operator = kronecker.LinearOperatorKronecker([ linalg.LinearOperatorFullMatrix(l, is_square=True, is_self_adjoint=True, is_positive_definite=True) for l in lin_op_matrices ]) matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) kronecker_dense = _kronecker_dense(matrices) if not use_placeholder: kronecker_dense.set_shape(shape) return operator, kronecker_dense
def _to_dense(self): num_cols = 0 dense_rows = [] flat_broadcast_operators = linear_operator_util.broadcast_matrix_batch_dims( [op.to_dense() for row in self.operators for op in row]) # pylint: disable=g-complex-comprehension broadcast_operators = [ flat_broadcast_operators[i * (i + 1) // 2:(i + 1) * (i + 2) // 2] for i in range(len(self.operators))] for row_blocks in broadcast_operators: batch_row_shape = array_ops.shape(row_blocks[0])[:-1] num_cols += array_ops.shape(row_blocks[-1])[-1] zeros_to_pad_after_shape = array_ops.concat( [batch_row_shape, [self.domain_dimension_tensor() - num_cols]], axis=-1) zeros_to_pad_after = array_ops.zeros( shape=zeros_to_pad_after_shape, dtype=self.dtype) row_blocks.append(zeros_to_pad_after) dense_rows.append(array_ops.concat(row_blocks, axis=-1)) mat = array_ops.concat(dense_rows, axis=-2) mat.set_shape(self.shape) return mat
def _operator_and_matrix( self, build_info, dtype, use_placeholder, ensure_self_adjoint_and_pd=False): # Kronecker products constructed below will be from symmetric # positive-definite matrices. del ensure_self_adjoint_and_pd shape = list(build_info.shape) expected_factors = build_info.__dict__["factors"] matrices = [ linear_operator_test_util.random_positive_definite_matrix( block_shape, dtype, force_well_conditioned=True) for block_shape in expected_factors ] lin_op_matrices = matrices if use_placeholder: lin_op_matrices = [ array_ops.placeholder_with_default(m, shape=None) for m in matrices] operator = kronecker.LinearOperatorKronecker( [linalg.LinearOperatorFullMatrix( l, is_square=True, is_self_adjoint=True, is_positive_definite=True) for l in lin_op_matrices]) matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) kronecker_dense = _kronecker_dense(matrices) if not use_placeholder: kronecker_dense.set_shape(shape) return operator, kronecker_dense
def test_zero_batch_matrices_returned_as_empty_list(self): self.assertAllEqual([], linear_operator_util.broadcast_matrix_batch_dims( []))
def test_zero_batch_matrices_returned_as_empty_list(self): self.assertAllEqual([], linear_operator_util.broadcast_matrix_batch_dims([]))
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. The returned `Tensor` will be close to an exact solution if `A` is well conditioned. Otherwise closeness will vary. See class docstring for details. Examples: ```python # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] operator = LinearOperator(...) operator.shape = [..., M, N] # Solve R > 0 linear systems for every member of the batch. RHS = ... # shape [..., M, R] X = operator.solve(RHS) # X[..., :, r] is the solution to the r'th linear system # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] operator.matmul(X) ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator and compatible shape, or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated like a [batch] matrices meaning for every set of leading dimensions, the last two dimensions defines a matrix. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` is the hermitian transpose (transposition and complex conjugation). name: A name scope to use for ops added by this method. Returns: `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. Raises: NotImplementedError: If `self.is_non_singular` or `is_square` is False. """ if self.is_non_singular is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "be singular.") if self.is_square is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "not be square.") if isinstance(rhs, linear_operator.LinearOperator): left_operator = self.adjoint() if adjoint else self right_operator = rhs.adjoint() if adjoint_arg else rhs if (right_operator.range_dimension is not None and left_operator.domain_dimension is not None and right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `rhs` to have dimension" " {} but got {}.".format(left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): return linear_operator_algebra.solve(left_operator, right_operator) with self._name_scope(name): block_dimensions = (self._block_domain_dimensions() if adjoint else self._block_range_dimensions()) arg_dim = -1 if adjoint_arg else -2 blockwise_arg = linear_operator_util.arg_is_blockwise( block_dimensions, rhs, arg_dim) if blockwise_arg: split_rhs = rhs for i, block in enumerate(split_rhs): if not isinstance(block, linear_operator.LinearOperator): block = ops.convert_to_tensor_v2_with_dispatch(block) self._check_input_dtype(block) block_dimensions[i].assert_is_compatible_with( block.shape[arg_dim]) split_rhs[i] = block else: rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") self._check_input_dtype(rhs) op_dimension = (self.domain_dimension if adjoint else self.range_dimension) op_dimension.assert_is_compatible_with(rhs.shape[arg_dim]) split_dim = -1 if adjoint_arg else -2 # Split input by rows normally, and otherwise columns. split_rhs = linear_operator_util.split_arg_into_blocks( self._block_domain_dimensions(), self._block_domain_dimension_tensors, rhs, axis=split_dim) solution_list = [] for index, operator in enumerate(self.operators): solution_list += [ operator.solve(split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg) ] if blockwise_arg: return solution_list solution_list = linear_operator_util.broadcast_matrix_batch_dims( solution_list) return array_ops.concat(solution_list, axis=-2)
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. The returned `Tensor` will be close to an exact solution if `A` is well conditioned. Otherwise closeness will vary. See class docstring for details. Given the blockwise `n + 1`-by-`n + 1` linear operator: op = [[A_00 0 ... 0 ... 0], [A_10 A_11 ... 0 ... 0], ... [A_k0 A_k1 ... A_kk ... 0], ... [A_n0 A_n1 ... A_nk ... A_nn]] we find `x = op.solve(y)` by observing that `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)` and therefore `x_k = A_kk.solve(y_k - A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))` where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x` and `y` along their appropriate axes. We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`. The adjoint case is solved similarly, beginning with `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards. Examples: ```python # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] operator = LinearOperator(...) operator.shape = [..., M, N] # Solve R > 0 linear systems for every member of the batch. RHS = ... # shape [..., M, R] X = operator.solve(RHS) # X[..., :, r] is the solution to the r'th linear system # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] operator.matmul(X) ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator and compatible shape, or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices meaning for every set of leading dimensions, the last two dimensions defines a matrix. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` is the hermitian transpose (transposition and complex conjugation). name: A name scope to use for ops added by this method. Returns: `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. Raises: NotImplementedError: If `self.is_non_singular` or `is_square` is False. """ if self.is_non_singular is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "be singular.") if self.is_square is False: raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "not be square.") if isinstance(rhs, linear_operator.LinearOperator): left_operator = self.adjoint() if adjoint else self right_operator = rhs.adjoint() if adjoint_arg else rhs if (right_operator.range_dimension is not None and left_operator.domain_dimension is not None and right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `rhs` to have dimension" " {} but got {}.".format(left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): # pylint: disable=not-callable return linear_operator_algebra.solve(left_operator, right_operator) with self._name_scope(name): # pylint: disable=not-callable block_dimensions = (self._block_domain_dimensions() if adjoint else self._block_range_dimensions()) arg_dim = -1 if adjoint_arg else -2 blockwise_arg = linear_operator_util.arg_is_blockwise( block_dimensions, rhs, arg_dim) if blockwise_arg: for i, block in enumerate(rhs): if not isinstance(block, linear_operator.LinearOperator): block = ops.convert_to_tensor_v2_with_dispatch(block) self._check_input_dtype(block) block_dimensions[i].assert_is_compatible_with( block.shape[arg_dim]) rhs[i] = block if adjoint_arg: split_rhs = [linalg.adjoint(y) for y in rhs] else: split_rhs = rhs else: rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") self._check_input_dtype(rhs) op_dimension = (self.domain_dimension if adjoint else self.range_dimension) op_dimension.assert_is_compatible_with(rhs.shape[arg_dim]) rhs = linalg.adjoint(rhs) if adjoint_arg else rhs split_rhs = linear_operator_util.split_arg_into_blocks( self._block_domain_dimensions(), self._block_domain_dimension_tensors, rhs, axis=-2) solution_list = [] if adjoint: # For an adjoint blockwise lower-triangular linear operator, the system # must be solved bottom to top. Iterate backwards over rows of the # adjoint (i.e. columns of the non-adjoint operator). for index in reversed(range(len(self.operators))): y = split_rhs[index] # Iterate top to bottom over the operators in the off-diagonal portion # of the column-partition (i.e. row-partition of the adjoint), apply # the operator to the respective block of the solution found in # previous iterations, and subtract the result from the `rhs` block. # For example,let `A`, `B`, and `D` be the linear operators in the top # row-partition of the adjoint of # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`, # and `x_1` and `x_2` be blocks of the solution found in previous # iterations of the outer loop. The following loop (when `index == 0`) # expresses # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where # `y_0* = y_0 - Bx_1 - Dx_2`. for j in reversed(range(index + 1, len(self.operators))): y = y - self.operators[j][index].matmul( solution_list[len(self.operators) - 1 - j], adjoint=adjoint) # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`. solution_list.append(self._diagonal_operators[index].solve( y, adjoint=adjoint)) solution_list.reverse() else: # Iterate top to bottom over the row-partitions. for row, y in zip(self.operators, split_rhs): # Iterate left to right over the operators in the off-diagonal portion # of the row-partition, apply the operator to the block of the # solution found in previous iterations, and subtract the result from # the `rhs` block. For example, let `D`, `E`, and `F` be the linear # operators in the bottom row-partition of # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and # `x_0` and `x_1` be blocks of the solution found in previous # iterations of the outer loop. The following loop # (when `index == 2`), expresses # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where # `y_2* = y_2 - D_x0 - Ex_1`. for i, operator in enumerate(row[:-1]): y = y - operator.matmul(solution_list[i], adjoint=adjoint) # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`. solution_list.append(row[-1].solve(y, adjoint=adjoint)) if blockwise_arg: return solution_list solution_list = linear_operator_util.broadcast_matrix_batch_dims( solution_list) return array_ops.concat(solution_list, axis=-2)