Exemplo n.º 1
0
def test_spmmd(mkl, sparse1, sparse2, dtype):
    # sparse1 @ sparse2
    smat1, dmat1 = sparse1
    smat1 = smat1.to(dtype=dtype); dmat1 = dmat1.to(dtype=dtype)
    smat2, dmat2 = sparse2
    smat2 = smat2.to(dtype=dtype); dmat2 = dmat2.to(dtype=dtype)
    dt = torch.float32

    outC = torch.zeros(smat1.shape[0], smat2.shape[1], dtype=dt)
    outF = torch.zeros(smat2.shape[1], smat1.shape[0], dtype=dt).T

    mkl = mkl_lib()
    mkl_As = mkl.mkl_create_sparse(smat1)
    mkl_Bs = mkl.mkl_create_sparse(smat2)
    mkl.mkl_spmmd(mkl_As, mkl_Bs, outC)
    mkl.mkl_spmmd(mkl_As, mkl_Bs, outF)
    expected = dmat1 @ dmat2
    np.testing.assert_allclose(expected, outC, rtol=_RTOL[dt])
    np.testing.assert_allclose(expected, outF, rtol=_RTOL[dt])

    mkl.mkl_sparse_destroy(mkl_As)
    mkl.mkl_sparse_destroy(mkl_Bs)
Exemplo n.º 2
0
def _sparse_matmul_cpu(A, B, out):
    """
    Inputs:
     - A : N x D, CSR matrix
     - B : D x M, CSC matrix
    """
    from falkon.mkl_bindings.mkl_bind import mkl_lib

    if A.nnz() == 0 or B.nnz() == 0:
        return out
    if not A.is_csr:
        raise ValueError("A must be CSR matrix")
    if not B.is_csc:
        raise ValueError("B must be CSC matrix")

    mkl = mkl_lib()
    try:
        # For some reason assigning the 'to_scipy' to their own variables
        # is **absolutely fundamental** for the mkl bindings to work
        A = A.transpose_csc()
        As = A.to_scipy()  # D * N (csc)
        Bs = B.to_scipy()

        mkl_sp_1 = mkl.mkl_create_sparse_from_scipy(As)
        mkl_sp_2 = mkl.mkl_create_sparse_from_scipy(Bs)
        mkl.mkl_spmmd(mkl_sp_1, mkl_sp_2, out, transposeA=True)
        return out
    finally:
        try:
            # noinspection PyUnboundLocalVariable
            mkl.mkl_sparse_destroy(mkl_sp_1)
        except:  # noqa E722
            pass
        try:
            # noinspection PyUnboundLocalVariable
            mkl.mkl_sparse_destroy(mkl_sp_2)
        except:  # noqa E722
            pass
Exemplo n.º 3
0
def mkl():
    return mkl_lib()