def test_filter(self): pyg_dataset = TUDataset('./enzymes', 'ENZYMES') graphs = GraphDataset.pyg_to_graphs(pyg_dataset) dataset = GraphDataset(graphs, task="graph") thresh = 90 orig_dataset_size = len(dataset) num_graphs_large = 0 for graph in dataset: if len(graph.G) >= thresh: num_graphs_large += 1 dataset = dataset.filter( lambda graph: len(graph.G) < thresh, deep_copy=False) filtered_dataset_size = len(dataset) self.assertEqual( orig_dataset_size - filtered_dataset_size, num_graphs_large)
def test_filter(self): pyg_dataset = TUDataset("./enzymes", "ENZYMES") ds = pyg_to_dicts(pyg_dataset) graphs = [Graph(**item) for item in ds] dataset = GraphDataset(graphs, task="graph") thresh = 90 orig_dataset_size = len(dataset) num_graphs_large = 0 for graph in dataset: if graph.num_nodes >= thresh: num_graphs_large += 1 dataset = dataset.filter(lambda graph: graph.num_nodes < thresh, deep_copy=False) filtered_dataset_size = len(dataset) self.assertEqual( orig_dataset_size - filtered_dataset_size, num_graphs_large, )