def test_bfs(index_dtype, n=100): def _bfs_nx(g_nx, src): edges = nx.bfs_edges(g_nx, src) layers_nx = [set([src])] edges_nx = [] frontier = set() edge_frontier = set() for u, v in edges: if u in layers_nx[-1]: frontier.add(v) edge_frontier.add(g.edge_id(u, v)) else: layers_nx.append(frontier) edges_nx.append(edge_frontier) frontier = set([v]) edge_frontier = set([g.edge_id(u, v)]) # avoids empty successors if len(frontier) > 0 and len(edge_frontier) > 0: layers_nx.append(frontier) edges_nx.append(edge_frontier) return layers_nx, edges_nx g = dgl.DGLGraph() a = sp.random(n, n, 3 / n, data_rvs=lambda n: np.ones(n)) g.from_scipy_sparse_matrix(a) if index_dtype == 'int32': g = dgl.graph(g.edges()).int() else: g = dgl.graph(g.edges()).long() g_nx = g.to_networkx() src = random.choice(range(n)) layers_nx, _ = _bfs_nx(g_nx, src) layers_dgl = dgl.bfs_nodes_generator(g, src) assert len(layers_dgl) == len(layers_nx) assert all(toset(x) == y for x, y in zip(layers_dgl, layers_nx)) g_nx = nx.random_tree(n, seed=42) g = dgl.DGLGraph() g.from_networkx(g_nx) if index_dtype == 'int32': g = dgl.graph(g.edges()).int() else: g = dgl.graph(g.edges()).long() src = 0 _, edges_nx = _bfs_nx(g_nx, src) edges_dgl = dgl.bfs_edges_generator(g, src) assert len(edges_dgl) == len(edges_nx) assert all(toset(x) == y for x, y in zip(edges_dgl, edges_nx))
def get_perturbed_inputs(i): g = dgl.DGLGraph() g.add_nodes(len(init_nodes[i])) edges = np.nonzero(init_edges[i]).T g.add_edges(*edges) all_radii_nbr_node_inds = dgl.bfs_nodes_generator( g, corruption_indices[i]) flattened_d_hop_node_inds = torch.cat(all_radii_nbr_node_inds) num_nodes_per_radius = [len(r) for r in all_radii_nbr_node_inds] # get 3d tensor for one graph where each slice has one neighbouring node corrupted d_init_nodes = torch.stack([init_nodes[i]] * len(flattened_d_hop_node_inds)) d_node_target_types = torch.stack([node_target_types[i]] * len(flattened_d_hop_node_inds)) d_node_target_inds_vector = torch.stack( [node_target_inds_vector[i]] * len(flattened_d_hop_node_inds)) d_init_edges = torch.stack([init_edges[i]] * len(flattened_d_hop_node_inds)) d_node_masks = torch.stack([node_masks[i]] * len(flattened_d_hop_node_inds)) d_edge_masks = torch.stack([edge_masks[i]] * len(flattened_d_hop_node_inds)) d_init_hydrogens = torch.stack([init_hydrogens[i]] * len(flattened_d_hop_node_inds)) d_original_node_inds = torch.stack([original_node_inds[i]] * len(flattened_d_hop_node_inds)) # Don't corrupt source node for k, j in enumerate(flattened_d_hop_node_inds[1:], 1): original_node = init_nodes[i][j] valid_choices = np.setdiff1d( np.arange(len(QM9_SYMBOL_LIST[:5])), original_node.cpu()) replacement_node = np.random.choice(valid_choices) d_init_nodes[k][j] = float(replacement_node) if replace_hs is True: original_h = init_hydrogens[i][j] valid_h_choices = np.setdiff1d(np.arange(max_hs + 2), original_h.cpu()) replacement_h = np.random.choice(valid_h_choices) d_init_hydrogens[k][j] = float(replacement_h) return d_init_nodes, d_node_target_types, d_node_target_inds_vector, d_init_edges, d_node_masks,\ d_edge_masks, d_init_hydrogens, d_original_node_inds, num_nodes_per_radius