Esempio n. 1
0
def sddmm(lhs_matrix,
          rhs_matrix,
          sparse_topology,
          transpose_lhs=False,
          transpose_rhs=False):
    """Sampled dense dense matrix multiplication.

  Computes selected outputs from the product of two dense matrices.

  Args:
    lhs_matrix: Tensor, the left-hand, dense matrix for the product.
    rhs_matrix: Tensor, the right-hand, dense matrix for the product.
    sparse_topology: SparseMatrix, specifying which outputs are to be computed.
    transpose_lhs: bool, whether to transpose the lhs operand.
    transpose_rhs: bool, whether to trasponse the rhs operand.

  Returns:
    A SparseMatrix holding the selected output values.
  """
    output_values = kernels.sddmm(sparse_topology._rows,
                                  sparse_topology._columns,
                                  sparse_topology.row_indices,
                                  sparse_topology.row_offsets,
                                  sparse_topology.column_indices, lhs_matrix,
                                  rhs_matrix, transpose_lhs, transpose_rhs)
    return SparseMatrix._wrap_existing(sparse_topology.shape,
                                       sparse_topology._columns,
                                       sparse_topology._rows, output_values,
                                       sparse_topology.row_indices,
                                       sparse_topology.row_offsets,
                                       sparse_topology.column_indices)
Esempio n. 2
0
def replicated_sddmm(lhs_matrix,
                     rhs_matrix,
                     sparse_topology,
                     transpose_lhs=False,
                     transpose_rhs=False):
    """Convenience API for replicated sddmm."""
    return kernels.sddmm(sparse_topology._rows, sparse_topology._columns,
                         sparse_topology.row_indices,
                         sparse_topology.row_offsets,
                         sparse_topology.column_indices, lhs_matrix,
                         rhs_matrix, transpose_lhs, transpose_rhs)
Esempio n. 3
0
def _spmm_grad(op, grad):
    """Gradient operation for sparse matrix matrix multiplication."""
    # Collect the inputs.
    m = op.inputs[0]
    k = op.inputs[1]
    values = op.inputs[2]
    row_indices = op.inputs[3]
    row_offsets = op.inputs[4]
    column_indices = op.inputs[5]
    dense_matrix = op.inputs[6]

    # Sparse matrix gradient: multiply the gradient by the transposed
    # dense matrix.
    sparse_matrix_grad = kernels.sddmm(m,
                                       k,
                                       row_indices,
                                       row_offsets,
                                       column_indices,
                                       grad,
                                       dense_matrix,
                                       transpose_rhs=True)

    # Dense matrix gradient: transpose the sparse weights, calculate the
    # new row indices, and multiply sparse matrix with dense gradient.
    values_t, row_offsets_t, column_indices_t = kernels.csr_transpose(
        m, k, values, row_offsets, column_indices)
    row_indices_t = diffsort(row_offsets_t)
    dense_matrix_grad = kernels.spmm(k, m, values_t, row_indices_t,
                                     row_offsets_t, column_indices_t, grad)

    # NOTE: Because we exposed the sparse matrix meta-data as arguments to
    # the underlying op, we need to return 'None' as gradients for these
    # tensors.
    #
    # TODO(tgale): Make sure there are no performance implications for this.
    return [
        None, None, sparse_matrix_grad, None, None, None, dense_matrix_grad
    ]