def _matmul(self, x, adjoint=False, adjoint_arg=False): if self._assert_proper_shapes: x = linalg.adjoint(x) if adjoint_arg else x aps = linear_operator_util.assert_compatible_matrix_dimensions( self, x) x = control_flow_ops.with_dependencies([aps], x) if self.is_square: # Note that adjoint has no effect since this matrix is self-adjoint. if adjoint_arg: output_shape = array_ops.concat([ array_ops.shape(x)[:-2], [array_ops.shape(x)[-1], array_ops.shape(x)[-2]] ], axis=0) else: output_shape = array_ops.shape(x) return self._possibly_broadcast_batch_shape( array_ops.zeros(shape=output_shape, dtype=x.dtype)) x_shape = array_ops.shape(x) n = self._num_columns if adjoint else self._num_rows m = x_shape[-2] if adjoint_arg else x_shape[-1] output_shape = array_ops.concat([x_shape[:-2], [n, m]], axis=0) zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype) return self._possibly_broadcast_batch_shape(zeros)
def _to_dense(self): num_cols = 0 rows = [] broadcasted_blocks = [ operator.to_dense() for operator in self.operators ] broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims( broadcasted_blocks) for block in broadcasted_blocks: batch_row_shape = array_ops.shape(block)[:-1] zeros_to_pad_before_shape = array_ops.concat( [batch_row_shape, [num_cols]], axis=-1) zeros_to_pad_before = array_ops.zeros( shape=zeros_to_pad_before_shape, dtype=block.dtype) num_cols += array_ops.shape(block)[-1] zeros_to_pad_after_shape = array_ops.concat( [batch_row_shape, [self.domain_dimension_tensor() - num_cols]], axis=-1) zeros_to_pad_after = array_ops.zeros( shape=zeros_to_pad_after_shape, dtype=block.dtype) rows.append( array_ops.concat( [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1)) mat = array_ops.concat(rows, axis=-2) mat.set_shape(_ops.TensorShape(self.shape)) return mat
def _rotate_last_dim(x, rotate_right=False): """Rotate the last dimension either left or right.""" ndims = array_ops.rank(x) if rotate_right: transpose_perm = array_ops.concat( [[ndims - 1], math_ops.range(0, ndims - 1)], axis=0) else: transpose_perm = array_ops.concat( [math_ops.range(1, ndims), [0]], axis=0) return array_ops.transpose(x, transpose_perm)
def _possibly_broadcast_batch_shape(self, x): """Return 'x', possibly after broadcasting the leading dimensions.""" # If we have no batch shape, our batch shape broadcasts with everything! if self._batch_shape_arg is None: return x # Static attempt: # If we determine that no broadcast is necessary, pass x through # If we need a broadcast, add to an array of zeros. # # special_shape is the shape that, when broadcast with x's shape, will give # the correct broadcast_shape. Note that # We have already verified the second to last dimension of _ops.TensorShape(self.shape) # matches x's shape in assert_compatible_matrix_dimensions. # Also, the final dimension of 'x' can have any shape. # Therefore, the final two dimensions of special_shape are 1's. special_shape = self.batch_shape.concatenate([1, 1]) bshape = _ops.broadcast_static_shape_as_tensorshape( _ops.TensorShape(x.shape), special_shape) if special_shape.is_fully_defined(): # bshape.is_fully_defined iff special_shape.is_fully_defined. if bshape == _ops.TensorShape(x.shape): return x # Use the built in broadcasting of addition. zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) return x + zeros # Dynamic broadcast: # Always add to an array of zeros, rather than using a "cond", since a # cond would require copying data from GPU --> CPU. special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0) zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) return x + zeros
def _unvec_by(y, num_col): """Unstack vector to form a matrix, with a specified amount of columns.""" return _linalg.matrix_transpose( array_ops.reshape( y, array_ops.concat( [array_ops.shape(y)[:-1], [num_col, -1]], axis=0)))
def _vectorize_then_blockify(self, matrix): """Shape batch matrix to batch vector, then blockify trailing dimensions.""" # Suppose # _ops.TensorShape(matrix.shape) = [m0, m1, m2, m3], # and matrix is a matrix because the final two dimensions are matrix dims. # self.block_depth = 2, # self.block_shape = [b0, b1] (note b0 * b1 = m2). # We will reshape matrix to # [m3, m0, m1, b0, b1]. # Vectorize: Reshape to batch vector. # [m0, m1, m2, m3] --> [m3, m0, m1, m2] # This is called "vectorize" because we have taken the final two matrix dims # and turned this into a size m3 batch of vectors. vec = distribution_util.rotate_transpose(matrix, shift=1) # Blockify: Blockfy trailing dimensions. # [m3, m0, m1, m2] --> [m3, m0, m1, b0, b1] if (_ops.TensorShape(vec.shape).is_fully_defined() and self.block_shape.is_fully_defined()): # vec_leading_shape = [m3, m0, m1], # the parts of vec that will not be blockified. vec_leading_shape = _ops.TensorShape(vec.shape)[:-1] final_shape = vec_leading_shape.concatenate(self.block_shape) else: vec_leading_shape = array_ops.shape(vec)[:-1] final_shape = array_ops.concat( (vec_leading_shape, self.block_shape_tensor()), 0) return array_ops.reshape(vec, final_shape)
def _shape_tensor(self): matrix_shape = array_ops.stack((self._num_rows, self._num_columns), axis=0) if self._batch_shape_arg is None: return matrix_shape return array_ops.concat((self._batch_shape_arg, matrix_shape), 0)
def _shape_tensor(self): # See _ops.TensorShape(self.shape) for explanation of steps s_shape = array_ops.shape(self._spectrum) batch_shape = s_shape[:-self.block_depth] trailing_dims = s_shape[-self.block_depth:] n = math_ops.reduce_prod(trailing_dims) n_x_n = [n, n] return array_ops.concat((batch_shape, n_x_n), 0)
def _diag_part(self): diag_list = [] for operator in self.operators: # Extend the axis for broadcasting. diag_list += [operator.diag_part()[..., array_ops.newaxis]] diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list) diagonal = array_ops.concat(diag_list, axis=-2) return array_ops.squeeze(diagonal, axis=-1)
def __init__(self, col, row, is_non_singular=None, is_self_adjoint=None, is_positive_definite=None, is_square=None, name="LinearOperatorToeplitz"): r"""Initialize a `LinearOperatorToeplitz`. Args: col: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. The first column of the operator. Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, `complex128`. Note that the first entry of `col` is assumed to be the same as the first entry of `row`. row: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. The first row of the operator. Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, `complex128`. Note that the first entry of `row` is assumed to be the same as the first entry of `col`. is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. If `diag.dtype` is real, this is auto-set to `True`. is_positive_definite: Expect that this operator is positive definite, meaning the quadratic form `x^H A x` has positive real part for all nonzero `x`. Note that we do not require the operator to be self-adjoint to be positive-definite. See: https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices is_square: Expect that this operator acts like square [batch] matrices. name: A name for this `LinearOperator`. """ with ops.name_scope(name, values=[row, col]): self._row = ops.convert_to_tensor(row, name="row") self._col = ops.convert_to_tensor(col, name="col") self._check_row_col(self._row, self._col) circulant_col = array_ops.concat( [self._col, array_ops.zeros_like(self._col[..., 0:1]), array_ops.reverse(self._row[..., 1:], axis=[-1])], axis=-1) # To be used for matmul. self._circulant = linear_operator_circulant.LinearOperatorCirculant( fft_ops.fft(_to_complex(circulant_col)), input_output_dtype=self._row.dtype) if is_square is False: # pylint:disable=g-bool-id-comparison raise ValueError("Only square Toeplitz operators currently supported.") is_square = True super(LinearOperatorToeplitz, self).__init__( dtype=self._row.dtype, graph_parents=[self._row, self._col], is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name)
def _broadcast_batch_dims(self, x, spectrum): """Broadcast batch dims of batch matrix `x` and spectrum.""" # _ops.TensorShape(spectrum.shape) = batch_shape + block_shape # First make spectrum a batch matrix with # _ops.TensorShape(spectrum.shape) = batch_shape + [prod(block_shape), 1] spec_mat = array_ops.reshape( spectrum, array_ops.concat((self.batch_shape_tensor(), [-1, 1]), axis=0)) # Second, broadcast, possibly requiring an addition of array of zeros. x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims( (x, spec_mat)) # Third, put the block shape back into spectrum. batch_shape = array_ops.shape(x)[:-2] spectrum = array_ops.reshape( spec_mat, array_ops.concat((batch_shape, self.block_shape_tensor()), axis=0)) return x, spectrum
def reshape_inv(y): # Expand the extra dims hanging off the end, "b_extra_sh". # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y # Could have different batch dims than a and b, because of broadcasting. y_extra_shape = array_ops.concat( (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) y_extra_on_end = array_ops.reshape(y, y_extra_shape) inverse_perm = np.argsort(perm) return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
def _zeros_diag(self): """Returns the diagonal of this operator as all zeros.""" if _ops.TensorShape(self.shape).is_fully_defined(): d_shape = self.batch_shape.concatenate([self._min_matrix_dim()]) else: d_shape = array_ops.concat( [self.batch_shape_tensor(), [self._min_matrix_dim_tensor()]], axis=0) return array_ops.zeros(shape=d_shape, dtype=self.dtype)
def _matmul(self, x, adjoint=False, adjoint_arg=False): split_dim = -1 if adjoint_arg else -2 # Split input by rows normally, and otherwise columns. split_x = self._split_input_into_blocks(x, axis=split_dim) result_list = [] for index, operator in enumerate(self.operators): result_list += [ operator.matmul(split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg) ] result_list = linear_operator_util.broadcast_matrix_batch_dims( result_list) return array_ops.concat(result_list, axis=-2)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): split_dim = -1 if adjoint_arg else -2 # Split input by rows normally, and otherwise columns. split_rhs = self._split_input_into_blocks(rhs, axis=split_dim) solution_list = [] for index, operator in enumerate(self.operators): solution_list += [ operator.solve(split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg) ] solution_list = linear_operator_util.broadcast_matrix_batch_dims( solution_list) return array_ops.concat(solution_list, axis=-2)
def _diag_part(self): diag_part = self.operators[0].diag_part() for operator in self.operators[1:]: diag_part = diag_part[..., :, array_ops.newaxis] op_diag_part = operator.diag_part()[..., array_ops.newaxis, :] diag_part *= op_diag_part diag_part = array_ops.reshape( diag_part, shape=array_ops.concat( [array_ops.shape(diag_part)[:-2], [-1]], axis=0)) if self.range_dimension > self.domain_dimension: diag_dimension = self.domain_dimension else: diag_dimension = self.range_dimension diag_part.set_shape( self.batch_shape.concatenate(diag_dimension)) return diag_part
def _shape_tensor(self): domain_dimension = self.operators[0].domain_dimension_tensor() for operator in self.operators[1:]: domain_dimension *= operator.domain_dimension_tensor() range_dimension = self.operators[0].range_dimension_tensor() for operator in self.operators[1:]: range_dimension *= operator.range_dimension_tensor() matrix_shape = [range_dimension, domain_dimension] # Get broadcast batch shape. # broadcast_shape checks for compatibility. batch_shape = self.operators[0].batch_shape_tensor() for operator in self.operators[1:]: batch_shape = array_ops.broadcast_dynamic_shape( batch_shape, operator.batch_shape_tensor()) return array_ops.concat((batch_shape, matrix_shape), 0)
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Given a Toeplitz matrix, we can embed it in a Circulant matrix to perform # efficient matrix multiplications. Given a Toeplitz matrix with first row # [t_0, t_1, ... t_{n-1}] and first column [t0, t_{-1}, ..., t_{-(n-1)}, # let C by the circulant matrix with first column [t0, t_{-1}, ..., # t_{-(n-1)}, 0, t_{n-1}, ..., t_1]. Also adjoin to our input vector `x` # `n` zeros, to make it a vector of length `2n` (call it y). It can be shown # that if we take the first n entries of `Cy`, this is equal to the Toeplitz # multiplication. See: # http://math.mit.edu/icg/resources/teaching/18.085-spring2015/toeplitz.pdf # for more details. x = linalg.adjoint(x) if adjoint_arg else x expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2) result = self._circulant.matmul( expanded_x, adjoint=adjoint, adjoint_arg=False) return _ops.cast( result[..., :self.domain_dimension_tensor(), :], self.dtype)
def _to_dense(self): product = self.operators[0].to_dense() for operator in self.operators[1:]: # Product has shape [B, R1, 1, C1]. product = product[ ..., :, array_ops.newaxis, :, array_ops.newaxis] # Operator has shape [B, 1, R2, 1, C2]. op_to_mul = operator.to_dense()[ ..., array_ops.newaxis, :, array_ops.newaxis, :] # This is now [B, R1, R2, C1, C2]. product *= op_to_mul # Now merge together dimensions to get [B, R1 * R2, C1 * C2]. product = array_ops.reshape( product, shape=array_ops.concat( [array_ops.shape(product)[:-4], [array_ops.shape(product)[-4] * array_ops.shape(product)[-3], array_ops.shape(product)[-2] * array_ops.shape(product)[-1]] ], axis=0)) product.set_shape(_ops.TensorShape(self.shape)) return product
def _shape_tensor(self): # Avoid messy broadcasting if possible. if _ops.TensorShape(self.shape).is_fully_defined(): return ops.convert_to_tensor(_ops.TensorShape( self.shape).as_list(), dtype=dtypes.int32, name="shape") domain_dimension = self.operators[0].domain_dimension_tensor() range_dimension = self.operators[0].range_dimension_tensor() for operator in self.operators[1:]: domain_dimension += operator.domain_dimension_tensor() range_dimension += operator.range_dimension_tensor() matrix_shape = array_ops.stack([domain_dimension, range_dimension]) # Dummy Tensor of zeros. Will never be materialized. zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor()) for operator in self.operators[1:]: zeros += array_ops.zeros(shape=operator.batch_shape_tensor()) batch_shape = array_ops.shape(zeros) return array_ops.concat((batch_shape, matrix_shape), 0)
def _shape_tensor(self): # Avoid messy broadcasting if possible. if _ops.TensorShape(self.shape).is_fully_defined(): return ops.convert_to_tensor(_ops.TensorShape( self.shape).as_list(), dtype=dtypes.int32, name="shape") # Don't check the matrix dimensions. That would add unnecessary Asserts to # the graph. Things will fail at runtime naturally if shapes are # incompatible. matrix_shape = array_ops.stack([ self.operators[0].range_dimension_tensor(), self.operators[-1].domain_dimension_tensor() ]) # Dummy Tensor of zeros. Will never be materialized. zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor()) for operator in self.operators[1:]: zeros += array_ops.zeros(shape=operator.batch_shape_tensor()) batch_shape = array_ops.shape(zeros) return array_ops.concat((batch_shape, matrix_shape), 0)
def _unblockify_then_matricize(self, vec): """Flatten the block dimensions then reshape to a batch matrix.""" # Suppose # _ops.TensorShape(vec.shape) = [v0, v1, v2, v3], # self.block_depth = 2. # Then # leading shape = [v0, v1] # block shape = [v2, v3]. # We will reshape vec to # [v1, v2*v3, v0]. # Un-blockify: Flatten block dimensions. Reshape # [v0, v1, v2, v3] --> [v0, v1, v2*v3]. if _ops.TensorShape(vec.shape).is_fully_defined(): # vec_shape = [v0, v1, v2, v3] vec_shape = _ops.TensorShape(vec.shape).as_list() # vec_leading_shape = [v0, v1] vec_leading_shape = vec_shape[:-self.block_depth] # vec_block_shape = [v2, v3] vec_block_shape = vec_shape[-self.block_depth:] # flat_shape = [v0, v1, v2*v3] flat_shape = vec_leading_shape + [np.prod(vec_block_shape)] else: vec_shape = array_ops.shape(vec) vec_leading_shape = vec_shape[:-self.block_depth] vec_block_shape = vec_shape[-self.block_depth:] flat_shape = array_ops.concat( (vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]), 0) vec_flat = array_ops.reshape(vec, flat_shape) # Matricize: Reshape to batch matrix. # [v0, v1, v2*v3] --> [v1, v2*v3, v0], # representing a shape [v1] batch of [v2*v3, v0] matrices. matrix = distribution_util.rotate_transpose(vec_flat, shift=-1) return matrix
def _shape_tensor(self): d_shape = array_ops.shape(self._diag) k = d_shape[-1] return array_ops.concat((d_shape, [k]), 0)
def _shape_tensor(self): batch_shape = array_ops.broadcast_dynamic_shape( self.base_operator.batch_shape_tensor(), array_ops.shape(self.u)[:-2]) return array_ops.concat( [batch_shape, self.base_operator.shape_tensor()[-2:]], axis=0)
def _vec(x): """Stacks column of matrix to form a single column.""" return array_ops.reshape( _linalg.matrix_transpose(x), array_ops.concat( [array_ops.shape(x)[:-2], [-1]], axis=0))
def _shape_tensor(self): v_shape = array_ops.broadcast_dynamic_shape( array_ops.shape(self.row), array_ops.shape(self.col)) k = v_shape[-1] return array_ops.concat((v_shape, [k]), 0)
def _solve(self, rhs, adjoint=False, adjoint_arg=False): # Here we follow the same use of Roth's column lemma as in `matmul`, with # the key difference that we replace all `matmul` instances with `solve`. # This follows from the property that inv(A x B) = inv(A) x inv(B). # Below we document the shape manipulation for adjoint=False, # adjoint_arg=False, but the general case of different adjoints is still # handled. if adjoint_arg: rhs = linalg.adjoint(rhs) # Always add a batch dimension to enable broadcasting to work. batch_shape = array_ops.concat( [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0) rhs += array_ops.zeros(batch_shape, dtype=rhs.dtype) # rhs has shape [B, R, C], where B represent some number of batch # dimensions, # R represents the number of rows, and C represents the number of columns. # In order to apply Roth's column lemma, we need to operate on a batch of # column vectors, so we reshape into a batch of column vectors. We put it # at the front to ensure that broadcasting between operators to the batch # dimensions B still works. output = _rotate_last_dim(rhs, rotate_right=True) # Also expand the shape to be [A, C, B, R]. The first dimension will be # used to accumulate dimensions from each operator matmul. output = output[array_ops.newaxis, ...] # In this loop, A is going to refer to the value of the accumulated # dimension. A = 1 at the start, and will end up being self.range_dimension. # V will refer to the last dimension. V = R at the start, and will end up # being 1 in the end. for operator in self.operators[:-1]: # Reshape output from [A, C, B, V] to be # [A, C, B, V / op.domain_dimension, op.domain_dimension] if adjoint: operator_dimension = operator.range_dimension_tensor() else: operator_dimension = operator.domain_dimension_tensor() output = _unvec_by(output, operator_dimension) # We are computing (XA^-1^T) = (A^-1 X^T)^T. # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], # which is being converted to: # [A, C, B, V / op.domain_dimension, op.range_dimension] output = _linalg.matrix_transpose(output) output = operator.solve(output, adjoint=adjoint, adjoint_arg=False) output = _linalg.matrix_transpose(output) # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=True) # After the loop, we will have # A = self.range_dimension / op[-1].range_dimension # V = op[-1].domain_dimension # We convert that using matvec to get: # [A, C, B, op[-1].range_dimension] output = self.operators[-1].solvevec(output, adjoint=adjoint) # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=False) if _ops.TensorShape(rhs.shape).is_fully_defined(): column_dim = _ops.TensorShape(rhs.shape)[-1] broadcast_batch_shape = common_shapes.broadcast_shape( _ops.TensorShape(rhs.shape)[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] output.set_shape(broadcast_batch_shape.concatenate( matrix_dimensions)) return output
def broadcast_matrix_batch_dims(batch_matrices, name=None): """Broadcast leading dimensions of zero or more [batch] matrices. Example broadcasting one batch dim of two simple matrices. ```python x = [[1, 2], [3, 4]] # Shape [2, 2], no batch dims y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1] x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) x_bc ==> [[[1, 2], [3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1]. y_bc ==> same as y ``` Example broadcasting many batch dims ```python x = tf.random.normal(shape=(2, 3, 1, 4, 4)) y = tf.random.normal(shape=(1, 3, 2, 5, 5)) x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) _ops.TensorShape(x_bc.shape) ==> (2, 3, 2, 4, 4) _ops.TensorShape(y_bc.shape) ==> (2, 3, 2, 5, 5) ``` Args: batch_matrices: Iterable of `Tensor`s, each having two or more dimensions. name: A string name to prepend to created ops. Returns: bcast_matrices: List of `Tensor`s, with `bcast_matricies[i]` containing the values from `batch_matrices[i]`, with possibly broadcast batch dims. Raises: ValueError: If any input `Tensor` is statically determined to have less than two dimensions. """ with ops.name_scope(name or "broadcast_matrix_batch_dims", values=batch_matrices): check_ops.assert_proper_iterable(batch_matrices) batch_matrices = list(batch_matrices) for i, mat in enumerate(batch_matrices): batch_matrices[i] = ops.convert_to_tensor(mat) assert_is_batch_matrix(batch_matrices[i]) if len(batch_matrices) < 2: return batch_matrices # Try static broadcasting. # bcast_batch_shape is the broadcast batch shape of ALL matrices. # E.g. if batch_matrices = [x, y], with # _ops.TensorShape(x.shape) = [2, j, k] (batch shape = [2]) # _ops.TensorShape(y.shape) = [3, 1, l, m] (batch shape = [3, 1]) # ==> bcast_batch_shape = [3, 2] bcast_batch_shape = _ops.TensorShape(batch_matrices[0].shape)[:-2] for mat in batch_matrices[1:]: bcast_batch_shape = _ops.broadcast_static_shape_as_tensorshape( bcast_batch_shape, _ops.TensorShape(mat.shape)[:-2]) if bcast_batch_shape.is_fully_defined(): for i, mat in enumerate(batch_matrices): if _ops.TensorShape(mat.shape)[:-2] != bcast_batch_shape: bcast_shape = array_ops.concat([ bcast_batch_shape.as_list(), array_ops.shape(mat)[-2:] ], axis=0) batch_matrices[i] = _ops.broadcast_to(mat, bcast_shape) return batch_matrices # Since static didn't work, do dynamic, which always copies data. bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2] for mat in batch_matrices[1:]: bcast_batch_shape = array_ops.broadcast_dynamic_shape( bcast_batch_shape, array_ops.shape(mat)[:-2]) for i, mat in enumerate(batch_matrices): batch_matrices[i] = _ops.broadcast_to( mat, array_ops.concat( [bcast_batch_shape, array_ops.shape(mat)[-2:]], axis=0)) return batch_matrices
def _reshape_for_efficiency(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False): """Maybe reshape a, b, and return an inverse map. For matmul/solve.""" def identity(x): return x # At this point, we have not taken transpose/adjoint of a/b. still_need_to_transpose = True if _ops.TensorShape(a.shape).ndims is None or _ops.TensorShape( b.shape).ndims is None: return a, b, identity, still_need_to_transpose # This could be handled in the future, but seems less common. if _ops.TensorShape(a.shape).ndims >= _ops.TensorShape(b.shape).ndims: return a, b, identity, still_need_to_transpose # From now on, we might modify b, but will not modify a. # Suppose: # _ops.TensorShape(a.shape) = C + [m, n], _ops.TensorShape(b.shape) = # _ops.TensorShape(b.shape) = S + C + [n, r] b_extra_ndims = _ops.TensorShape(b.shape).ndims - _ops.TensorShape( a.shape).ndims # b_extra_sh = S, b_main_sh = C + [n, r] b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # No reason to flip unless the extra dims of b are big enough. Why? # Assume adjoint/transpose = False. Then... # By not flipping, we have to replicate a to shape # b_extra_sh + _ops.TensorShape(a.shape), # which could use extra memory. But in all cases, the final output has shape # b_extra_sh + _ops.TensorShape(a.shape)[:-1] + _ops.TensorShape([b.shape)[-1]] # So we only end up creating a larger object if the end dim of b is smaller # than the end dim of a. This often happens, e.g. if b was a vector that was # expanded to a matrix (by appending a singleton). # Since adjoint/transpose may not be False, we must make adjustments here. # The dim of b that holds the multiple equations. a_domain_sz_ = _ops.TensorShape( a.shape)[-2 if adjoint_a or transpose_a else -1] b_eq_sz_ = _ops.TensorShape( b.shape)[-2 if adjoint_b or transpose_b else -1] b_extra_sz_ = ( np.prod(_ops.TensorShape(b.shape)[:b_extra_ndims].as_list()) if _ops.TensorShape(b.shape)[:b_extra_ndims].is_fully_defined() else None) if (a_domain_sz_ is not None and b_eq_sz_ is not None and b_extra_sz_ is not None): if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_: return a, b, identity, still_need_to_transpose # At this point, we're flipping for sure! # Any transposes/adjoints will happen here explicitly, rather than in calling # code. Why? To avoid having to write separate complex code for each case. if adjoint_a: a = linalg.adjoint(a) elif transpose_a: a = linalg.transpose(a) if adjoint_b: b = linalg.adjoint(b) elif transpose_b: b = linalg.transpose(b) still_need_to_transpose = False # Recompute shapes, since the transpose/adjoint may have changed them. b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # Permutation to put the extra dims at the end. perm = (np.concatenate( (np.arange(b_extra_ndims, _ops.TensorShape(b.shape).ndims), np.arange( 0, b_extra_ndims)), 0)) b_extra_on_end = array_ops.transpose(b, perm=perm) # Now squash this end into one long dim. b_squashed_end = array_ops.reshape( b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0)) def reshape_inv(y): # Expand the extra dims hanging off the end, "b_extra_sh". # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y # Could have different batch dims than a and b, because of broadcasting. y_extra_shape = array_ops.concat( (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) y_extra_on_end = array_ops.reshape(y, y_extra_shape) inverse_perm = np.argsort(perm) return array_ops.transpose(y_extra_on_end, perm=inverse_perm) return a, b_squashed_end, reshape_inv, still_need_to_transpose
def _matmul(self, x, adjoint=False, adjoint_arg=False): # Here we heavily rely on Roth's column Lemma [1]: # (A x B) * vec X = vec BXA^T, # where vec stacks all the columns of the matrix under each other. In our # case, x represents a batch of vec X (i.e. we think of x as a batch of # column vectors, rather than a matrix). Each member of the batch can be # reshaped to a matrix (hence we get a batch of matrices). # We can iteratively apply this lemma by noting that if B is a Kronecker # product, then we can apply the lemma again. # [1] W. E. Roth, "On direct product matrices," # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468, # 1934 # Efficiency # Naively doing the Kronecker product, by calculating the dense matrix and # applying it will can take cubic time in the size of domain_dimension # (assuming a square matrix). The other issue is that calculating the dense # matrix can be prohibitively expensive, in that it can take a large amount # of memory. # # This implementation avoids this memory blow up by only computing matmuls # with the factors. In this way, we don't have to realize the dense matrix. # In terms of complexity, if we have Kronecker Factors of size: # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we # have as input a [N, M] matrix, the naive approach would take O(N^2 M). # With this approach (ignoring reshaping of tensors and transposes for now), # the time complexity can be O(M * (\sum n_i) * N). There is also the # benefit of batched multiplication (In this example, the batch size is # roughly M * N) so this can be much faster. However, not factored in are # the costs of the several transposing of tensors, which can affect cache # behavior. # Below we document the shape manipulation for adjoint=False, # adjoint_arg=False, but the general case of different adjoints is still # handled. if adjoint_arg: x = linalg.adjoint(x) # Always add a batch dimension to enable broadcasting to work. batch_shape = array_ops.concat( [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0) x += array_ops.zeros(batch_shape, dtype=x.dtype) # x has shape [B, R, C], where B represent some number of batch dimensions, # R represents the number of rows, and C represents the number of columns. # In order to apply Roth's column lemma, we need to operate on a batch of # column vectors, so we reshape into a batch of column vectors. We put it # at the front to ensure that broadcasting between operators to the batch # dimensions B still works. output = _rotate_last_dim(x, rotate_right=True) # Also expand the shape to be [A, C, B, R]. The first dimension will be # used to accumulate dimensions from each operator matmul. output = output[array_ops.newaxis, ...] # In this loop, A is going to refer to the value of the accumulated # dimension. A = 1 at the start, and will end up being self.range_dimension. # V will refer to the last dimension. V = R at the start, and will end up # being 1 in the end. for operator in self.operators[:-1]: # Reshape output from [A, C, B, V] to be # [A, C, B, V / op.domain_dimension, op.domain_dimension] if adjoint: operator_dimension = operator.range_dimension_tensor() else: operator_dimension = operator.domain_dimension_tensor() output = _unvec_by(output, operator_dimension) # We are computing (XA^T) = (AX^T)^T. # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], # which is being converted to: # [A, C, B, V / op.domain_dimension, op.range_dimension] output = _linalg.matrix_transpose(output) output = operator.matmul(output, adjoint=adjoint, adjoint_arg=False) output = _linalg.matrix_transpose(output) # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=True) # After the loop, we will have # A = self.range_dimension / op[-1].range_dimension # V = op[-1].domain_dimension # We convert that using matvec to get: # [A, C, B, op[-1].range_dimension] output = self.operators[-1].matvec(output, adjoint=adjoint) # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] output = _rotate_last_dim(output, rotate_right=False) output = _vec(output) output = _rotate_last_dim(output, rotate_right=False) if _ops.TensorShape(x.shape).is_fully_defined(): column_dim = _ops.TensorShape(x.shape)[-1] broadcast_batch_shape = common_shapes.broadcast_shape( _ops.TensorShape(x.shape)[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] output.set_shape(broadcast_batch_shape.concatenate( matrix_dimensions)) return output