def test_validation_split(self): dataset = Shrec2016DualPrimal(root=osp.join( current_dir, '../../datasets_no_augmentation/'), categories=[], single_dual_nodes=False, undirected_dual_edges=True, num_augmentations=1) validation_set_fraction = 0.1 num_samples_train = int(len(dataset) * (1. - validation_set_fraction)) dataset_train = dataset[:num_samples_train] dataset_validation = dataset[num_samples_train:] print("\n") for category_idx in range(30): print(f"* Category {category_idx}:") elements_in_training_set = [ idx for idx, i in enumerate(dataset_train) if i[0].y.item() == category_idx ] elements_in_validation_set = [ idx for idx, i in enumerate(dataset_validation) if i[0].y.item() == category_idx ] self.assertGreater(len(elements_in_training_set), 0) self.assertGreater(len(elements_in_validation_set), 0) print("\t- Num elements in training set: " f"{len(elements_in_training_set)}") print("\t- Num elements in validation set: " f"{len(elements_in_validation_set)}")
def test_batch_formation(self): dataset = Shrec2016DualPrimal(root=osp.abspath( osp.join(current_dir, '../common_data/shrec2016/')), categories=['shark'], single_dual_nodes=False, undirected_dual_edges=True, vertices_scale_mean=1., vertices_scale_var=0.1, edges_flip_fraction=0.5, slide_vertices_fraction=0.2, num_augmentations=4) batch_size = 4 data_loader = DualPrimalDataLoader(dataset=dataset, batch_size=batch_size, shuffle=True) dataset_len = len(dataset) self.assertEqual(dataset_len, 64) for i, (primal_graph, dual_graph, _) in enumerate(data_loader): # Each mesh has 750 edges, hence the number of dual nodes/primal # edges, considering the directness of the graphs, must be # 750 * 2 * batch_size. self.assertEqual(primal_graph.num_edges, 750 * 2 * batch_size) self.assertEqual(dual_graph.x.shape[0], 750 * 2 * batch_size) self.assertEqual(dual_graph.batch.shape[0], 750 * 2 * batch_size) num_batches = i + 1 self.assertEqual(num_batches, 16)
def test_sample_indices_shuffle(self): dataset = Shrec2016DualPrimal(root=osp.abspath( osp.join(current_dir, '../common_data/shrec2016/')), categories=['shark'], single_dual_nodes=False, undirected_dual_edges=True, vertices_scale_mean=1., vertices_scale_var=0.1, edges_flip_fraction=0.5, slide_vertices_fraction=0.2, num_augmentations=4, return_sample_indices=True) self.assertEqual(len(dataset), 64) sample_indices_found = set() data_loader = DualPrimalDataLoader(dataset=dataset, batch_size=8, shuffle=True, return_sample_indices=True) for _, _, _, sample_indices in data_loader: for sample_idx in sample_indices: sample_indices_found.add(sample_idx) self.assertEqual(len(sample_indices_found), len(dataset)) self.assertEqual(list(sample_indices_found), [*range(0, 64)])
def create_dataset(dataset_name, compute_node_feature_stats=True, node_feature_stats_filename=None, **dataset_params): r"""Creates an instance of the input dataset with the input parameters. Args: dataset_name (str): Name that identifies the dataset. Valid values are: `shrec_16` (SHREC2016 dataset), `cubes` (Cube engraving dataset from MeshCNN), `coseg` (COSEG dataset), `human_seg` (Human Body Segmentation dataset). compute_node_feature_stats (bool, optional): If True, the mean and standard deviation of the node features in the dataset are computed (cf. function :obj:`compute_mean_and_std`). (default: :obj:`True`) node_feature_stats_filename (str, optional): If not None, filename of the file containing the mean and standard deviation of the node features in the dataset (cf. function :obj:`compute_mean_and_std`). The argument is considered only if the argument `compute_node_feature_stats` is True. (default: :obj:`None`) ... Optional parameters of the datasets. Returns: dataset (torch_geometric.data.Dataset): The instance of the dataset with the input parameters. primal_graph_mean, primal_graph_std, dual_graph_mean, dual_graph_std (tuple of numpy array or None): If the argument :obj:`compute_node_feature_stats` is True, statistics about the node features in the dataset (cf. docs above); otherwise, None. """ if (dataset_name == 'shrec_16'): dataset = Shrec2016DualPrimal(**dataset_params) elif (dataset_name == 'cubes'): dataset = CubesDualPrimal(**dataset_params) elif (dataset_name == 'coseg'): dataset = CosegDualPrimal(**dataset_params) elif (dataset_name == 'human_seg'): dataset = HumanSegDualPrimal(**dataset_params) else: raise KeyError( f"No known dataset can be generated with the name '{dataset_name}'." ) node_statistics = None if (compute_node_feature_stats): dataset_params = dataset.input_parameters (primal_graph_mean, primal_graph_std, dual_graph_mean, dual_graph_std) = compute_mean_and_std( dataset=dataset, dataset_params=dataset_params, filename=node_feature_stats_filename) node_statistics = (primal_graph_mean, primal_graph_std, dual_graph_mean, dual_graph_std) return dataset, node_statistics
def test_download_process_and_get(self): dataset = Shrec2016DualPrimal(root=osp.abspath( osp.join(current_dir, '../common_data/shrec2016/')), categories=['shark'], single_dual_nodes=False, undirected_dual_edges=True, vertices_scale_mean=1., vertices_scale_var=0.1, edges_flip_fraction=0.5, slide_vertices_fraction=0.2, num_augmentations=4) # Get one element of the dataset. print(dataset[0])
def test_feature_standardization(self): dataset = Shrec2016DualPrimal(root=osp.abspath( osp.join(current_dir, '../common_data/shrec2016/')), categories=['shark'], single_dual_nodes=False, undirected_dual_edges=True, vertices_scale_mean=1., vertices_scale_var=0.1, edges_flip_fraction=0.5, slide_vertices_fraction=0.2, num_augmentations=4) dataset_params = dataset.input_parameters primal_mean, primal_std, dual_mean, dual_std = compute_mean_and_std( dataset=dataset, dataset_params=dataset_params, filename=osp.join(current_dir, '../output_data/dataset_params.pkl')) # Create batch of the size of the entire dataset, with standardized # features. The mean of the primal-graph-/dual-graph- node features # should now be 0, while the standard deviation should be 1. primal_graph_list = [] dual_graph_list = [] primal_edge_to_dual_node_idx_list = [] for (primal_graph, dual_graph, primal_edge_to_dual_node_idx, _) in dataset: primal_graph_list.append(primal_graph) dual_graph_list.append(dual_graph) primal_edge_to_dual_node_idx_list.append( primal_edge_to_dual_node_idx) (primal_graph_batch, dual_graph_batch, _) = create_dual_primal_batch(primal_graph_list, dual_graph_list, primal_edge_to_dual_node_idx_list, primal_mean=primal_mean, primal_std=primal_std, dual_mean=dual_mean, dual_std=dual_std) self.assertAlmostEqual(primal_graph_batch.x.mean(axis=0).item(), 0., 5) self.assertAlmostEqual(primal_graph_batch.x.std(axis=0).item(), 1., 5) for dual_features_idx in range(4): self.assertAlmostEqual( dual_graph_batch.x.mean(axis=0)[dual_features_idx].item(), 0., 4) self.assertAlmostEqual( dual_graph_batch.x.std(axis=0)[dual_features_idx].item(), 1., 4)
def test_sample_indices_no_shuffle(self): dataset = Shrec2016DualPrimal(root=osp.abspath( osp.join(current_dir, '../common_data/shrec2016/')), categories=['shark'], single_dual_nodes=False, undirected_dual_edges=True, vertices_scale_mean=1., vertices_scale_var=0.1, edges_flip_fraction=0.5, slide_vertices_fraction=0.2, num_augmentations=4, return_sample_indices=True) self.assertEqual(len(dataset), 64) data_loader = DualPrimalDataLoader(dataset=dataset, batch_size=8, shuffle=False, return_sample_indices=True) for batch_idx, (_, _, _, sample_indices) in enumerate(data_loader): self.assertEqual(sample_indices, [*range(batch_idx * 8, batch_idx * 8 + 8)])
def test_shrec_multiple_classes(self): # The shape in SHREC are closed, manifold meshes. Therefore the # associated 'primal graph' (simplex mesh) and 'dual graph' (medial # graph) should be respectively 3-regular and 4-regular (when directness # is not considered). root_shrec = osp.abspath( osp.join(current_dir, '../common_data/shrec2016_shark_gorilla/')) dataset = Shrec2016DualPrimal(root=root_shrec, train=True, categories=['shark', 'gorilla'], single_dual_nodes=False, undirected_dual_edges=True, vertices_scale_mean=1., vertices_scale_var=0.1, edges_flip_fraction=0.5, slide_vertices_fraction=0.2, num_augmentations=2) batch_size = 4 data_loader = DualPrimalDataLoader(dataset=dataset, batch_size=batch_size, shuffle=True) dataset_size = len(dataset) # There are 16 'gorilla' shapes and 16 'shark' shapes in the 'train' # split of SHREC2016. Counting 2 versions per shape due to data # augmentation, one has (16 + 16) * 2 = 64 shapes in total. self.assertEqual(dataset_size, 64) # Check the ground-truth class index of the samples in the batch. shark_class_index = Shrec2016DualPrimal.valid_categories.index('shark') gorilla_class_index = Shrec2016DualPrimal.valid_categories.index( 'gorilla') for (primal_graph_batch, _, _) in data_loader: self.assertEqual(primal_graph_batch.y.size(), (batch_size, )) for primal_graph_idx in range(batch_size): sample_class_index = primal_graph_batch.y[ primal_graph_idx].item() self.assertTrue(sample_class_index in [shark_class_index, gorilla_class_index])
def _test_compute_mean_and_std(self): dataset = Shrec2016DualPrimal(root=osp.abspath( osp.join(current_dir, '../common_data/shrec2016/')), categories=['shark'], single_dual_nodes=False, undirected_dual_edges=True, vertices_scale_mean=1., vertices_scale_var=0.1, edges_flip_fraction=0.5, slide_vertices_fraction=0.2, num_augmentations=4) dataset_params = dataset.input_parameters primal_mean, primal_std, dual_mean, dual_std = compute_mean_and_std( dataset=dataset, dataset_params=dataset_params, filename=osp.join(current_dir, '../output_data/dataset_params.pkl')) # Store statistics, so that they can be compared with those computed for # other tests using the same dataset input parameters. self.__class__.__primal_mean = primal_mean self.__class__.__primal_std = primal_std self.__class__.__dual_mean = dual_mean self.__class__.__dual_std = dual_std
def test_forward_pass(self): # NOTE: this is not an actual unit test, but it is just used to verify # that the forward pass goes through and the output features and graphs # have the expected format. single_dual_nodes = False undirected_dual_edges = True dataset = Shrec2016DualPrimal( root=osp.join(current_dir, '../common_data/shrec2016'), categories=['shark'], single_dual_nodes=single_dual_nodes, undirected_dual_edges=undirected_dual_edges, vertices_scale_mean=1., vertices_scale_var=0.1, edges_flip_fraction=0.5, slide_vertices_fraction=0.2, num_augmentations=1) batch_size = 4 data_loader = DualPrimalDataLoader(dataset=dataset, batch_size=batch_size, shuffle=True) # Down-convolutional layer. out_channels_primal = 5 out_channels_dual = 5 out_channels_primal_after_encoder = 7 out_channels_dual_after_encoder = 7 heads = 1 concat_primal = True concat_dual = True negative_slope_primal = 0.2 negative_slope_dual = 0.2 dropout_primal = 0 dropout_dual = 0 bias_primal = True bias_dual = True add_self_loops_to_dual_graph = False down_conv = DualPrimalDownConv( in_channels_primal=1, in_channels_dual=4, out_channels_primal=out_channels_primal, out_channels_dual=out_channels_dual, heads=heads, concat_primal=concat_primal, concat_dual=concat_dual, negative_slope_primal=negative_slope_primal, negative_slope_dual=negative_slope_dual, dropout_primal=dropout_primal, dropout_dual=dropout_dual, bias_primal=bias_primal, bias_dual=bias_dual, single_dual_nodes=single_dual_nodes, undirected_dual_edges=undirected_dual_edges, add_self_loops_to_dual_graph=add_self_loops_to_dual_graph, num_primal_edges_to_keep=600, num_skips=3, return_old_dual_node_to_new_dual_node=True, return_graphs_before_pooling=True) # Dual-primal convolutional layer. conv = DualPrimalConv( in_channels_primal=out_channels_primal, in_channels_dual=out_channels_dual, out_channels_primal=out_channels_primal_after_encoder, out_channels_dual=out_channels_dual_after_encoder, single_dual_nodes=single_dual_nodes, undirected_dual_edges=undirected_dual_edges, heads=heads, concat_primal=concat_primal, concat_dual=concat_dual, negative_slope_primal=negative_slope_primal, negative_slope_dual=negative_slope_dual, dropout_primal=dropout_primal, dropout_dual=dropout_dual, bias_primal=bias_primal, bias_dual=bias_dual, add_self_loops_to_dual_graph=add_self_loops_to_dual_graph) # Up-convolutional layer. up_conv = DualPrimalUpConv( in_channels_primal=out_channels_primal_after_encoder, in_channels_dual=out_channels_dual_after_encoder, out_channels_primal=out_channels_primal, out_channels_dual=out_channels_dual, single_dual_nodes=single_dual_nodes, undirected_dual_edges=undirected_dual_edges, concat_data_from_before_pooling=True) # Forward pass. for primal_graph, dual_graph, petdni in data_loader: (primal_graph_after_pooling, dual_graph_after_pooling, petdni_after_pooling, log_info, primal_graph_before_pooling, dual_graph_before_pooling) = down_conv( primal_graph_batch=primal_graph, dual_graph_batch=dual_graph, primal_edge_to_dual_node_idx_batch=petdni) self.assertNotEqual(log_info, None) (primal_graph_after_pooling.x, dual_graph_after_pooling.x) = conv( x_primal=primal_graph_after_pooling.x, x_dual=dual_graph_after_pooling.x, edge_index_primal=primal_graph_after_pooling.edge_index, edge_index_dual=dual_graph_after_pooling.edge_index, primal_edge_to_dual_node_idx=petdni_after_pooling) (primal_graph_batch_out, dual_graph_batch_out, primal_edge_to_dual_node_idx_batch_out) = up_conv( primal_graph_batch=primal_graph_after_pooling, dual_graph_batch=dual_graph_after_pooling, primal_edge_to_dual_node_idx_batch=petdni_after_pooling, pooling_log=log_info, primal_graph_batch_before_pooling=primal_graph_before_pooling, dual_graph_batch_before_pooling=dual_graph_before_pooling) # Check that the original graphs and the 'unpooled' ones match in # connectivity. self.assertTrue( torch.equal(primal_graph.edge_index, primal_graph_batch_out.edge_index)) self.assertTrue( torch.equal(dual_graph.edge_index, dual_graph_batch_out.edge_index)) self.assertEqual(petdni, primal_edge_to_dual_node_idx_batch_out) # Check that the number of output channels of the new features match the # one of the features in the original graphs. self.assertEqual(primal_graph.num_node_features, primal_graph_after_pooling.num_node_features) self.assertEqual(dual_graph.num_node_features, dual_graph_after_pooling.num_node_features)
def test_forward_pass(self): # NOTE: this is not an actual unit test, but it is just used to verify # that the forward pass goes through. single_dual_nodes = False undirected_dual_edges = True dataset = Shrec2016DualPrimal( root=osp.join(current_dir, '../common_data/shrec2016'), categories=['shark'], single_dual_nodes=single_dual_nodes, undirected_dual_edges=undirected_dual_edges, vertices_scale_mean=1., vertices_scale_var=0.1, edges_flip_fraction=0.5, slide_vertices_fraction=0.2, num_augmentations=1) batch_size = 4 data_loader = DualPrimalDataLoader(dataset=dataset, batch_size=batch_size, shuffle=True) # Test without pooling. down_conv = DualPrimalDownConv( in_channels_primal=1, in_channels_dual=4, out_channels_primal=5, out_channels_dual=5, heads=1, concat_primal=True, concat_dual=True, negative_slope_primal=0.2, negative_slope_dual=0.2, dropout_primal=0, dropout_dual=0, bias_primal=True, bias_dual=True, single_dual_nodes=single_dual_nodes, undirected_dual_edges=undirected_dual_edges, add_self_loops_to_dual_graph=False, num_skips=3) for primal_graph, dual_graph, petdni in data_loader: (primal_graph_out, dual_graph_out, petdni_out, log_info, _, _) = down_conv(primal_graph_batch=primal_graph, dual_graph_batch=dual_graph, primal_edge_to_dual_node_idx_batch=petdni) self.assertEqual(log_info, None) # Tests with pooling. down_conv = DualPrimalDownConv( in_channels_primal=1, in_channels_dual=4, out_channels_primal=5, out_channels_dual=5, heads=1, concat_primal=True, concat_dual=True, negative_slope_primal=0.2, negative_slope_dual=0.2, dropout_primal=0, dropout_dual=0, bias_primal=True, bias_dual=True, single_dual_nodes=single_dual_nodes, undirected_dual_edges=undirected_dual_edges, add_self_loops_to_dual_graph=False, num_primal_edges_to_keep=600, num_skips=3) for primal_graph, dual_graph, petdni in data_loader: (primal_graph_out, dual_graph_out, petdni_out, log_info, _, _) = down_conv(primal_graph_batch=primal_graph, dual_graph_batch=dual_graph, primal_edge_to_dual_node_idx_batch=petdni) self.assertNotEqual(log_info, None) down_conv = DualPrimalDownConv( in_channels_primal=1, in_channels_dual=4, out_channels_primal=5, out_channels_dual=5, heads=1, concat_primal=True, concat_dual=True, negative_slope_primal=0.2, negative_slope_dual=0.2, dropout_primal=0, dropout_dual=0, bias_primal=True, bias_dual=True, single_dual_nodes=single_dual_nodes, undirected_dual_edges=undirected_dual_edges, add_self_loops_to_dual_graph=False, fraction_primal_edges_to_keep=0.7, num_skips=3) for primal_graph, dual_graph, petdni in data_loader: (primal_graph_out, dual_graph_out, petdni_out, log_info, _, _) = down_conv(primal_graph_batch=primal_graph, dual_graph_batch=dual_graph, primal_edge_to_dual_node_idx_batch=petdni) self.assertNotEqual(log_info, None) down_conv = DualPrimalDownConv( in_channels_primal=1, in_channels_dual=4, out_channels_primal=5, out_channels_dual=5, heads=1, concat_primal=True, concat_dual=True, negative_slope_primal=0.2, negative_slope_dual=0.2, dropout_primal=0, dropout_dual=0, bias_primal=True, bias_dual=True, single_dual_nodes=single_dual_nodes, undirected_dual_edges=undirected_dual_edges, add_self_loops_to_dual_graph=False, primal_attention_coeff_threshold=0.5, num_skips=3) for primal_graph, dual_graph, petdni in data_loader: (primal_graph_out, dual_graph_out, petdni_out, log_info, _, _) = down_conv(primal_graph_batch=primal_graph, dual_graph_batch=dual_graph, primal_edge_to_dual_node_idx_batch=petdni) self.assertNotEqual(log_info, None) down_conv = DualPrimalDownConv( in_channels_primal=1, in_channels_dual=4, out_channels_primal=5, out_channels_dual=5, heads=1, concat_primal=True, concat_dual=True, negative_slope_primal=0.2, negative_slope_dual=0.2, dropout_primal=0, dropout_dual=0, bias_primal=True, bias_dual=True, single_dual_nodes=single_dual_nodes, undirected_dual_edges=undirected_dual_edges, add_self_loops_to_dual_graph=False, num_primal_edges_to_keep=600, allow_pooling_consecutive_edges=False, num_skips=3) for primal_graph, dual_graph, petdni in data_loader: (primal_graph_out, dual_graph_out, petdni_out, log_info, _, _) = down_conv(primal_graph_batch=primal_graph, dual_graph_batch=dual_graph, primal_edge_to_dual_node_idx_batch=petdni) self.assertNotEqual(log_info, None) down_conv = DualPrimalDownConv( in_channels_primal=1, in_channels_dual=4, out_channels_primal=5, out_channels_dual=5, heads=3, concat_primal=True, concat_dual=True, negative_slope_primal=0.2, negative_slope_dual=0.2, dropout_primal=0, dropout_dual=0, bias_primal=True, bias_dual=True, single_dual_nodes=single_dual_nodes, undirected_dual_edges=undirected_dual_edges, add_self_loops_to_dual_graph=False, num_primal_edges_to_keep=600, use_decreasing_attention_coefficients=False, num_skips=3) for primal_graph, dual_graph, petdni in data_loader: (primal_graph_out, dual_graph_out, petdni_out, log_info, _, _) = down_conv(primal_graph_batch=primal_graph, dual_graph_batch=dual_graph, primal_edge_to_dual_node_idx_batch=petdni) self.assertNotEqual(log_info, None)
def test_right_graph_connectivity(self): # The shape in SHREC are closed, manifold meshes. Therefore the # associated 'primal graph' (simplex mesh) and 'dual graph' (medial # graph) should be respectively 3-regular and 4-regular (when directness # is not considered). root_shrec = osp.abspath( osp.join(current_dir, '../common_data/shrec2016/')) processed_shrec_fold = osp.join(root_shrec, 'processed/') if (osp.exists(processed_shrec_fold)): sys.stdout.write( "\nWarning: running the following test will cause the folder " f"'{processed_shrec_fold}' to be deleted! ") valid_choice = False while (not valid_choice): sys.stdout.write("Do you want to continue? [y/n] ") user_input = input().lower() if (user_input == 'y'): print("Removing folder...") shutil.rmtree(processed_shrec_fold) valid_choice = True elif (user_input == 'n'): print("Skipping test.") valid_choice = True return else: sys.stdout.write( "Please respond with 'y'/'Y' or 'n'/'N'.\n") print("Running test...") dataset = Shrec2016DualPrimal(root=root_shrec, categories=['shark'], single_dual_nodes=False, undirected_dual_edges=True, vertices_scale_mean=1., vertices_scale_var=0.1, edges_flip_fraction=0.5, slide_vertices_fraction=0.2, num_augmentations=4) for primal_graph, dual_graph, _, _ in dataset: neighbors_incoming_edges_primal = dict() neighbors_outgoing_edges_primal = dict() for a, b in primal_graph.edge_index.t(): if (not a.item() in neighbors_incoming_edges_primal): neighbors_incoming_edges_primal[a.item()] = [b.item()] else: neighbors_incoming_edges_primal[a.item()].append(b.item()) if (not b.item() in neighbors_outgoing_edges_primal): neighbors_outgoing_edges_primal[b.item()] = [a.item()] else: neighbors_outgoing_edges_primal[b.item()].append(a.item()) self.assertEqual( len([ n for n in neighbors_incoming_edges_primal.keys() if len(neighbors_incoming_edges_primal[n]) != 3 ]), 0) self.assertEqual( len([ n for n in neighbors_outgoing_edges_primal.keys() if len(neighbors_outgoing_edges_primal[n]) != 3 ]), 0) neighbors_incoming_edges_dual = dict() neighbors_outgoing_edges_dual = dict() for a, b in dual_graph.edge_index.t(): if (not a.item() in neighbors_incoming_edges_dual): neighbors_incoming_edges_dual[a.item()] = [b.item()] else: neighbors_incoming_edges_dual[a.item()].append(b.item()) if (not b.item() in neighbors_outgoing_edges_dual): neighbors_outgoing_edges_dual[b.item()] = [a.item()] else: neighbors_outgoing_edges_dual[b.item()].append(a.item()) self.assertEqual( len([ n for n in neighbors_incoming_edges_dual.keys() if len(neighbors_incoming_edges_dual[n]) != 4 ]), 0) self.assertEqual( len([ n for n in neighbors_outgoing_edges_dual.keys() if len(neighbors_outgoing_edges_dual[n]) != 4 ]), 0)
def test_slicing(self): dataset = Shrec2016DualPrimal(root=osp.abspath( osp.join(current_dir, '../common_data/shrec2016/')), categories=['shark'], single_dual_nodes=False, undirected_dual_edges=True, vertices_scale_mean=1., vertices_scale_var=0.1, edges_flip_fraction=0.5, slide_vertices_fraction=0.2, num_augmentations=4) reduced_dataset = dataset[:10] for idx in range(10): for graph_idx in range(2): # Verify that all attributes of the samples from the original # dataset and the "sliced" version match. self.assertEqual( dataset[idx][graph_idx].contains_isolated_nodes(), reduced_dataset[idx][graph_idx].contains_isolated_nodes()) self.assertEqual( dataset[idx][graph_idx].contains_self_loops(), reduced_dataset[idx][graph_idx].contains_self_loops()) self.assertEqual( dataset[idx][graph_idx].is_coalesced(), reduced_dataset[idx][graph_idx].is_coalesced()) self.assertEqual(dataset[idx][graph_idx].is_directed(), reduced_dataset[idx][graph_idx].is_directed()) self.assertEqual( dataset[idx][graph_idx].is_undirected(), reduced_dataset[idx][graph_idx].is_undirected()) for scalar_attr in [ 'keys', 'norm', 'num_edge_features', 'num_edges', 'num_faces', 'num_node_features', 'num_nodes' ]: if (getattr(dataset[idx][graph_idx], scalar_attr) is None): self.assertTrue( getattr(reduced_dataset[idx][graph_idx], scalar_attr) is None) else: self.assertEqual( getattr(dataset[idx][graph_idx], scalar_attr), getattr(reduced_dataset[idx][graph_idx], scalar_attr)) for tensor_attr in [ 'edge_attr', 'edge_index', 'face', 'pos', 'x', 'y' ]: if (getattr(dataset[idx][graph_idx], tensor_attr) is None): self.assertTrue( getattr(reduced_dataset[idx][graph_idx], tensor_attr) is None) else: self.assertTrue( np.all( getattr(dataset[idx][graph_idx], tensor_attr).numpy() == getattr( reduced_dataset[idx][graph_idx], tensor_attr).numpy())) # Check primal-edge-to-dual-node index dictionary. self.assertEqual(dataset[idx][2], reduced_dataset[idx][2]) reduced_dataset = dataset[2:60:5] for reduced_dataset_idx, dataset_idx in enumerate(range(2, 60, 5)): for graph_idx in range(2): # Verify that all attributes of the samples from the original # dataset and the "sliced" version match. self.assertEqual( dataset[dataset_idx][graph_idx].contains_isolated_nodes(), reduced_dataset[reduced_dataset_idx] [graph_idx].contains_isolated_nodes()) self.assertEqual( dataset[dataset_idx][graph_idx].contains_self_loops(), reduced_dataset[reduced_dataset_idx] [graph_idx].contains_self_loops()) self.assertEqual( dataset[dataset_idx][graph_idx].is_coalesced(), reduced_dataset[reduced_dataset_idx] [graph_idx].is_coalesced()) self.assertEqual( dataset[dataset_idx][graph_idx].is_directed(), reduced_dataset[reduced_dataset_idx] [graph_idx].is_directed()) self.assertEqual( dataset[dataset_idx][graph_idx].is_undirected(), reduced_dataset[reduced_dataset_idx] [graph_idx].is_undirected()) for scalar_attr in [ 'keys', 'norm', 'num_edge_features', 'num_edges', 'num_faces', 'num_node_features', 'num_nodes' ]: if (getattr(dataset[dataset_idx][graph_idx], scalar_attr) is None): self.assertTrue( getattr( reduced_dataset[reduced_dataset_idx] [graph_idx], scalar_attr) is None) else: self.assertEqual( getattr(dataset[dataset_idx][graph_idx], scalar_attr), getattr( reduced_dataset[reduced_dataset_idx] [graph_idx], scalar_attr)) for tensor_attr in [ 'edge_attr', 'edge_index', 'face', 'pos', 'x', 'y' ]: if (getattr(dataset[dataset_idx][graph_idx], tensor_attr) is None): self.assertTrue( getattr( reduced_dataset[reduced_dataset_idx] [graph_idx], tensor_attr) is None) else: self.assertTrue( np.all( getattr(dataset[dataset_idx][graph_idx], tensor_attr).numpy() == getattr( reduced_dataset[reduced_dataset_idx] [graph_idx], tensor_attr).numpy())) # Check primal-edge-to-dual-node index dictionary. self.assertEqual(dataset[dataset_idx][2], reduced_dataset[reduced_dataset_idx][2])