def load_dataset(name): task = "graph" if name == "enzymes": dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES") elif name == "cox2": dataset = TUDataset(root="/tmp/cox2", name="COX2") elif name == "imdb-binary": dataset = TUDataset(root="/tmp/IMDB-BINARY", name="IMDB-BINARY") if task == "graph": dataset = GraphDataset(GraphDataset.pyg_to_graphs(dataset)) dataset = dataset.apply_transform( lambda g: g.G.subgraph(max(nx.connected_components(g.G), key=len))) dataset = dataset.filter(lambda g: len(g.G) >= 6) train, test = dataset.split(split_ratio=[0.8, 0.2]) return train, test, task
def load_dataset(name): def add_feats(graph): for v in graph.G.nodes: graph.G.nodes[v]["node_feature"] = torch.ones(1) return graph task = "graph" if name == "enzymes": dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES") elif name == "cox2": dataset = TUDataset(root="/tmp/cox2", name="COX2") elif name == "imdb-binary": dataset = TUDataset(root="/tmp/IMDB-BINARY", name="IMDB-BINARY") if task == "graph": dataset = GraphDataset(GraphDataset.pyg_to_graphs(dataset)) # add blank features for imdb-binary, which doesn't have node labels if name == "imdb-binary": dataset = dataset.apply_transform(add_feats) dataset = dataset.apply_transform( lambda g: g.G.subgraph(max(nx.connected_components(g.G), key=len))) dataset = dataset.filter(lambda g: len(g.G) >= 6) train, test = dataset.split(split_ratio=[0.8, 0.2]) return train, test, task