def check_rpc_sampling(tmpdir, num_server): generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = CitationGraphDataset("cora")[0] g.readonly() print(g.idtype) num_parts = num_server num_hops = 1 partition_graph(g, 'test_sampling', num_parts, tmpdir, num_hops=num_hops, part_method='metis', reshuffle=False) pserver_list = [] ctx = mp.get_context('spawn') for i in range(num_server): p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling')) p.start() time.sleep(1) pserver_list.append(p) sampled_graph = start_sample_client(0, tmpdir, num_server > 1) print("Done sampling") for p in pserver_list: p.join() src, dst = sampled_graph.edges() assert sampled_graph.number_of_nodes() == g.number_of_nodes() assert np.all(F.asnumpy(g.has_edges_between(src, dst))) eids = g.edge_ids(src, dst) assert np.array_equal( F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))
def test_kv_multi_role(): reset_envs() num_servers = 2 num_trainers = 2 num_samplers = 2 generate_ip_config("kv_ip_mul_config.txt", 1, num_servers) # There are two trainer processes and each trainer process has two sampler processes. num_clients = num_trainers * (1 + num_samplers) ctx = mp.get_context('spawn') pserver_list = [] pclient_list = [] os.environ['DGL_NUM_SAMPLER'] = str(num_samplers) os.environ['DGL_NUM_SERVER'] = str(num_servers) for i in range(num_servers): pserver = ctx.Process(target=start_server_mul_role, args=(i, num_clients, num_servers)) pserver.start() pserver_list.append(pserver) for i in range(num_trainers): pclient = ctx.Process(target=start_client_mul_role, args=(i,)) pclient.start() pclient_list.append(pclient) for i in range(num_trainers): pclient_list[i].join() for i in range(num_servers): pserver_list[i].join()
def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server): generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = create_random_hetero(dense=True) num_parts = num_server num_hops = 1 partition_graph(g, 'test_sampling', num_parts, tmpdir, num_hops=num_hops, part_method='metis', reshuffle=True) pserver_list = [] ctx = mp.get_context('spawn') for i in range(num_server): p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling')) p.start() time.sleep(1) pserver_list.append(p) fanout = 3 block, gpb = start_hetero_etype_sample_client(0, tmpdir, num_server > 1, fanout, nodes={'n3': [0, 10, 99, 66, 124, 208]}) print("Done sampling") for p in pserver_list: p.join() src, dst = block.edges(etype=('n1', 'r2', 'n3')) assert len(src) == 18 src, dst = block.edges(etype=('n2', 'r3', 'n3')) assert len(src) == 18 orig_nid_map = {ntype: F.zeros((g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes} orig_eid_map = {etype: F.zeros((g.number_of_edges(etype),), dtype=F.int64) for etype in g.etypes} for i in range(num_server): part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i) ntype_ids, type_nids = gpb.map_to_per_ntype(part.ndata[dgl.NID]) for ntype_id, ntype in enumerate(g.ntypes): idx = ntype_ids == ntype_id F.scatter_row_inplace(orig_nid_map[ntype], F.boolean_mask(type_nids, idx), F.boolean_mask(part.ndata['orig_id'], idx)) etype_ids, type_eids = gpb.map_to_per_etype(part.edata[dgl.EID]) for etype_id, etype in enumerate(g.etypes): idx = etype_ids == etype_id F.scatter_row_inplace(orig_eid_map[etype], F.boolean_mask(type_eids, idx), F.boolean_mask(part.edata['orig_id'], idx)) for src_type, etype, dst_type in block.canonical_etypes: src, dst = block.edges(etype=etype) # These are global Ids after shuffling. shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src) shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst) shuffled_eid = block.edges[etype].data[dgl.EID] orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src)) orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst)) orig_eid = F.asnumpy(F.gather_row(orig_eid_map[etype], shuffled_eid)) # Check the node Ids and edge Ids. orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype) assert np.all(F.asnumpy(orig_src1) == orig_src) assert np.all(F.asnumpy(orig_dst1) == orig_dst)
def test_multi_client_connect(net_type): reset_envs() os.environ['DGL_DIST_MODE'] = 'distributed' ip_config = "rpc_ip_config_mul_client.txt" generate_ip_config(ip_config, 1, 1) ctx = mp.get_context('spawn') num_clients = 1 pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, 0, False, 1, net_type)) # small max try times os.environ['DGL_DIST_MAX_TRY_TIMES'] = '1' expect_except = False try: start_client(ip_config, 0, 1, net_type) except dgl.distributed.DistConnectError as err: print("Expected error: {}".format(err)) expect_except = True assert expect_except # large max try times os.environ['DGL_DIST_MAX_TRY_TIMES'] = '1024' pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1, net_type)) pclient.start() pserver.start() pclient.join() pserver.join() reset_envs()
def check_rpc_get_degree_shuffle(tmpdir, num_server): generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = CitationGraphDataset("cora")[0] g.readonly() num_parts = num_server partition_graph(g, 'test_get_degrees', num_parts, tmpdir, num_hops=1, part_method='metis', reshuffle=True) pserver_list = [] ctx = mp.get_context('spawn') for i in range(num_server): p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_get_degrees')) p.start() time.sleep(1) pserver_list.append(p) orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu()) for i in range(num_server): part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_get_degrees.json', i) orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id'] nids = F.tensor(np.random.randint(g.number_of_nodes(), size=100)) in_degs, out_degs, all_in_degs, all_out_degs = start_get_degrees_client(0, tmpdir, num_server > 1, nids) print("Done get_degree") for p in pserver_list: p.join() print('check results') assert F.array_equal(g.in_degrees(orig_nid[nids]), in_degs) assert F.array_equal(g.in_degrees(orig_nid), all_in_degs) assert F.array_equal(g.out_degrees(orig_nid[nids]), out_degs) assert F.array_equal(g.out_degrees(orig_nid), all_out_degs)
def test_multi_client_groups(): reset_envs() os.environ['DGL_DIST_MODE'] = 'distributed' ip_config = "rpc_ip_config_mul_client_groups.txt" num_machines = 5 # should test with larger number but due to possible port in-use issue. num_servers = 1 generate_ip_config(ip_config, num_machines, num_servers) # presssue test num_clients = 2 num_groups = 2 ctx = mp.get_context('spawn') pserver_list = [] for i in range(num_servers * num_machines): pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, i, True, num_servers)) pserver.start() pserver_list.append(pserver) pclient_list = [] for i in range(num_clients): for group_id in range(num_groups): pclient = ctx.Process(target=start_client, args=(ip_config, group_id, num_servers)) pclient.start() pclient_list.append(pclient) for p in pclient_list: p.join() for p in pserver_list: assert p.is_alive() # force shutdown server dgl.distributed.shutdown_servers(ip_config, num_servers) for p in pserver_list: p.join()
def test_standalone(tmpdir): reset_envs() generate_ip_config("mp_ip_config.txt", 1, 1) g = CitationGraphDataset("cora")[0] print(g.idtype) num_parts = 1 num_hops = 1 orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir, num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True) os.environ['DGL_DIST_MODE'] = 'standalone' try: start_dist_dataloader(0, tmpdir, 1, True, orig_nid, orig_eid) except Exception as e: print(e) dgl.distributed.exit_client( ) # this is needed since there's two test here in one process
def check_rpc_find_edges_shuffle(tmpdir, num_server): generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = CitationGraphDataset("cora")[0] g.readonly() num_parts = num_server orig_nid, orig_eid = partition_graph(g, 'test_find_edges', num_parts, tmpdir, num_hops=1, part_method='metis', reshuffle=True, return_mapping=True) pserver_list = [] ctx = mp.get_context('spawn') for i in range(num_server): p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_find_edges', ['csr', 'coo'])) p.start() time.sleep(1) pserver_list.append(p) eids = F.tensor(np.random.randint(g.number_of_edges(), size=100)) u, v = g.find_edges(orig_eid[eids]) du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids) du = orig_nid[du] dv = orig_nid[dv] assert F.array_equal(u, du) assert F.array_equal(v, dv)
def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server): generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = create_random_hetero(dense=True, empty=True) num_parts = num_server num_hops = 1 orig_nids, _ = partition_graph(g, 'test_sampling', num_parts, tmpdir, num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True) pserver_list = [] ctx = mp.get_context('spawn') for i in range(num_server): p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling')) p.start() time.sleep(1) pserver_list.append(p) fanout = 3 deg = get_degrees(g, orig_nids['n3'], 'n3') empty_nids = F.nonzero_1d(deg == 0) block, gpb = start_hetero_etype_sample_client(0, tmpdir, num_server > 1, fanout, nodes={'n3': empty_nids}) print("Done sampling") for p in pserver_list: p.join() assert block.number_of_edges() == 0 assert len(block.etypes) == len(g.etypes)
def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server): """sample on bipartite via sample_etype_neighbors() which yields empty sample results""" generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = create_random_bipartite() num_parts = num_server num_hops = 1 orig_nids, _ = partition_graph(g, 'test_sampling', num_parts, tmpdir, num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True) pserver_list = [] ctx = mp.get_context('spawn') for i in range(num_server): p = ctx.Process(target=start_server, args=( i, tmpdir, num_server > 1, 'test_sampling')) p.start() time.sleep(1) pserver_list.append(p) deg = get_degrees(g, orig_nids['game'], 'game') empty_nids = F.nonzero_1d(deg == 0) block, gpb = start_bipartite_etype_sample_client(0, tmpdir, num_server > 1, nodes={'game': empty_nids, 'user': [1]}) print("Done sampling") for p in pserver_list: p.join() assert block is not None assert block.number_of_edges() == 0 assert len(block.etypes) == len(g.etypes)
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle, num_groups): reset_envs() # No multiple partitions on single machine for # multiple client groups in case of race condition. if num_groups > 1: num_server = 1 generate_ip_config("mp_ip_config.txt", num_server, num_server) g = CitationGraphDataset("cora")[0] print(g.idtype) num_parts = num_server num_hops = 1 orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir, num_hops=num_hops, part_method='metis', reshuffle=reshuffle, return_mapping=True) pserver_list = [] ctx = mp.get_context('spawn') keep_alive = num_groups > 1 for i in range(num_server): p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, num_workers + 1, keep_alive)) p.start() time.sleep(1) pserver_list.append(p) os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_NUM_SAMPLER'] = str(num_workers) ptrainer_list = [] num_trainers = 1 for trainer_id in range(num_trainers): for group_id in range(num_groups): p = ctx.Process(target=start_dist_dataloader, args=(trainer_id, tmpdir, num_server, drop_last, orig_nid, orig_eid, group_id)) p.start() time.sleep(1) # avoid race condition when instantiating DistGraph ptrainer_list.append(p) for p in ptrainer_list: p.join() if keep_alive: for p in pserver_list: assert p.is_alive() # force shutdown server dgl.distributed.shutdown_servers("mp_ip_config.txt", 1) for p in pserver_list: p.join()
def test_rpc(): reset_envs() os.environ['DGL_DIST_MODE'] = 'distributed' generate_ip_config("rpc_ip_config.txt", 1, 1) ctx = mp.get_context('spawn') pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config.txt")) pclient = ctx.Process(target=start_client, args=("rpc_ip_config.txt", )) pserver.start() pclient.start() pserver.join() pclient.join()
def check_rpc_in_subgraph_shuffle(tmpdir, num_server): generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = CitationGraphDataset("cora")[0] g.readonly() num_parts = num_server partition_graph(g, 'test_in_subgraph', num_parts, tmpdir, num_hops=1, part_method='metis', reshuffle=True) pserver_list = [] ctx = mp.get_context('spawn') for i in range(num_server): p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_in_subgraph')) p.start() time.sleep(1) pserver_list.append(p) nodes = [0, 10, 99, 66, 1024, 2008] sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes) for p in pserver_list: p.join() orig_nid = F.zeros((g.number_of_nodes(), ), dtype=F.int64, ctx=F.cpu()) orig_eid = F.zeros((g.number_of_edges(), ), dtype=F.int64, ctx=F.cpu()) for i in range(num_server): part, _, _, _, _, _, _ = load_partition( tmpdir / 'test_in_subgraph.json', i) orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id'] orig_eid[part.edata[dgl.EID]] = part.edata['orig_id'] src, dst = sampled_graph.edges() src = orig_nid[src] dst = orig_nid[dst] assert sampled_graph.number_of_nodes() == g.number_of_nodes() assert np.all(F.asnumpy(g.has_edges_between(src, dst))) subg1 = dgl.in_subgraph(g, orig_nid[nodes]) src1, dst1 = subg1.edges() assert np.all(np.sort(F.asnumpy(src)) == np.sort(F.asnumpy(src1))) assert np.all(np.sort(F.asnumpy(dst)) == np.sort(F.asnumpy(dst1))) eids = g.edge_ids(src, dst) eids1 = orig_eid[sampled_graph.edata[dgl.EID]] assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
def test_rpc_timeout(net_type): reset_envs() os.environ['DGL_DIST_MODE'] = 'distributed' ip_config = "rpc_ip_config.txt" generate_ip_config(ip_config, 1, 1) ctx = mp.get_context('spawn') pserver = ctx.Process(target=start_server, args=(1, ip_config, 0, False, 1, net_type)) pclient = ctx.Process(target=start_client_timeout, args=(ip_config, 0, 1, net_type)) pserver.start() pclient.start() pserver.join() pclient.join()
def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = CitationGraphDataset("cora")[0] g.readonly() num_parts = num_server num_hops = 1 partition_graph(g, 'test_sampling', num_parts, tmpdir, num_hops=num_hops, part_method='metis', reshuffle=True) pserver_list = [] ctx = mp.get_context('spawn') keep_alive = num_groups > 1 for i in range(num_server): p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling', ['csc', 'coo'], keep_alive)) p.start() time.sleep(1) pserver_list.append(p) pclient_list = [] num_clients = 1 for client_id in range(num_clients): for group_id in range(num_groups): p = ctx.Process(target=start_sample_client_shuffle, args=(client_id, tmpdir, num_server > 1, g, num_server, group_id)) p.start() time.sleep(1) # avoid race condition when instantiating DistGraph pclient_list.append(p) for p in pclient_list: p.join() if keep_alive: for p in pserver_list: assert p.is_alive() # force shutdown server dgl.distributed.shutdown_servers("rpc_ip_config.txt", 1) for p in pserver_list: p.join()
def check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type): generate_ip_config("mp_ip_config.txt", num_server, num_server) num_parts = num_server num_hops = 1 orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir, num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True) if not isinstance(orig_nid, dict): orig_nid = {g.ntypes[0]: orig_nid} if not isinstance(orig_eid, dict): orig_eid = {g.etypes[0]: orig_eid} pserver_list = [] ctx = mp.get_context('spawn') for i in range(num_server): p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, num_workers + 1)) p.start() time.sleep(1) pserver_list.append(p) os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_NUM_SAMPLER'] = str(num_workers) ptrainer_list = [] if dataloader_type == 'node': p = ctx.Process(target=start_node_dataloader, args=(0, tmpdir, num_server, num_workers, orig_nid, orig_eid, g)) p.start() ptrainer_list.append(p) elif dataloader_type == 'edge': p = ctx.Process(target=start_edge_dataloader, args=(0, tmpdir, num_server, num_workers, orig_nid, orig_eid, g)) p.start() ptrainer_list.append(p) for p in pserver_list: p.join() for p in ptrainer_list: p.join()
def test_multi_client(): reset_envs() os.environ['DGL_DIST_MODE'] = 'distributed' generate_ip_config("rpc_ip_config_mul_client.txt", 1, 1) ctx = mp.get_context('spawn') pserver = ctx.Process(target=start_server, args=(10, "rpc_ip_config_mul_client.txt")) pclient_list = [] for i in range(10): pclient = ctx.Process(target=start_client, args=("rpc_ip_config_mul_client.txt", )) pclient_list.append(pclient) pserver.start() for i in range(10): pclient_list[i].start() for i in range(10): pclient_list[i].join() pserver.join()
def test_multi_thread_rpc(net_type): reset_envs() os.environ['DGL_DIST_MODE'] = 'distributed' num_servers = 2 ip_config = "rpc_ip_config_multithread.txt" generate_ip_config(ip_config, num_servers, num_servers) ctx = mp.get_context('spawn') pserver_list = [] for i in range(num_servers): pserver = ctx.Process(target=start_server, args=(1, ip_config, i, False, 1, net_type)) pserver.start() pserver_list.append(pserver) def start_client_multithread(ip_config): import threading dgl.distributed.connect_to_server(ip_config=ip_config, num_servers=1, net_type=net_type) dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse) req = HelloRequest(STR, INTEGER, TENSOR, simple_func) dgl.distributed.send_request(0, req) def subthread_call(server_id): req = HelloRequest(STR, INTEGER, TENSOR, simple_func) dgl.distributed.send_request(server_id, req) subthread = threading.Thread(target=subthread_call, args=(1, )) subthread.start() subthread.join() res0 = dgl.distributed.recv_response() res1 = dgl.distributed.recv_response() # Order is not guaranteed assert_array_equal(F.asnumpy(res0.tensor), F.asnumpy(TENSOR)) assert_array_equal(F.asnumpy(res1.tensor), F.asnumpy(TENSOR)) dgl.distributed.exit_client() start_client_multithread(ip_config) pserver.join()
def test_multi_client(net_type): reset_envs() os.environ['DGL_DIST_MODE'] = 'distributed' ip_config = "rpc_ip_config_mul_client.txt" generate_ip_config(ip_config, 1, 1) ctx = mp.get_context('spawn') num_clients = 20 pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, 0, False, 1, net_type)) pclient_list = [] for i in range(num_clients): pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1, net_type)) pclient_list.append(pclient) pserver.start() for i in range(num_clients): pclient_list[i].start() for i in range(num_clients): pclient_list[i].join() pserver.join()
def test_kv_store(): reset_envs() num_servers = 2 num_clients = 2 generate_ip_config("kv_ip_config.txt", 1, num_servers) ctx = mp.get_context('spawn') pserver_list = [] pclient_list = [] os.environ['DGL_NUM_SERVER'] = str(num_servers) for i in range(num_servers): pserver = ctx.Process(target=start_server, args=(i, num_clients, num_servers)) pserver.start() pserver_list.append(pserver) for i in range(num_clients): pclient = ctx.Process(target=start_client, args=(num_clients, num_servers)) pclient.start() pclient_list.append(pclient) for i in range(num_clients): pclient_list[i].join() for i in range(num_servers): pserver_list[i].join()
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle): reset_envs() generate_ip_config("mp_ip_config.txt", num_server, num_server) g = CitationGraphDataset("cora")[0] print(g.idtype) num_parts = num_server num_hops = 1 orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir, num_hops=num_hops, part_method='metis', reshuffle=reshuffle, return_mapping=True) pserver_list = [] ctx = mp.get_context('spawn') for i in range(num_server): p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, num_workers + 1)) p.start() time.sleep(1) pserver_list.append(p) os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_NUM_SAMPLER'] = str(num_workers) ptrainer = ctx.Process(target=start_dist_dataloader, args=(0, tmpdir, num_server, drop_last, orig_nid, orig_eid)) ptrainer.start() for p in pserver_list: p.join() ptrainer.join()
def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server): """sample on bipartite via sample_etype_neighbors() which yields non-empty sample results""" generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = create_random_bipartite() num_parts = num_server num_hops = 1 orig_nids, _ = partition_graph(g, 'test_sampling', num_parts, tmpdir, num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True) pserver_list = [] ctx = mp.get_context('spawn') for i in range(num_server): p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_sampling')) p.start() time.sleep(1) pserver_list.append(p) fanout = 3 deg = get_degrees(g, orig_nids['game'], 'game') nids = F.nonzero_1d(deg > 0) block, gpb = start_bipartite_etype_sample_client(0, tmpdir, num_server > 1, fanout, nodes={ 'game': nids, 'user': [0] }) print("Done sampling") for p in pserver_list: p.join() orig_nid_map = { ntype: F.zeros((g.number_of_nodes(ntype), ), dtype=F.int64) for ntype in g.ntypes } orig_eid_map = { etype: F.zeros((g.number_of_edges(etype), ), dtype=F.int64) for etype in g.etypes } for i in range(num_server): part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i) ntype_ids, type_nids = gpb.map_to_per_ntype(part.ndata[dgl.NID]) for ntype_id, ntype in enumerate(g.ntypes): idx = ntype_ids == ntype_id F.scatter_row_inplace(orig_nid_map[ntype], F.boolean_mask(type_nids, idx), F.boolean_mask(part.ndata['orig_id'], idx)) etype_ids, type_eids = gpb.map_to_per_etype(part.edata[dgl.EID]) for etype_id, etype in enumerate(g.etypes): idx = etype_ids == etype_id F.scatter_row_inplace(orig_eid_map[etype], F.boolean_mask(type_eids, idx), F.boolean_mask(part.edata['orig_id'], idx)) for src_type, etype, dst_type in block.canonical_etypes: src, dst = block.edges(etype=etype) # These are global Ids after shuffling. shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src) shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst) shuffled_eid = block.edges[etype].data[dgl.EID] orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src)) orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst)) orig_eid = F.asnumpy(F.gather_row(orig_eid_map[etype], shuffled_eid)) # Check the node Ids and edge Ids. orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype) assert np.all(F.asnumpy(orig_src1) == orig_src) assert np.all(F.asnumpy(orig_dst1) == orig_dst)
def prepare_dist(): generate_ip_config("kv_ip_config.txt", 1, 1)
def prepare_dist(num_servers=1): generate_ip_config("kv_ip_config.txt", 1, num_servers=num_servers)