예제 #1
0
def sddmm(lhs_matrix,
          rhs_matrix,
          sparse_topology,
          transpose_lhs=False,
          transpose_rhs=False):
    """Sampled dense dense matrix multiplication.

  Computes selected outputs from the product of two dense matrices.

  Args:
    lhs_matrix: Tensor, the left-hand, dense matrix for the product.
    rhs_matrix: Tensor, the right-hand, dense matrix for the product.
    sparse_topology: SparseMatrix, specifying which outputs are to be computed.
    transpose_lhs: bool, whether to transpose the lhs operand.
    transpose_rhs: bool, whether to trasponse the rhs operand.

  Returns:
    A SparseMatrix holding the selected output values.
  """
    output_values = kernels.sddmm(sparse_topology._rows,
                                  sparse_topology._columns,
                                  sparse_topology.row_indices,
                                  sparse_topology.row_offsets,
                                  sparse_topology.column_indices, lhs_matrix,
                                  rhs_matrix, transpose_lhs, transpose_rhs)
    return SparseMatrix._wrap_existing(sparse_topology.shape,
                                       sparse_topology._columns,
                                       sparse_topology._rows, output_values,
                                       sparse_topology.row_indices,
                                       sparse_topology.row_offsets,
                                       sparse_topology.column_indices)
예제 #2
0
def transpose(sparse_matrix):
    """Transpose a sparse matrix.

  Args:
    sparse_matrix: SparseMatrix, the sparse matrix to be transposed.

  Returns:
    SparseMatrix, a sparse matrix that is the transpose of the input
      sparse matrix.
  """
    values, row_offsets, column_indices = kernels.csr_transpose(
        sparse_matrix._rows, sparse_matrix._columns, sparse_matrix.values,
        sparse_matrix.row_offsets, sparse_matrix.column_indices)

    # Sort the row indices.
    row_indices = diffsort(row_offsets)

    # Wrap the individual tensors in a SparseMatrix and return.
    return SparseMatrix._wrap_existing(list(reversed(sparse_matrix.shape)),
                                       sparse_matrix._columns,
                                       sparse_matrix._rows, values,
                                       row_indices, row_offsets,
                                       column_indices)
예제 #3
0
def sparse_softmax(x):
    output_values = kernels.csr_softmax(x.values, x.row_indices, x.row_offsets,
                                        x.column_indices)
    return SparseMatrix._wrap_existing(x.shape, x._columns, x._rows,
                                       output_values, x.row_indices,
                                       x.row_offsets, x.column_indices)