def test_sort_with_tag(idtype): num_nodes, num_adj, num_tags = 200, [20, 50], 5 g = create_test_heterograph(num_nodes, num_adj, idtype=idtype) tag = F.tensor(np.random.choice(num_tags, g.number_of_nodes())) new_g = dgl.sort_csr_by_tag(g, tag) old_csr = g.adjacency_matrix(scipy_fmt='csr') new_csr = new_g.adjacency_matrix(scipy_fmt='csr') assert (check_sort(new_csr, tag, new_g.ndata["_TAG_OFFSET"])) assert (not check_sort(old_csr, tag) ) # Check the original csr is not modified. new_g = dgl.sort_csc_by_tag(g, tag) old_csc = g.adjacency_matrix(transpose=True, scipy_fmt='csr') new_csc = new_g.adjacency_matrix(transpose=True, scipy_fmt='csr') assert (check_sort(new_csc, tag, new_g.ndata["_TAG_OFFSET"])) assert (not check_sort(old_csc, tag))
def test_sort_with_tag_bipartite(idtype): num_nodes, num_adj, num_tags = 200, [20, 50], 5 g = create_test_heterograph(num_nodes, num_adj, idtype=idtype) g = dgl.heterograph({('_U', '_E', '_V'): g.edges()}) utag = F.tensor(np.random.choice(num_tags, g.number_of_nodes('_U'))) vtag = F.tensor(np.random.choice(num_tags, g.number_of_nodes('_V'))) new_g = dgl.sort_csr_by_tag(g, vtag) old_csr = g.adjacency_matrix(scipy_fmt='csr') new_csr = new_g.adjacency_matrix(scipy_fmt='csr') assert (check_sort(new_csr, vtag, new_g.nodes['_U'].data['_TAG_OFFSET'])) assert (not check_sort(old_csr, vtag)) new_g = dgl.sort_csc_by_tag(g, utag) old_csc = g.adjacency_matrix(transpose=True, scipy_fmt='csr') new_csc = new_g.adjacency_matrix(transpose=True, scipy_fmt='csr') assert (check_sort(new_csc, utag, new_g.nodes['_V'].data['_TAG_OFFSET'])) assert (not check_sort(old_csc, utag))
def test_sample_neighbors_biased_bipartite(): g = create_test_graph(100, 30, True) num_dst = g.number_of_dst_nodes() bias = F.tensor([0, 0.01, 10, 10], dtype=F.float32) def check_num(nodes, tag): nodes, tag = F.asnumpy(nodes), F.asnumpy(tag) cnt = [sum(tag[nodes] == i) for i in range(4)] # No tag 0 assert cnt[0] == 0 # very rare tag 1 assert cnt[2] > 2 * cnt[1] assert cnt[3] > 2 * cnt[1] # inedge / without replacement tag = F.tensor(np.random.choice(4, 100)) g_sorted = dgl.sort_csc_by_tag(g, tag) for _ in range(5): subg = dgl.sampling.sample_neighbors_biased(g_sorted, g.dstnodes(), 5, bias, replace=False) check_num(subg.edges()[0], tag) u, v = subg.edges() edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) assert len(edge_set) == subg.number_of_edges() # inedge / with replacement for _ in range(5): subg = dgl.sampling.sample_neighbors_biased(g_sorted, g.dstnodes(), 5, bias, replace=True) check_num(subg.edges()[0], tag) # outedge / without replacement tag = F.tensor(np.random.choice(4, num_dst)) g_sorted = dgl.sort_csr_by_tag(g, tag) for _ in range(5): subg = dgl.sampling.sample_neighbors_biased(g_sorted, g.srcnodes(), 5, bias, edge_dir='out', replace=False) check_num(subg.edges()[1], tag) u, v = subg.edges() edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) assert len(edge_set) == subg.number_of_edges() # outedge / with replacement for _ in range(5): subg = dgl.sampling.sample_neighbors_biased(g_sorted, g.srcnodes(), 5, bias, edge_dir='out', replace=True) check_num(subg.edges()[1], tag)
def test_sort_with_tag(idtype): num_nodes, num_adj, num_tags = 200, [20, 50], 5 g = create_test_heterograph(num_nodes, num_adj, idtype=idtype) tag = F.tensor(np.random.choice(num_tags, g.number_of_nodes())) src, dst = g.edges() edge_tag_dst = F.gather_row(tag, F.tensor(dst)) edge_tag_src = F.gather_row(tag, F.tensor(src)) for tag_type in ['node', 'edge']: new_g = dgl.sort_csr_by_tag( g, tag if tag_type == 'node' else edge_tag_dst, tag_type=tag_type) old_csr = g.adjacency_matrix(scipy_fmt='csr') new_csr = new_g.adjacency_matrix(scipy_fmt='csr') assert(check_sort(new_csr, tag, new_g.dstdata["_TAG_OFFSET"])) assert(not check_sort(old_csr, tag)) # Check the original csr is not modified. for tag_type in ['node', 'edge']: new_g = dgl.sort_csc_by_tag( g, tag if tag_type == 'node' else edge_tag_src, tag_type=tag_type) old_csc = g.adjacency_matrix(transpose=True, scipy_fmt='csr') new_csc = new_g.adjacency_matrix(transpose=True, scipy_fmt='csr') assert(check_sort(new_csc, tag, new_g.srcdata["_TAG_OFFSET"])) assert(not check_sort(old_csc, tag))