def test_coo_matmat(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.coo_matrix) M = rng(shape, dtype) B = B_rng((op(M).shape[1], 4), dtype) args = (M.data, M.row, M.col, B) matmat = lambda *args: sparse_ops.coo_matmat( *args, shape=shape, transpose=transpose) self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL) self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL) y, dy = jvp( lambda x: sparse_ops.coo_matmat( M.data, M.row, M.col, x, shape=shape, transpose=transpose).sum( ), (B, ), (jnp.ones_like(B), )) self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL) y, dy = jvp( lambda x: sparse_ops.coo_matmat( x, M.row, M.col, B, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), )) self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
def test_coo_matmat(self, shape, dtype, transpose): op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) rng = rand_sparse(self.rng(), post=sparse.coo_matrix) M = rng(shape, dtype) B = B_rng((op(M).shape[1], 4), dtype) args = (M.data, M.row, M.col, B) matmat = lambda *args: sparse_ops.coo_matmat(*args, shape=shape, transpose=transpose) self.assertAllClose(op(M) @ B, matmat(*args)) self.assertAllClose(op(M) @ B, jit(matmat)(*args))