コード例 #1
0
ファイル: sparse_ops_test.py プロジェクト: zhangqiaorjc/jax
    def test_coo_matmat(self, shape, dtype, transpose):
        op = lambda M: M.T if transpose else M

        B_rng = jtu.rand_default(self.rng())
        rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
        M = rng(shape, dtype)
        B = B_rng((op(M).shape[1], 4), dtype)

        args = (M.data, M.row, M.col, B)
        matmat = lambda *args: sparse_ops.coo_matmat(
            *args, shape=shape, transpose=transpose)

        self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
        self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)

        y, dy = jvp(
            lambda x: sparse_ops.coo_matmat(
                M.data, M.row, M.col, x, shape=shape, transpose=transpose).sum(
                ), (B, ), (jnp.ones_like(B), ))
        self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)

        y, dy = jvp(
            lambda x: sparse_ops.coo_matmat(
                x, M.row, M.col, B, shape=shape, transpose=transpose).sum(),
            (M.data, ), (jnp.ones_like(M.data), ))
        self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
コード例 #2
0
ファイル: sparse_ops_test.py プロジェクト: yajinwuzl/jax
  def test_coo_matmat(self, shape, dtype, transpose):
    op = lambda M: M.T if transpose else M

    B_rng = jtu.rand_default(self.rng())
    rng = rand_sparse(self.rng(), post=sparse.coo_matrix)
    M = rng(shape, dtype)
    B = B_rng((op(M).shape[1], 4), dtype)

    args = (M.data, M.row, M.col, B)
    matmat = lambda *args: sparse_ops.coo_matmat(*args, shape=shape, transpose=transpose)

    self.assertAllClose(op(M) @ B, matmat(*args))
    self.assertAllClose(op(M) @ B, jit(matmat)(*args))