def _sddmm_grad(op, grad): """Gradient operation for sampled dense dense matrix multiplication.""" # Collect the inputs. m = op.inputs[0] n = op.inputs[1] row_indices = op.inputs[2] row_offsets = op.inputs[3] column_indices = op.inputs[4] lhs_matrix = op.inputs[5] rhs_matrix = op.inputs[6] # lhs matrix gradient: multiply the sparse gradient by the rhs matrix. lhs_matrix_grad = kernels.spmm(m, n, grad, row_indices, row_offsets, column_indices, rhs_matrix) # rhs matrix gradient: transpose the sparse gradient, calculate the new # row indices, and multiply the sparse gradient with the lhs matrix. grad_t, row_offsets_t, column_indices_t = kernels.csr_transpose( m, n, grad, row_offsets, column_indices) row_indices_t = diffsort(row_offsets_t) rhs_matrix_grad = kernels.spmm(n, m, grad_t, row_indices_t, row_offsets_t, column_indices_t, lhs_matrix) # NOTE: Because we exposed the sparse matrix meta-data as arguments to # the underlying op, we need to return 'None' as gradients for these # tensors. # # TODO(tgale): Make sure there are no performance implications for this. return [None] * 5 + [lhs_matrix_grad, rhs_matrix_grad]
def _spmm_grad(op, grad): """Gradient operation for sparse matrix matrix multiplication.""" # Collect the inputs. m = op.inputs[0] k = op.inputs[1] values = op.inputs[2] row_indices = op.inputs[3] row_offsets = op.inputs[4] column_indices = op.inputs[5] dense_matrix = op.inputs[6] # Sparse matrix gradient: multiply the gradient by the transposed # dense matrix. sparse_matrix_grad = kernels.sddmm(m, k, row_indices, row_offsets, column_indices, grad, dense_matrix, transpose_rhs=True) # Dense matrix gradient: transpose the sparse weights, calculate the # new row indices, and multiply sparse matrix with dense gradient. values_t, row_offsets_t, column_indices_t = kernels.csr_transpose( m, k, values, row_offsets, column_indices) row_indices_t = diffsort(row_offsets_t) dense_matrix_grad = kernels.spmm(k, m, values_t, row_indices_t, row_offsets_t, column_indices_t, grad) # NOTE: Because we exposed the sparse matrix meta-data as arguments to # the underlying op, we need to return 'None' as gradients for these # tensors. # # TODO(tgale): Make sure there are no performance implications for this. return [ None, None, sparse_matrix_grad, None, None, None, dense_matrix_grad ]
def transpose(sparse_matrix): """Transpose a sparse matrix. Args: sparse_matrix: SparseMatrix, the sparse matrix to be transposed. Returns: SparseMatrix, a sparse matrix that is the transpose of the input sparse matrix. """ values, row_offsets, column_indices = kernels.csr_transpose( sparse_matrix._rows, sparse_matrix._columns, sparse_matrix.values, sparse_matrix.row_offsets, sparse_matrix.column_indices) # Sort the row indices. row_indices = diffsort(row_offsets) # Wrap the individual tensors in a SparseMatrix and return. return SparseMatrix._wrap_existing(list(reversed(sparse_matrix.shape)), sparse_matrix._columns, sparse_matrix._rows, values, row_indices, row_offsets, column_indices)