def dfs_order(forest, roots): forest = tocpu(forest) edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True) for e, l in zip(*edges): # I exploited the fact that the reverse edge ID equal to 1 xor forward # edge ID for molecule trees. Normally, I should locate reverse edges # using find_edges(). yield e ^ l, l
def test_dfs_labeled_edges(index_dtype, example=False): dgl_g = dgl.DGLGraph() dgl_g.add_nodes(6) dgl_g.add_edges([0, 1, 0, 3, 3], [1, 2, 2, 4, 5]) if index_dtype == 'int32': dgl_g = dgl.graph(dgl_g.edges()).int() else: dgl_g = dgl.graph(dgl_g.edges()).long() dgl_edges, dgl_labels = dgl.dfs_labeled_edges_generator( dgl_g, [0, 3], has_reverse_edge=True, has_nontree_edge=True) dgl_edges = [toset(t) for t in dgl_edges] dgl_labels = [toset(t) for t in dgl_labels] g1_solutions = [ # edges labels [[0, 1, 1, 0, 2], [0, 0, 1, 1, 2]], [[2, 2, 0, 1, 0], [0, 1, 0, 2, 1]], ] g2_solutions = [ # edges labels [[3, 3, 4, 4], [0, 1, 0, 1]], [[4, 4, 3, 3], [0, 1, 0, 1]], ] def combine_frontiers(sol): es, ls = zip(*sol) es = [ set(i for i in t if i is not None) for t in itertools.zip_longest(*es) ] ls = [ set(i for i in t if i is not None) for t in itertools.zip_longest(*ls) ] return es, ls for sol_set in itertools.product(g1_solutions, g2_solutions): es, ls = combine_frontiers(sol_set) if es == dgl_edges and ls == dgl_labels: break else: assert False