예제 #1
0
파일: test.py 프로젝트: gcucurull/jax-gcn
def test_rectangular_matrix_3():
    rows, cols = 2000, 300
    mask = numpy.random.rand(rows, cols) < sparsity
    A = numpy.random.rand(rows, cols)
    A[mask] = 0.0
    B = numpy.random.rand(cols, 64)

    indexes = A.nonzero()
    values = A[indexes]
    sp_A = (indexes, values)

    sp_res = sp_matmul(sp_A, B, A.shape[0])
    res = np.matmul(A, B)

    diff = distance(sp_res, res)
    print(diff)
    assert diff < tolerance
예제 #2
0
파일: test.py 프로젝트: gcucurull/jax-gcn
def test_square_sparse_matrix_fail():
    rows = 2000
    mask = numpy.random.rand(rows, rows) < sparsity
    A = numpy.random.rand(rows, rows)
    A[mask] = 0.0
    B = numpy.random.rand(rows, 1)

    indexes = A.nonzero()
    values = A[indexes]
    sp_A = (indexes, values)

    sp_res = sp_matmul(sp_A, B, A.shape[0])
    res = np.matmul(A.T, B)

    diff = distance(sp_res, res)
    print(diff)
    assert diff > tolerance
예제 #3
0
파일: models.py 프로젝트: gcucurull/jax-gcn
 def matmul(A, B, shape):
     if sparse:
         return sp_matmul(A, B, shape)
     else:
         return np.matmul(A, B)