def test_nccl_sparse_push_single_remainder(): nccl_id = nccl.UniqueId() comm = nccl.Communicator(1, 0, nccl_id) index = F.randint([10000], F.int32, F.ctx(), 0, 10000) value = F.uniform([10000, 100], F.float32, F.ctx(), -1.0, 1.0) part = NDArrayPartition(10000, 1, 'remainder') ri, rv = comm.sparse_all_to_all_push(index, value, part) assert F.array_equal(ri, index) assert F.array_equal(rv, value)
def test_nccl_sparse_pull_single_remainder(): nccl_id = nccl.UniqueId() comm = nccl.Communicator(1, 0, nccl_id) req_index = F.randint([10000], F.int64, F.ctx(), 0, 100000) value = F.uniform([100000, 100], F.float32, F.ctx(), -1.0, 1.0) part = NDArrayPartition(100000, 1, 'remainder') rv = comm.sparse_all_to_all_pull(req_index, value, part) exp_rv = F.gather_row(value, req_index) assert F.array_equal(rv, exp_rv)
def test_nccl_sparse_pull_single_range(): nccl_id = nccl.UniqueId() comm = nccl.Communicator(1, 0, nccl_id) req_index = F.randint([10000], F.int64, F.ctx(), 0, 100000) value = F.uniform([100000, 100], F.float32, F.ctx(), -1.0, 1.0) part_ranges = F.copy_to(F.tensor([0, value.shape[0]], dtype=F.int64), F.ctx()) part = NDArrayPartition(100000, 1, 'range', part_ranges=part_ranges) rv = comm.sparse_all_to_all_pull(req_index, value, part) exp_rv = F.gather_row(value, req_index) assert F.array_equal(rv, exp_rv)
def test_rgcn(): etype = [] g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) # 5 etypes R = 5 for i in range(g.number_of_edges()): etype.append(i % 5) B = 2 I = 10 O = 8 rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True) h = F.randn((100, I)) r = F.tensor(etype) init_params = rgc_basis.init(jax.random.PRNGKey(2666), g, h, r) h_new = rgc_basis.apply(init_params, g, h, r) init_params = rgc_basis_low.init(jax.random.PRNGKey(2666), g, h, r) h_new_low = rgc_basis_low.apply(init_params, g, h, r) assert list(h_new.shape) == [100, O] assert list(h_new_low.shape) == [100, O] rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True) h = F.randn((100, I)) r = F.tensor(etype) init_params = rgc_bdd.init(jax.random.PRNGKey(2666), g, h, r) h_new = rgc_bdd.apply(init_params, g, h, r) init_params = rgc_bdd_low.init(jax.random.PRNGKey(2666), g, h, r) h_new_low = rgc_bdd_low.apply(init_params, g, h, r) assert list(h_new.shape) == [100, O] assert list(h_new_low.shape) == [100, O] # with norm norm = F.zeros((g.number_of_edges(), 1)) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True) h = F.randn((100, I)) r = F.tensor(etype) init_params = rgc_basis.init(jax.random.PRNGKey(2666), g, h, r, norm) h_new = rgc_basis.apply(init_params, g, h, r, norm) init_params = rgc_basis_low.init(jax.random.PRNGKey(2666), g, h, r, norm) h_new_low = rgc_basis_low.apply(init_params, g, h, r, norm) assert list(h_new.shape) == [100, O] assert list(h_new_low.shape) == [100, O] rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True) h = F.randn((100, I)) r = F.tensor(etype) init_params = rgc_bdd.init(jax.random.PRNGKey(2666), g, h, r, norm) h_new = rgc_bdd.apply(init_params, g, h, r, norm) init_params = rgc_bdd_low.init(jax.random.PRNGKey(2666), g, h, r, norm) h_new_low = rgc_bdd_low.apply(init_params, g, h, r, norm) assert list(h_new.shape) == [100, O] assert list(h_new_low.shape) == [100, O] # id input rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True) h = F.randint(low=0, high=I, shape=(100, ), dtype=jnp.int64) r = F.tensor(etype) init_params = rgc_basis.init(jax.random.PRNGKey(2666), g, h, r) h_new = rgc_basis.apply(init_params, g, h, r) init_params = rgc_basis_low.init(jax.random.PRNGKey(2666), g, h, r) h_new_low = rgc_basis_low.apply(init_params, g, h, r) assert list(h_new.shape) == [100, O] assert list(h_new_low.shape) == [100, O]