Example #1
0
    def test_validation_split(self):
        dataset = Shrec2016DualPrimal(root=osp.join(
            current_dir, '../../datasets_no_augmentation/'),
                                      categories=[],
                                      single_dual_nodes=False,
                                      undirected_dual_edges=True,
                                      num_augmentations=1)
        validation_set_fraction = 0.1
        num_samples_train = int(len(dataset) * (1. - validation_set_fraction))

        dataset_train = dataset[:num_samples_train]
        dataset_validation = dataset[num_samples_train:]

        print("\n")
        for category_idx in range(30):
            print(f"* Category {category_idx}:")
            elements_in_training_set = [
                idx for idx, i in enumerate(dataset_train)
                if i[0].y.item() == category_idx
            ]
            elements_in_validation_set = [
                idx for idx, i in enumerate(dataset_validation)
                if i[0].y.item() == category_idx
            ]
            self.assertGreater(len(elements_in_training_set), 0)
            self.assertGreater(len(elements_in_validation_set), 0)
            print("\t- Num elements in training set: "
                  f"{len(elements_in_training_set)}")
            print("\t- Num elements in validation set: "
                  f"{len(elements_in_validation_set)}")
Example #2
0
    def test_batch_formation(self):
        dataset = Shrec2016DualPrimal(root=osp.abspath(
            osp.join(current_dir, '../common_data/shrec2016/')),
                                      categories=['shark'],
                                      single_dual_nodes=False,
                                      undirected_dual_edges=True,
                                      vertices_scale_mean=1.,
                                      vertices_scale_var=0.1,
                                      edges_flip_fraction=0.5,
                                      slide_vertices_fraction=0.2,
                                      num_augmentations=4)
        batch_size = 4
        data_loader = DualPrimalDataLoader(dataset=dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
        dataset_len = len(dataset)
        self.assertEqual(dataset_len, 64)

        for i, (primal_graph, dual_graph, _) in enumerate(data_loader):
            # Each mesh has 750 edges, hence the number of dual nodes/primal
            # edges, considering the directness of the graphs, must be
            # 750 * 2 * batch_size.
            self.assertEqual(primal_graph.num_edges, 750 * 2 * batch_size)
            self.assertEqual(dual_graph.x.shape[0], 750 * 2 * batch_size)
            self.assertEqual(dual_graph.batch.shape[0], 750 * 2 * batch_size)

        num_batches = i + 1
        self.assertEqual(num_batches, 16)
Example #3
0
    def test_sample_indices_shuffle(self):
        dataset = Shrec2016DualPrimal(root=osp.abspath(
            osp.join(current_dir, '../common_data/shrec2016/')),
                                      categories=['shark'],
                                      single_dual_nodes=False,
                                      undirected_dual_edges=True,
                                      vertices_scale_mean=1.,
                                      vertices_scale_var=0.1,
                                      edges_flip_fraction=0.5,
                                      slide_vertices_fraction=0.2,
                                      num_augmentations=4,
                                      return_sample_indices=True)

        self.assertEqual(len(dataset), 64)
        sample_indices_found = set()
        data_loader = DualPrimalDataLoader(dataset=dataset,
                                           batch_size=8,
                                           shuffle=True,
                                           return_sample_indices=True)
        for _, _, _, sample_indices in data_loader:
            for sample_idx in sample_indices:
                sample_indices_found.add(sample_idx)

        self.assertEqual(len(sample_indices_found), len(dataset))
        self.assertEqual(list(sample_indices_found), [*range(0, 64)])
Example #4
0
def create_dataset(dataset_name,
                   compute_node_feature_stats=True,
                   node_feature_stats_filename=None,
                   **dataset_params):
    r"""Creates an instance of the input dataset with the input parameters.

    Args:
        dataset_name (str): Name that identifies the dataset. Valid values are:
            `shrec_16` (SHREC2016 dataset), `cubes` (Cube engraving dataset from
            MeshCNN), `coseg` (COSEG dataset), `human_seg` (Human Body
            Segmentation dataset).
        compute_node_feature_stats (bool, optional): If True, the mean and
            standard deviation of the node features in the dataset are computed
            (cf. function :obj:`compute_mean_and_std`). (default: :obj:`True`)
        node_feature_stats_filename (str, optional): If not None, filename of
            the file containing the mean and standard deviation of the node
            features in the dataset (cf. function :obj:`compute_mean_and_std`).
            The argument is considered only if the argument
            `compute_node_feature_stats` is True. (default: :obj:`None`)
        ...
        Optional parameters of the datasets.

    Returns:
        dataset (torch_geometric.data.Dataset): The instance of the dataset with
            the input parameters.
        primal_graph_mean, primal_graph_std, dual_graph_mean, dual_graph_std
            (tuple of numpy array or None): If the argument
            :obj:`compute_node_feature_stats` is True, statistics about the node
            features in the dataset (cf. docs above); otherwise, None.
    """
    if (dataset_name == 'shrec_16'):
        dataset = Shrec2016DualPrimal(**dataset_params)
    elif (dataset_name == 'cubes'):
        dataset = CubesDualPrimal(**dataset_params)
    elif (dataset_name == 'coseg'):
        dataset = CosegDualPrimal(**dataset_params)
    elif (dataset_name == 'human_seg'):
        dataset = HumanSegDualPrimal(**dataset_params)
    else:
        raise KeyError(
            f"No known dataset can be generated with the name '{dataset_name}'."
        )

    node_statistics = None

    if (compute_node_feature_stats):
        dataset_params = dataset.input_parameters
        (primal_graph_mean, primal_graph_std, dual_graph_mean,
         dual_graph_std) = compute_mean_and_std(
             dataset=dataset,
             dataset_params=dataset_params,
             filename=node_feature_stats_filename)
        node_statistics = (primal_graph_mean, primal_graph_std,
                           dual_graph_mean, dual_graph_std)
    return dataset, node_statistics
Example #5
0
 def test_download_process_and_get(self):
     dataset = Shrec2016DualPrimal(root=osp.abspath(
         osp.join(current_dir, '../common_data/shrec2016/')),
                                   categories=['shark'],
                                   single_dual_nodes=False,
                                   undirected_dual_edges=True,
                                   vertices_scale_mean=1.,
                                   vertices_scale_var=0.1,
                                   edges_flip_fraction=0.5,
                                   slide_vertices_fraction=0.2,
                                   num_augmentations=4)
     # Get one element of the dataset.
     print(dataset[0])
Example #6
0
    def test_feature_standardization(self):
        dataset = Shrec2016DualPrimal(root=osp.abspath(
            osp.join(current_dir, '../common_data/shrec2016/')),
                                      categories=['shark'],
                                      single_dual_nodes=False,
                                      undirected_dual_edges=True,
                                      vertices_scale_mean=1.,
                                      vertices_scale_var=0.1,
                                      edges_flip_fraction=0.5,
                                      slide_vertices_fraction=0.2,
                                      num_augmentations=4)
        dataset_params = dataset.input_parameters
        primal_mean, primal_std, dual_mean, dual_std = compute_mean_and_std(
            dataset=dataset,
            dataset_params=dataset_params,
            filename=osp.join(current_dir,
                              '../output_data/dataset_params.pkl'))
        # Create batch of the size of the entire dataset, with standardized
        # features. The mean of the primal-graph-/dual-graph- node features
        # should now be 0, while the standard deviation should be 1.
        primal_graph_list = []
        dual_graph_list = []
        primal_edge_to_dual_node_idx_list = []
        for (primal_graph, dual_graph, primal_edge_to_dual_node_idx,
             _) in dataset:
            primal_graph_list.append(primal_graph)
            dual_graph_list.append(dual_graph)
            primal_edge_to_dual_node_idx_list.append(
                primal_edge_to_dual_node_idx)
        (primal_graph_batch, dual_graph_batch,
         _) = create_dual_primal_batch(primal_graph_list,
                                       dual_graph_list,
                                       primal_edge_to_dual_node_idx_list,
                                       primal_mean=primal_mean,
                                       primal_std=primal_std,
                                       dual_mean=dual_mean,
                                       dual_std=dual_std)

        self.assertAlmostEqual(primal_graph_batch.x.mean(axis=0).item(), 0., 5)
        self.assertAlmostEqual(primal_graph_batch.x.std(axis=0).item(), 1., 5)

        for dual_features_idx in range(4):
            self.assertAlmostEqual(
                dual_graph_batch.x.mean(axis=0)[dual_features_idx].item(), 0.,
                4)
            self.assertAlmostEqual(
                dual_graph_batch.x.std(axis=0)[dual_features_idx].item(), 1.,
                4)
Example #7
0
    def test_sample_indices_no_shuffle(self):
        dataset = Shrec2016DualPrimal(root=osp.abspath(
            osp.join(current_dir, '../common_data/shrec2016/')),
                                      categories=['shark'],
                                      single_dual_nodes=False,
                                      undirected_dual_edges=True,
                                      vertices_scale_mean=1.,
                                      vertices_scale_var=0.1,
                                      edges_flip_fraction=0.5,
                                      slide_vertices_fraction=0.2,
                                      num_augmentations=4,
                                      return_sample_indices=True)

        self.assertEqual(len(dataset), 64)
        data_loader = DualPrimalDataLoader(dataset=dataset,
                                           batch_size=8,
                                           shuffle=False,
                                           return_sample_indices=True)
        for batch_idx, (_, _, _, sample_indices) in enumerate(data_loader):
            self.assertEqual(sample_indices,
                             [*range(batch_idx * 8, batch_idx * 8 + 8)])
Example #8
0
 def test_shrec_multiple_classes(self):
     # The shape in SHREC are closed, manifold meshes. Therefore the
     # associated 'primal graph' (simplex mesh) and 'dual graph' (medial
     # graph) should be respectively 3-regular and 4-regular (when directness
     # is not considered).
     root_shrec = osp.abspath(
         osp.join(current_dir, '../common_data/shrec2016_shark_gorilla/'))
     dataset = Shrec2016DualPrimal(root=root_shrec,
                                   train=True,
                                   categories=['shark', 'gorilla'],
                                   single_dual_nodes=False,
                                   undirected_dual_edges=True,
                                   vertices_scale_mean=1.,
                                   vertices_scale_var=0.1,
                                   edges_flip_fraction=0.5,
                                   slide_vertices_fraction=0.2,
                                   num_augmentations=2)
     batch_size = 4
     data_loader = DualPrimalDataLoader(dataset=dataset,
                                        batch_size=batch_size,
                                        shuffle=True)
     dataset_size = len(dataset)
     # There are 16 'gorilla' shapes and 16 'shark' shapes in the 'train'
     # split of SHREC2016. Counting 2 versions per shape due to data
     # augmentation, one has (16 + 16) * 2 = 64 shapes in total.
     self.assertEqual(dataset_size, 64)
     # Check the ground-truth class index of the samples in the batch.
     shark_class_index = Shrec2016DualPrimal.valid_categories.index('shark')
     gorilla_class_index = Shrec2016DualPrimal.valid_categories.index(
         'gorilla')
     for (primal_graph_batch, _, _) in data_loader:
         self.assertEqual(primal_graph_batch.y.size(), (batch_size, ))
         for primal_graph_idx in range(batch_size):
             sample_class_index = primal_graph_batch.y[
                 primal_graph_idx].item()
             self.assertTrue(sample_class_index in
                             [shark_class_index, gorilla_class_index])
Example #9
0
    def _test_compute_mean_and_std(self):
        dataset = Shrec2016DualPrimal(root=osp.abspath(
            osp.join(current_dir, '../common_data/shrec2016/')),
                                      categories=['shark'],
                                      single_dual_nodes=False,
                                      undirected_dual_edges=True,
                                      vertices_scale_mean=1.,
                                      vertices_scale_var=0.1,
                                      edges_flip_fraction=0.5,
                                      slide_vertices_fraction=0.2,
                                      num_augmentations=4)
        dataset_params = dataset.input_parameters
        primal_mean, primal_std, dual_mean, dual_std = compute_mean_and_std(
            dataset=dataset,
            dataset_params=dataset_params,
            filename=osp.join(current_dir,
                              '../output_data/dataset_params.pkl'))

        # Store statistics, so that they can be compared with those computed for
        # other tests using the same dataset input parameters.
        self.__class__.__primal_mean = primal_mean
        self.__class__.__primal_std = primal_std
        self.__class__.__dual_mean = dual_mean
        self.__class__.__dual_std = dual_std
 def test_forward_pass(self):
     # NOTE: this is not an actual unit test, but it is just used to verify
     # that the forward pass goes through and the output features and graphs
     # have the expected format.
     single_dual_nodes = False
     undirected_dual_edges = True
     dataset = Shrec2016DualPrimal(
         root=osp.join(current_dir, '../common_data/shrec2016'),
         categories=['shark'],
         single_dual_nodes=single_dual_nodes,
         undirected_dual_edges=undirected_dual_edges,
         vertices_scale_mean=1.,
         vertices_scale_var=0.1,
         edges_flip_fraction=0.5,
         slide_vertices_fraction=0.2,
         num_augmentations=1)
     batch_size = 4
     data_loader = DualPrimalDataLoader(dataset=dataset,
                                        batch_size=batch_size,
                                        shuffle=True)
     # Down-convolutional layer.
     out_channels_primal = 5
     out_channels_dual = 5
     out_channels_primal_after_encoder = 7
     out_channels_dual_after_encoder = 7
     heads = 1
     concat_primal = True
     concat_dual = True
     negative_slope_primal = 0.2
     negative_slope_dual = 0.2
     dropout_primal = 0
     dropout_dual = 0
     bias_primal = True
     bias_dual = True
     add_self_loops_to_dual_graph = False
     down_conv = DualPrimalDownConv(
         in_channels_primal=1,
         in_channels_dual=4,
         out_channels_primal=out_channels_primal,
         out_channels_dual=out_channels_dual,
         heads=heads,
         concat_primal=concat_primal,
         concat_dual=concat_dual,
         negative_slope_primal=negative_slope_primal,
         negative_slope_dual=negative_slope_dual,
         dropout_primal=dropout_primal,
         dropout_dual=dropout_dual,
         bias_primal=bias_primal,
         bias_dual=bias_dual,
         single_dual_nodes=single_dual_nodes,
         undirected_dual_edges=undirected_dual_edges,
         add_self_loops_to_dual_graph=add_self_loops_to_dual_graph,
         num_primal_edges_to_keep=600,
         num_skips=3,
         return_old_dual_node_to_new_dual_node=True,
         return_graphs_before_pooling=True)
     # Dual-primal convolutional layer.
     conv = DualPrimalConv(
         in_channels_primal=out_channels_primal,
         in_channels_dual=out_channels_dual,
         out_channels_primal=out_channels_primal_after_encoder,
         out_channels_dual=out_channels_dual_after_encoder,
         single_dual_nodes=single_dual_nodes,
         undirected_dual_edges=undirected_dual_edges,
         heads=heads,
         concat_primal=concat_primal,
         concat_dual=concat_dual,
         negative_slope_primal=negative_slope_primal,
         negative_slope_dual=negative_slope_dual,
         dropout_primal=dropout_primal,
         dropout_dual=dropout_dual,
         bias_primal=bias_primal,
         bias_dual=bias_dual,
         add_self_loops_to_dual_graph=add_self_loops_to_dual_graph)
     # Up-convolutional layer.
     up_conv = DualPrimalUpConv(
         in_channels_primal=out_channels_primal_after_encoder,
         in_channels_dual=out_channels_dual_after_encoder,
         out_channels_primal=out_channels_primal,
         out_channels_dual=out_channels_dual,
         single_dual_nodes=single_dual_nodes,
         undirected_dual_edges=undirected_dual_edges,
         concat_data_from_before_pooling=True)
     # Forward pass.
     for primal_graph, dual_graph, petdni in data_loader:
         (primal_graph_after_pooling, dual_graph_after_pooling,
          petdni_after_pooling, log_info, primal_graph_before_pooling,
          dual_graph_before_pooling) = down_conv(
              primal_graph_batch=primal_graph,
              dual_graph_batch=dual_graph,
              primal_edge_to_dual_node_idx_batch=petdni)
         self.assertNotEqual(log_info, None)
         (primal_graph_after_pooling.x, dual_graph_after_pooling.x) = conv(
             x_primal=primal_graph_after_pooling.x,
             x_dual=dual_graph_after_pooling.x,
             edge_index_primal=primal_graph_after_pooling.edge_index,
             edge_index_dual=dual_graph_after_pooling.edge_index,
             primal_edge_to_dual_node_idx=petdni_after_pooling)
         (primal_graph_batch_out, dual_graph_batch_out,
          primal_edge_to_dual_node_idx_batch_out) = up_conv(
              primal_graph_batch=primal_graph_after_pooling,
              dual_graph_batch=dual_graph_after_pooling,
              primal_edge_to_dual_node_idx_batch=petdni_after_pooling,
              pooling_log=log_info,
              primal_graph_batch_before_pooling=primal_graph_before_pooling,
              dual_graph_batch_before_pooling=dual_graph_before_pooling)
     # Check that the original graphs and the 'unpooled' ones match in
     # connectivity.
     self.assertTrue(
         torch.equal(primal_graph.edge_index,
                     primal_graph_batch_out.edge_index))
     self.assertTrue(
         torch.equal(dual_graph.edge_index, dual_graph_batch_out.edge_index))
     self.assertEqual(petdni, primal_edge_to_dual_node_idx_batch_out)
     # Check that the number of output channels of the new features match the
     # one of the features in the original graphs.
     self.assertEqual(primal_graph.num_node_features,
                      primal_graph_after_pooling.num_node_features)
     self.assertEqual(dual_graph.num_node_features,
                      dual_graph_after_pooling.num_node_features)
    def test_forward_pass(self):
        # NOTE: this is not an actual unit test, but it is just used to verify
        # that the forward pass goes through.
        single_dual_nodes = False
        undirected_dual_edges = True
        dataset = Shrec2016DualPrimal(
            root=osp.join(current_dir, '../common_data/shrec2016'),
            categories=['shark'],
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            vertices_scale_mean=1.,
            vertices_scale_var=0.1,
            edges_flip_fraction=0.5,
            slide_vertices_fraction=0.2,
            num_augmentations=1)
        batch_size = 4
        data_loader = DualPrimalDataLoader(dataset=dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
        # Test without pooling.
        down_conv = DualPrimalDownConv(
            in_channels_primal=1,
            in_channels_dual=4,
            out_channels_primal=5,
            out_channels_dual=5,
            heads=1,
            concat_primal=True,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            add_self_loops_to_dual_graph=False,
            num_skips=3)
        for primal_graph, dual_graph, petdni in data_loader:
            (primal_graph_out, dual_graph_out, petdni_out, log_info, _,
             _) = down_conv(primal_graph_batch=primal_graph,
                            dual_graph_batch=dual_graph,
                            primal_edge_to_dual_node_idx_batch=petdni)
            self.assertEqual(log_info, None)
        # Tests with pooling.
        down_conv = DualPrimalDownConv(
            in_channels_primal=1,
            in_channels_dual=4,
            out_channels_primal=5,
            out_channels_dual=5,
            heads=1,
            concat_primal=True,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            add_self_loops_to_dual_graph=False,
            num_primal_edges_to_keep=600,
            num_skips=3)
        for primal_graph, dual_graph, petdni in data_loader:
            (primal_graph_out, dual_graph_out, petdni_out, log_info, _,
             _) = down_conv(primal_graph_batch=primal_graph,
                            dual_graph_batch=dual_graph,
                            primal_edge_to_dual_node_idx_batch=petdni)
            self.assertNotEqual(log_info, None)

        down_conv = DualPrimalDownConv(
            in_channels_primal=1,
            in_channels_dual=4,
            out_channels_primal=5,
            out_channels_dual=5,
            heads=1,
            concat_primal=True,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            add_self_loops_to_dual_graph=False,
            fraction_primal_edges_to_keep=0.7,
            num_skips=3)
        for primal_graph, dual_graph, petdni in data_loader:
            (primal_graph_out, dual_graph_out, petdni_out, log_info, _,
             _) = down_conv(primal_graph_batch=primal_graph,
                            dual_graph_batch=dual_graph,
                            primal_edge_to_dual_node_idx_batch=petdni)
            self.assertNotEqual(log_info, None)

        down_conv = DualPrimalDownConv(
            in_channels_primal=1,
            in_channels_dual=4,
            out_channels_primal=5,
            out_channels_dual=5,
            heads=1,
            concat_primal=True,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            add_self_loops_to_dual_graph=False,
            primal_attention_coeff_threshold=0.5,
            num_skips=3)
        for primal_graph, dual_graph, petdni in data_loader:
            (primal_graph_out, dual_graph_out, petdni_out, log_info, _,
             _) = down_conv(primal_graph_batch=primal_graph,
                            dual_graph_batch=dual_graph,
                            primal_edge_to_dual_node_idx_batch=petdni)
            self.assertNotEqual(log_info, None)

        down_conv = DualPrimalDownConv(
            in_channels_primal=1,
            in_channels_dual=4,
            out_channels_primal=5,
            out_channels_dual=5,
            heads=1,
            concat_primal=True,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            add_self_loops_to_dual_graph=False,
            num_primal_edges_to_keep=600,
            allow_pooling_consecutive_edges=False,
            num_skips=3)
        for primal_graph, dual_graph, petdni in data_loader:
            (primal_graph_out, dual_graph_out, petdni_out, log_info, _,
             _) = down_conv(primal_graph_batch=primal_graph,
                            dual_graph_batch=dual_graph,
                            primal_edge_to_dual_node_idx_batch=petdni)
            self.assertNotEqual(log_info, None)

        down_conv = DualPrimalDownConv(
            in_channels_primal=1,
            in_channels_dual=4,
            out_channels_primal=5,
            out_channels_dual=5,
            heads=3,
            concat_primal=True,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            add_self_loops_to_dual_graph=False,
            num_primal_edges_to_keep=600,
            use_decreasing_attention_coefficients=False,
            num_skips=3)
        for primal_graph, dual_graph, petdni in data_loader:
            (primal_graph_out, dual_graph_out, petdni_out, log_info, _,
             _) = down_conv(primal_graph_batch=primal_graph,
                            dual_graph_batch=dual_graph,
                            primal_edge_to_dual_node_idx_batch=petdni)
            self.assertNotEqual(log_info, None)
Example #12
0
    def test_right_graph_connectivity(self):
        # The shape in SHREC are closed, manifold meshes. Therefore the
        # associated 'primal graph' (simplex mesh) and 'dual graph' (medial
        # graph) should be respectively 3-regular and 4-regular (when directness
        # is not considered).
        root_shrec = osp.abspath(
            osp.join(current_dir, '../common_data/shrec2016/'))
        processed_shrec_fold = osp.join(root_shrec, 'processed/')
        if (osp.exists(processed_shrec_fold)):
            sys.stdout.write(
                "\nWarning: running the following test will cause the folder "
                f"'{processed_shrec_fold}' to be deleted! ")
            valid_choice = False
            while (not valid_choice):
                sys.stdout.write("Do you want to continue? [y/n] ")
                user_input = input().lower()
                if (user_input == 'y'):
                    print("Removing folder...")
                    shutil.rmtree(processed_shrec_fold)
                    valid_choice = True
                elif (user_input == 'n'):
                    print("Skipping test.")
                    valid_choice = True
                    return
                else:
                    sys.stdout.write(
                        "Please respond with 'y'/'Y' or 'n'/'N'.\n")

        print("Running test...")
        dataset = Shrec2016DualPrimal(root=root_shrec,
                                      categories=['shark'],
                                      single_dual_nodes=False,
                                      undirected_dual_edges=True,
                                      vertices_scale_mean=1.,
                                      vertices_scale_var=0.1,
                                      edges_flip_fraction=0.5,
                                      slide_vertices_fraction=0.2,
                                      num_augmentations=4)

        for primal_graph, dual_graph, _, _ in dataset:
            neighbors_incoming_edges_primal = dict()
            neighbors_outgoing_edges_primal = dict()
            for a, b in primal_graph.edge_index.t():
                if (not a.item() in neighbors_incoming_edges_primal):
                    neighbors_incoming_edges_primal[a.item()] = [b.item()]
                else:
                    neighbors_incoming_edges_primal[a.item()].append(b.item())
                if (not b.item() in neighbors_outgoing_edges_primal):
                    neighbors_outgoing_edges_primal[b.item()] = [a.item()]
                else:
                    neighbors_outgoing_edges_primal[b.item()].append(a.item())

            self.assertEqual(
                len([
                    n for n in neighbors_incoming_edges_primal.keys()
                    if len(neighbors_incoming_edges_primal[n]) != 3
                ]), 0)
            self.assertEqual(
                len([
                    n for n in neighbors_outgoing_edges_primal.keys()
                    if len(neighbors_outgoing_edges_primal[n]) != 3
                ]), 0)

            neighbors_incoming_edges_dual = dict()
            neighbors_outgoing_edges_dual = dict()
            for a, b in dual_graph.edge_index.t():
                if (not a.item() in neighbors_incoming_edges_dual):
                    neighbors_incoming_edges_dual[a.item()] = [b.item()]
                else:
                    neighbors_incoming_edges_dual[a.item()].append(b.item())
                if (not b.item() in neighbors_outgoing_edges_dual):
                    neighbors_outgoing_edges_dual[b.item()] = [a.item()]
                else:
                    neighbors_outgoing_edges_dual[b.item()].append(a.item())

            self.assertEqual(
                len([
                    n for n in neighbors_incoming_edges_dual.keys()
                    if len(neighbors_incoming_edges_dual[n]) != 4
                ]), 0)
            self.assertEqual(
                len([
                    n for n in neighbors_outgoing_edges_dual.keys()
                    if len(neighbors_outgoing_edges_dual[n]) != 4
                ]), 0)
Example #13
0
    def test_slicing(self):
        dataset = Shrec2016DualPrimal(root=osp.abspath(
            osp.join(current_dir, '../common_data/shrec2016/')),
                                      categories=['shark'],
                                      single_dual_nodes=False,
                                      undirected_dual_edges=True,
                                      vertices_scale_mean=1.,
                                      vertices_scale_var=0.1,
                                      edges_flip_fraction=0.5,
                                      slide_vertices_fraction=0.2,
                                      num_augmentations=4)

        reduced_dataset = dataset[:10]

        for idx in range(10):
            for graph_idx in range(2):
                # Verify that all attributes of the samples from the original
                # dataset and the "sliced" version match.
                self.assertEqual(
                    dataset[idx][graph_idx].contains_isolated_nodes(),
                    reduced_dataset[idx][graph_idx].contains_isolated_nodes())
                self.assertEqual(
                    dataset[idx][graph_idx].contains_self_loops(),
                    reduced_dataset[idx][graph_idx].contains_self_loops())
                self.assertEqual(
                    dataset[idx][graph_idx].is_coalesced(),
                    reduced_dataset[idx][graph_idx].is_coalesced())
                self.assertEqual(dataset[idx][graph_idx].is_directed(),
                                 reduced_dataset[idx][graph_idx].is_directed())
                self.assertEqual(
                    dataset[idx][graph_idx].is_undirected(),
                    reduced_dataset[idx][graph_idx].is_undirected())
                for scalar_attr in [
                        'keys', 'norm', 'num_edge_features', 'num_edges',
                        'num_faces', 'num_node_features', 'num_nodes'
                ]:
                    if (getattr(dataset[idx][graph_idx], scalar_attr) is None):
                        self.assertTrue(
                            getattr(reduced_dataset[idx][graph_idx],
                                    scalar_attr) is None)
                    else:
                        self.assertEqual(
                            getattr(dataset[idx][graph_idx], scalar_attr),
                            getattr(reduced_dataset[idx][graph_idx],
                                    scalar_attr))

                for tensor_attr in [
                        'edge_attr', 'edge_index', 'face', 'pos', 'x', 'y'
                ]:
                    if (getattr(dataset[idx][graph_idx], tensor_attr) is None):
                        self.assertTrue(
                            getattr(reduced_dataset[idx][graph_idx],
                                    tensor_attr) is None)
                    else:
                        self.assertTrue(
                            np.all(
                                getattr(dataset[idx][graph_idx],
                                        tensor_attr).numpy() == getattr(
                                            reduced_dataset[idx][graph_idx],
                                            tensor_attr).numpy()))
            # Check primal-edge-to-dual-node index dictionary.
            self.assertEqual(dataset[idx][2], reduced_dataset[idx][2])

        reduced_dataset = dataset[2:60:5]

        for reduced_dataset_idx, dataset_idx in enumerate(range(2, 60, 5)):
            for graph_idx in range(2):
                # Verify that all attributes of the samples from the original
                # dataset and the "sliced" version match.
                self.assertEqual(
                    dataset[dataset_idx][graph_idx].contains_isolated_nodes(),
                    reduced_dataset[reduced_dataset_idx]
                    [graph_idx].contains_isolated_nodes())
                self.assertEqual(
                    dataset[dataset_idx][graph_idx].contains_self_loops(),
                    reduced_dataset[reduced_dataset_idx]
                    [graph_idx].contains_self_loops())
                self.assertEqual(
                    dataset[dataset_idx][graph_idx].is_coalesced(),
                    reduced_dataset[reduced_dataset_idx]
                    [graph_idx].is_coalesced())
                self.assertEqual(
                    dataset[dataset_idx][graph_idx].is_directed(),
                    reduced_dataset[reduced_dataset_idx]
                    [graph_idx].is_directed())
                self.assertEqual(
                    dataset[dataset_idx][graph_idx].is_undirected(),
                    reduced_dataset[reduced_dataset_idx]
                    [graph_idx].is_undirected())
                for scalar_attr in [
                        'keys', 'norm', 'num_edge_features', 'num_edges',
                        'num_faces', 'num_node_features', 'num_nodes'
                ]:
                    if (getattr(dataset[dataset_idx][graph_idx], scalar_attr)
                            is None):
                        self.assertTrue(
                            getattr(
                                reduced_dataset[reduced_dataset_idx]
                                [graph_idx], scalar_attr) is None)
                    else:
                        self.assertEqual(
                            getattr(dataset[dataset_idx][graph_idx],
                                    scalar_attr),
                            getattr(
                                reduced_dataset[reduced_dataset_idx]
                                [graph_idx], scalar_attr))

                for tensor_attr in [
                        'edge_attr', 'edge_index', 'face', 'pos', 'x', 'y'
                ]:
                    if (getattr(dataset[dataset_idx][graph_idx], tensor_attr)
                            is None):
                        self.assertTrue(
                            getattr(
                                reduced_dataset[reduced_dataset_idx]
                                [graph_idx], tensor_attr) is None)
                    else:
                        self.assertTrue(
                            np.all(
                                getattr(dataset[dataset_idx][graph_idx],
                                        tensor_attr).numpy() ==
                                getattr(
                                    reduced_dataset[reduced_dataset_idx]
                                    [graph_idx], tensor_attr).numpy()))
            # Check primal-edge-to-dual-node index dictionary.
            self.assertEqual(dataset[dataset_idx][2],
                             reduced_dataset[reduced_dataset_idx][2])