def sddmm(lhs_matrix, rhs_matrix, sparse_topology, transpose_lhs=False, transpose_rhs=False): """Sampled dense dense matrix multiplication. Computes selected outputs from the product of two dense matrices. Args: lhs_matrix: Tensor, the left-hand, dense matrix for the product. rhs_matrix: Tensor, the right-hand, dense matrix for the product. sparse_topology: SparseMatrix, specifying which outputs are to be computed. transpose_lhs: bool, whether to transpose the lhs operand. transpose_rhs: bool, whether to trasponse the rhs operand. Returns: A SparseMatrix holding the selected output values. """ output_values = kernels.sddmm(sparse_topology._rows, sparse_topology._columns, sparse_topology.row_indices, sparse_topology.row_offsets, sparse_topology.column_indices, lhs_matrix, rhs_matrix, transpose_lhs, transpose_rhs) return SparseMatrix._wrap_existing(sparse_topology.shape, sparse_topology._columns, sparse_topology._rows, output_values, sparse_topology.row_indices, sparse_topology.row_offsets, sparse_topology.column_indices)
def replicated_sddmm(lhs_matrix, rhs_matrix, sparse_topology, transpose_lhs=False, transpose_rhs=False): """Convenience API for replicated sddmm.""" return kernels.sddmm(sparse_topology._rows, sparse_topology._columns, sparse_topology.row_indices, sparse_topology.row_offsets, sparse_topology.column_indices, lhs_matrix, rhs_matrix, transpose_lhs, transpose_rhs)
def _spmm_grad(op, grad): """Gradient operation for sparse matrix matrix multiplication.""" # Collect the inputs. m = op.inputs[0] k = op.inputs[1] values = op.inputs[2] row_indices = op.inputs[3] row_offsets = op.inputs[4] column_indices = op.inputs[5] dense_matrix = op.inputs[6] # Sparse matrix gradient: multiply the gradient by the transposed # dense matrix. sparse_matrix_grad = kernels.sddmm(m, k, row_indices, row_offsets, column_indices, grad, dense_matrix, transpose_rhs=True) # Dense matrix gradient: transpose the sparse weights, calculate the # new row indices, and multiply sparse matrix with dense gradient. values_t, row_offsets_t, column_indices_t = kernels.csr_transpose( m, k, values, row_offsets, column_indices) row_indices_t = diffsort(row_offsets_t) dense_matrix_grad = kernels.spmm(k, m, values_t, row_indices_t, row_offsets_t, column_indices_t, grad) # NOTE: Because we exposed the sparse matrix meta-data as arguments to # the underlying op, we need to return 'None' as gradients for these # tensors. # # TODO(tgale): Make sure there are no performance implications for this. return [ None, None, sparse_matrix_grad, None, None, None, dense_matrix_grad ]