def _reshape_for_efficiency(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False): """Maybe reshape a, b, and return an inverse map. For matmul/solve.""" def identity(x): return x # At this point, we have not taken transpose/adjoint of a/b. still_need_to_transpose = True if a.shape.ndims is None or b.shape.ndims is None: return a, b, identity, still_need_to_transpose # This could be handled in the future, but seems less common. if a.shape.ndims >= b.shape.ndims: return a, b, identity, still_need_to_transpose # From now on, we might modify b, but will not modify a. # Suppose: # a.shape = C + [m, n], b.shape = # b.shape = S + C + [n, r] b_extra_ndims = b.shape.ndims - a.shape.ndims # b_extra_sh = S, b_main_sh = C + [n, r] b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # No reason to flip unless the extra dims of b are big enough. Why? # Assume adjoint/transpose = False. Then... # By not flipping, we have to replicate a to shape # b_extra_sh + a.shape, # which could use extra memory. But in all cases, the final output has shape # b_extra_sh + a.shape[:-1] + [b.shape[-1]] # So we only end up creating a larger object if the end dim of b is smaller # than the end dim of a. This often happens, e.g. if b was a vector that was # expanded to a matrix (by appending a singleton). # Since adjoint/transpose may not be False, we must make adjustments here. # The dim of b that holds the multiple equations. a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1] b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1] b_extra_sz_ = (np.prod(b.shape[:b_extra_ndims].as_list()) if b.shape[:b_extra_ndims].is_fully_defined() else None) if (a_domain_sz_ is not None and b_eq_sz_ is not None and b_extra_sz_ is not None): if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_: return a, b, identity, still_need_to_transpose # At this point, we're flipping for sure! # Any transposes/adjoints will happen here explicitly, rather than in calling # code. Why? To avoid having to write separate complex code for each case. if adjoint_a: a = linalg.adjoint(a) elif transpose_a: a = linalg.transpose(a) if adjoint_b: b = linalg.adjoint(b) elif transpose_b: b = linalg.transpose(b) still_need_to_transpose = False # Recompute shapes, since the transpose/adjoint may have changed them. b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # Permutation to put the extra dims at the end. perm = (np.concatenate( (np.arange(b_extra_ndims, b.shape.ndims), np.arange(0, b_extra_ndims)), 0)) b_extra_on_end = array_ops.transpose(b, perm=perm) # Now squash this end into one long dim. b_squashed_end = array_ops.reshape( b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0)) def reshape_inv(y): # Expand the extra dims hanging off the end, "b_extra_sh". # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y # Could have different batch dims than a and b, because of broadcasting. y_extra_shape = array_ops.concat( (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) y_extra_on_end = array_ops.reshape(y, y_extra_shape) inverse_perm = np.argsort(perm) return array_ops.transpose(y_extra_on_end, perm=inverse_perm) return a, b_squashed_end, reshape_inv, still_need_to_transpose
def _solve_matmul_internal(self, x, solve_matmul_fn, adjoint=False, adjoint_arg=False): # We heavily rely on Roth's column Lemma [1]: # (A x B) * vec X = vec BXA^T # where vec stacks all the columns of the matrix under each other. # In our case, we use a variant of the lemma that is row-major # friendly: (A x B) * vec' X = vec' AXB^T # Where vec' reshapes a matrix into a vector. We can repeatedly apply this # for a collection of kronecker products. # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can # use the above to compute multiplications, solves with any composition of # transposes. output = x if adjoint_arg: if self.dtype.is_complex: output = math_ops.conj(output) else: output = linalg.transpose(output) for o in reversed(self.operators): # Statically compute the reshape. if adjoint: operator_dimension = o.range_dimension_tensor() else: operator_dimension = o.domain_dimension_tensor() output_shape = _prefer_static_shape(output) if tensor_util.constant_value(operator_dimension) is not None: operator_dimension = tensor_util.constant_value( operator_dimension) if output.shape[-2] is not None and output.shape[ -1] is not None: dim = int(output.shape[-2] * output_shape[-1] // operator_dimension) else: dim = math_ops.cast(output_shape[-2] * output_shape[-1] // operator_dimension, dtype=dtypes.int32) output_shape = _prefer_static_concat_shape( output_shape[:-2], [dim, operator_dimension]) output = array_ops.reshape(output, shape=output_shape) # Conjugate because we are trying to compute A @ B^T, but # `LinearOperator` only supports `adjoint_arg`. if self.dtype.is_complex: output = math_ops.conj(output) output = solve_matmul_fn(o, output, adjoint=adjoint, adjoint_arg=True) if adjoint_arg: col_dim = _prefer_static_shape(x)[-2] else: col_dim = _prefer_static_shape(x)[-1] if adjoint: row_dim = self.domain_dimension_tensor() else: row_dim = self.range_dimension_tensor() matrix_shape = [row_dim, col_dim] output = array_ops.reshape( output, _prefer_static_concat_shape( _prefer_static_shape(output)[:-2], matrix_shape)) if x.shape.is_fully_defined(): if adjoint_arg: column_dim = x.shape[-2] else: column_dim = x.shape[-1] broadcast_batch_shape = common_shapes.broadcast_shape( x.shape[:-2], self.batch_shape) if adjoint: matrix_dimensions = [self.domain_dimension, column_dim] else: matrix_dimensions = [self.range_dimension, column_dim] output.set_shape( broadcast_batch_shape.concatenate(matrix_dimensions)) return output
def _reshape_for_efficiency(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False): """Maybe reshape a, b, and return an inverse map. For matmul/solve.""" def identity(x): return x # At this point, we have not taken transpose/adjoint of a/b. still_need_to_transpose = True if a.shape.ndims is None or b.shape.ndims is None: return a, b, identity, still_need_to_transpose # This could be handled in the future, but seems less common. if a.shape.ndims >= b.shape.ndims: return a, b, identity, still_need_to_transpose # From now on, we might modify b, but will not modify a. # Suppose: # a.shape = C + [m, n], b.shape = # b.shape = S + C + [n, r] b_extra_ndims = b.shape.ndims - a.shape.ndims # b_extra_sh = S, b_main_sh = C + [n, r] b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # No reason to flip unless the extra dims of b are big enough. Why? # Assume adjoint/transpose = False. Then... # By not flipping, we have to replicate a to shape # b_extra_sh + a.shape, # which could use extra memory. But in all cases, the final output has shape # b_extra_sh + a.shape[:-1] + [b.shape[-1]] # So we only end up creating a larger object if the end dim of b is smaller # than the end dim of a. This often happens, e.g. if b was a vector that was # expanded to a matrix (by appending a singleton). # Since adjoint/transpose may not be False, we must make adjustments here. # The dim of b that holds the multiple equations. a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1] b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1] b_extra_sz_ = ( np.prod(b.shape[:b_extra_ndims].as_list()) if b.shape[:b_extra_ndims].is_fully_defined() else None) if (a_domain_sz_ is not None and b_eq_sz_ is not None and b_extra_sz_ is not None): if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_: return a, b, identity, still_need_to_transpose # At this point, we're flipping for sure! # Any transposes/adjoints will happen here explicitly, rather than in calling # code. Why? To avoid having to write separate complex code for each case. if adjoint_a: a = linalg.adjoint(a) elif transpose_a: a = linalg.transpose(a) if adjoint_b: b = linalg.adjoint(b) elif transpose_b: b = linalg.transpose(b) still_need_to_transpose = False # Recompute shapes, since the transpose/adjoint may have changed them. b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # Permutation to put the extra dims at the end. perm = ( array_ops.concat( (math_ops.range(b_extra_ndims, b.shape.ndims), math_ops.range(0, b_extra_ndims)), 0)) b_extra_on_end = array_ops.transpose(b, perm=perm) # Now squash this end into one long dim. b_squashed_end = array_ops.reshape( b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0)) def reshape_inv(y): # Expand the extra dims hanging off the end, "b_extra_sh". # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y # Could have different batch dims than a and b, because of broadcasting. y_extra_shape = array_ops.concat( (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) y_extra_on_end = array_ops.reshape(y, y_extra_shape) return array_ops.transpose( y_extra_on_end, perm=array_ops.invert_permutation(perm)) return a, b_squashed_end, reshape_inv, still_need_to_transpose