def reshape_inv(y): # Expand the extra dims hanging off the end, "b_extra_sh". # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y # Could have different batch dims than a and b, because of broadcasting. y_extra_shape = array_ops.concat( (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) y_extra_on_end = array_ops.reshape(y, y_extra_shape) inverse_perm = np.argsort(perm) return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
def _rotate_last_dim(x, rotate_right=False): """Rotate the last dimension either left or right.""" ndims = array_ops.rank(x) if rotate_right: transpose_perm = array_ops.concat( [[ndims - 1], array_ops.range(0, ndims - 1)], axis=0) else: transpose_perm = array_ops.concat([array_ops.range(1, ndims), [0]], axis=0) return array_ops.transpose(x, transpose_perm)
def _reshape_for_efficiency(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False): """Maybe reshape a, b, and return an inverse map. For matmul/solve.""" def identity(x): return x # At this point, we have not taken transpose/adjoint of a/b. still_need_to_transpose = True if tensor_shape.TensorShape(a.shape).ndims is None or tensor_shape.TensorShape(b.shape).ndims is None: return a, b, identity, still_need_to_transpose # This could be handled in the future, but seems less common. if tensor_shape.TensorShape(a.shape).ndims >= tensor_shape.TensorShape(b.shape).ndims: return a, b, identity, still_need_to_transpose # From now on, we might modify b, but will not modify a. # Suppose: # tensor_shape.TensorShape(a.shape) = C + [m, n], tensor_shape.TensorShape(b.shape) = # tensor_shape.TensorShape(b.shape) = S + C + [n, r] b_extra_ndims = tensor_shape.TensorShape(b.shape).ndims - tensor_shape.TensorShape(a.shape).ndims # b_extra_sh = S, b_main_sh = C + [n, r] b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # No reason to flip unless the extra dims of b are big enough. Why? # Assume adjoint/transpose = False. Then... # By not flipping, we have to replicate a to shape # b_extra_sh + tensor_shape.TensorShape(a.shape), # which could use extra memory. But in all cases, the final output has shape # b_extra_sh + tensor_shape.TensorShape(a.shape)[:-1] + tensor_shape.TensorShape([b.shape)[-1]] # So we only end up creating a larger object if the end dim of b is smaller # than the end dim of a. This often happens, e.g. if b was a vector that was # expanded to a matrix (by appending a singleton). # Since adjoint/transpose may not be False, we must make adjustments here. # The dim of b that holds the multiple equations. a_domain_sz_ = tensor_shape.TensorShape(a.shape)[-2 if adjoint_a or transpose_a else -1] b_eq_sz_ = tensor_shape.TensorShape(b.shape)[-2 if adjoint_b or transpose_b else -1] b_extra_sz_ = ( np.prod(tensor_shape.TensorShape(b.shape)[:b_extra_ndims].as_list()) if tensor_shape.TensorShape(b.shape)[:b_extra_ndims].is_fully_defined() else None) if (a_domain_sz_ is not None and b_eq_sz_ is not None and b_extra_sz_ is not None): if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_: return a, b, identity, still_need_to_transpose # At this point, we're flipping for sure! # Any transposes/adjoints will happen here explicitly, rather than in calling # code. Why? To avoid having to write separate complex code for each case. if adjoint_a: a = _linalg.matrix_transpose(a, conjugate=True) elif transpose_a: a = _linalg.matrix_transpose(a, conjugate=False) if adjoint_b: b = _linalg.matrix_transpose(b, conjugate=True) elif transpose_a: b = _linalg.matrix_transpose(b, conjugate=False) still_need_to_transpose = False # Recompute shapes, since the transpose/adjoint may have changed them. b_extra_sh = array_ops.shape(b)[:b_extra_ndims] b_main_sh = array_ops.shape(b)[b_extra_ndims:] # Permutation to put the extra dims at the end. perm = ( np.concatenate( (np.arange(b_extra_ndims, tensor_shape.TensorShape(b.shape).ndims), np.arange(0, b_extra_ndims)), 0)) b_extra_on_end = array_ops.transpose(b, perm=perm) # Now squash this end into one long dim. b_squashed_end = array_ops.reshape( b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0)) def reshape_inv(y): # Expand the extra dims hanging off the end, "b_extra_sh". # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y # Could have different batch dims than a and b, because of broadcasting. y_extra_shape = array_ops.concat( (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) y_extra_on_end = array_ops.reshape(y, y_extra_shape) inverse_perm = np.argsort(perm) return array_ops.transpose(y_extra_on_end, perm=inverse_perm) return a, b_squashed_end, reshape_inv, still_need_to_transpose