def dot(a, b, transpose_a=False, transpose_b=False): """ Dot product between a and b along innermost dimensions, for a and b with same rank. Supports both dense and sparse multiplication (including sparse-sparse). :param a: Tensor or SparseTensor with rank 2 or 3. :param b: Tensor or SparseTensor with same rank as a. :param transpose_a: bool, transpose innermost two dimensions of a. :param transpose_b: bool, transpose innermost two dimensions of b. :return: Tensor or SparseTensor with rank 2 or 3. """ if transpose_a == False and transpose_b == False and isinstance( a, tf.SparseTensor) and isinstance(b, tf.Tensor) and tf.keras.backend.ndim(b) == 2: return tf.sparse.sparse_dense_matmul(a, b) a_is_sparse_tensor = isinstance(a, tf.SparseTensor) b_is_sparse_tensor = isinstance(b, tf.SparseTensor) if a_is_sparse_tensor: a = tfsp.CSRSparseMatrix(a) if b_is_sparse_tensor: b = tfsp.CSRSparseMatrix(b) out = tfsp.matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b) if hasattr(out, 'to_sparse_tensor'): return out.to_sparse_tensor() return out
def dot(a, b): """ Computes a @ b, for a, b of the same rank (both 2 or both 3). If the rank is 2, then the innermost dimension of `a` must match the outermost dimension of `b`. If the rank is 3, the first dimension of `a` and `b` must be equal and the function computes a batch matmul. Supports both dense and sparse multiplication (including sparse-sparse). :param a: Tensor or SparseTensor with rank 2 or 3. :param b: Tensor or SparseTensor with same rank as b. :return: Tensor or SparseTensor with rank 2 or 3. """ a_ndim = K.ndim(a) b_ndim = K.ndim(b) assert a_ndim == b_ndim, "Expected equal ranks, got {} and {}" "".format( a_ndim, b_ndim ) a_is_sparse = K.is_sparse(a) b_is_sparse = K.is_sparse(b) # Handle cases: rank 2 sparse-dense, rank 2 dense-sparse # In these cases we can use the faster sparse-dense matmul of tf.sparse if a_ndim == 2: if a_is_sparse and not b_is_sparse: return tf.sparse.sparse_dense_matmul(a, b) if not a_is_sparse and b_is_sparse: return ops.transpose( tf.sparse.sparse_dense_matmul(ops.transpose(b), ops.transpose(a)) ) # Handle cases: rank 2 sparse-sparse, rank 3 sparse-dense, # rank 3 dense-sparse, rank 3 sparse-sparse # In these cases we can use the tfsp.CSRSparseMatrix implementation (slower, # but saves memory) if a_is_sparse: a = tfsp.CSRSparseMatrix(a) if b_is_sparse: b = tfsp.CSRSparseMatrix(b) if a_is_sparse or b_is_sparse: out = tfsp.matmul(a, b) if hasattr(out, "to_sparse_tensor"): return out.to_sparse_tensor() else: return out # Handle case: rank 2 dense-dense, rank 3 dense-dense # Here we use the standard dense operation return tf.matmul(a, b)
def dot(a, b, transpose_a=False, transpose_b=False): """ Dot product between `a` and `b`, with automatic handling of batch dimensions. Supports both dense and sparse multiplication (including sparse-sparse). The innermost dimension of `a` must match the outermost dimension of `b`, unless there is a shared batch dimension. Note that doing sparse-sparse multiplication of any rank and sparse-dense multiplication with rank higher than 2 may result in slower computations. :param a: Tensor or SparseTensor with rank 2 or 3. :param b: Tensor or SparseTensor with rank 2 or 3. :param transpose_a: bool, transpose innermost two dimensions of a. :param transpose_b: bool, transpose innermost two dimensions of b. :return: Tensor or SparseTensor with rank 2 or 3. """ a_is_sparse_tensor = isinstance(a, tf.SparseTensor) b_is_sparse_tensor = isinstance(b, tf.SparseTensor) # Handle case where we can use faster sparse-dense matmul if K.ndim(a) == 2 and K.ndim(b) == 2: if transpose_a: a = ops.transpose(a) if transpose_b: b = ops.transpose(b) if a_is_sparse_tensor and not b_is_sparse_tensor: return tf.sparse.sparse_dense_matmul(a, b) elif not a_is_sparse_tensor and b_is_sparse_tensor: return ops.transpose( tf.sparse.sparse_dense_matmul(ops.transpose(b), ops.transpose(a))) # Fallthrough to sp-sp and d-d implementations if a_is_sparse_tensor: a = tfsp.CSRSparseMatrix(a) if b_is_sparse_tensor: b = tfsp.CSRSparseMatrix(b) if a_is_sparse_tensor or b_is_sparse_tensor: out = tfsp.matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b) if hasattr(out, 'to_sparse_tensor'): return out.to_sparse_tensor() else: out = tf.matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b) return out
def unstacked_csr_matmul(adj_values, features, kernel, indices, dense_shape): convs = [] for av, k in zip(tf.unstack(adj_values, axis=0), tf.unstack(kernel, axis=0)): adj = tf.SparseTensor(indices, av, dense_shape) adj = sparse_lib.CSRSparseMatrix(adj) f = sparse_lib.matmul(adj, features) convs.append(tf.matmul(f, k)) return tf.add_n(convs)
def _csr_matmul(indices: tf.Tensor, values: tf.Tensor, dense_shape, b: tf.Tensor): try: # pylint: disable=import-outside-toplevel from tensorflow.python.ops.linalg.sparse import sparse as sparse_lib # pylint: enable=import-outside-toplevel except ImportError as e: raise ImportError("use_csr requires tensorflow >= 2.3") from e st = tf.SparseTensor(indices, values, dense_shape) csr_m = sparse_lib.CSRSparseMatrix(st) out = sparse_lib.matmul(csr_m, b) def grad(dy): rows, cols = tf.unstack(indices, axis=-1) parts_a = tf.gather(dy, rows, axis=0) parts_b = tf.gather(b, cols, axis=0) a_values_grad = tf.math.reduce_sum(parts_a * parts_b, axis=1) b_grad = sparse_lib.matmul(csr_m, dy, adjoint_a=True) return (None, a_values_grad, None, b_grad) return out, grad