Пример #1
0
def test_bfs(index_dtype, n=100):
    def _bfs_nx(g_nx, src):
        edges = nx.bfs_edges(g_nx, src)
        layers_nx = [set([src])]
        edges_nx = []
        frontier = set()
        edge_frontier = set()
        for u, v in edges:
            if u in layers_nx[-1]:
                frontier.add(v)
                edge_frontier.add(g.edge_id(u, v))
            else:
                layers_nx.append(frontier)
                edges_nx.append(edge_frontier)
                frontier = set([v])
                edge_frontier = set([g.edge_id(u, v)])
        # avoids empty successors
        if len(frontier) > 0 and len(edge_frontier) > 0:
            layers_nx.append(frontier)
            edges_nx.append(edge_frontier)
        return layers_nx, edges_nx

    g = dgl.DGLGraph()
    a = sp.random(n, n, 3 / n, data_rvs=lambda n: np.ones(n))
    g.from_scipy_sparse_matrix(a)
    if index_dtype == 'int32':
        g = dgl.graph(g.edges()).int()
    else:
        g = dgl.graph(g.edges()).long()

    g_nx = g.to_networkx()
    src = random.choice(range(n))
    layers_nx, _ = _bfs_nx(g_nx, src)
    layers_dgl = dgl.bfs_nodes_generator(g, src)
    assert len(layers_dgl) == len(layers_nx)
    assert all(toset(x) == y for x, y in zip(layers_dgl, layers_nx))

    g_nx = nx.random_tree(n, seed=42)
    g = dgl.DGLGraph()
    g.from_networkx(g_nx)
    if index_dtype == 'int32':
        g = dgl.graph(g.edges()).int()
    else:
        g = dgl.graph(g.edges()).long()

    src = 0
    _, edges_nx = _bfs_nx(g_nx, src)
    edges_dgl = dgl.bfs_edges_generator(g, src)
    assert len(edges_dgl) == len(edges_nx)
    assert all(toset(x) == y for x, y in zip(edges_dgl, edges_nx))
        def get_perturbed_inputs(i):
            g = dgl.DGLGraph()
            g.add_nodes(len(init_nodes[i]))
            edges = np.nonzero(init_edges[i]).T
            g.add_edges(*edges)
            all_radii_nbr_node_inds = dgl.bfs_nodes_generator(
                g, corruption_indices[i])
            flattened_d_hop_node_inds = torch.cat(all_radii_nbr_node_inds)
            num_nodes_per_radius = [len(r) for r in all_radii_nbr_node_inds]

            # get 3d tensor for one graph where each slice has one neighbouring node corrupted
            d_init_nodes = torch.stack([init_nodes[i]] *
                                       len(flattened_d_hop_node_inds))
            d_node_target_types = torch.stack([node_target_types[i]] *
                                              len(flattened_d_hop_node_inds))
            d_node_target_inds_vector = torch.stack(
                [node_target_inds_vector[i]] * len(flattened_d_hop_node_inds))
            d_init_edges = torch.stack([init_edges[i]] *
                                       len(flattened_d_hop_node_inds))
            d_node_masks = torch.stack([node_masks[i]] *
                                       len(flattened_d_hop_node_inds))
            d_edge_masks = torch.stack([edge_masks[i]] *
                                       len(flattened_d_hop_node_inds))
            d_init_hydrogens = torch.stack([init_hydrogens[i]] *
                                           len(flattened_d_hop_node_inds))
            d_original_node_inds = torch.stack([original_node_inds[i]] *
                                               len(flattened_d_hop_node_inds))
            # Don't corrupt source node
            for k, j in enumerate(flattened_d_hop_node_inds[1:], 1):
                original_node = init_nodes[i][j]
                valid_choices = np.setdiff1d(
                    np.arange(len(QM9_SYMBOL_LIST[:5])), original_node.cpu())
                replacement_node = np.random.choice(valid_choices)
                d_init_nodes[k][j] = float(replacement_node)
                if replace_hs is True:
                    original_h = init_hydrogens[i][j]
                    valid_h_choices = np.setdiff1d(np.arange(max_hs + 2),
                                                   original_h.cpu())
                    replacement_h = np.random.choice(valid_h_choices)
                    d_init_hydrogens[k][j] = float(replacement_h)

            return d_init_nodes, d_node_target_types, d_node_target_inds_vector, d_init_edges, d_node_masks,\
                   d_edge_masks, d_init_hydrogens, d_original_node_inds, num_nodes_per_radius