コード例 #1
0
    def test_sample_indices_shuffle(self):
        dataset = Shrec2016DualPrimal(root=osp.abspath(
            osp.join(current_dir, '../common_data/shrec2016/')),
                                      categories=['shark'],
                                      single_dual_nodes=False,
                                      undirected_dual_edges=True,
                                      vertices_scale_mean=1.,
                                      vertices_scale_var=0.1,
                                      edges_flip_fraction=0.5,
                                      slide_vertices_fraction=0.2,
                                      num_augmentations=4,
                                      return_sample_indices=True)

        self.assertEqual(len(dataset), 64)
        sample_indices_found = set()
        data_loader = DualPrimalDataLoader(dataset=dataset,
                                           batch_size=8,
                                           shuffle=True,
                                           return_sample_indices=True)
        for _, _, _, sample_indices in data_loader:
            for sample_idx in sample_indices:
                sample_indices_found.add(sample_idx)

        self.assertEqual(len(sample_indices_found), len(dataset))
        self.assertEqual(list(sample_indices_found), [*range(0, 64)])
コード例 #2
0
    def test_batch_formation(self):
        dataset = Shrec2016DualPrimal(root=osp.abspath(
            osp.join(current_dir, '../common_data/shrec2016/')),
                                      categories=['shark'],
                                      single_dual_nodes=False,
                                      undirected_dual_edges=True,
                                      vertices_scale_mean=1.,
                                      vertices_scale_var=0.1,
                                      edges_flip_fraction=0.5,
                                      slide_vertices_fraction=0.2,
                                      num_augmentations=4)
        batch_size = 4
        data_loader = DualPrimalDataLoader(dataset=dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
        dataset_len = len(dataset)
        self.assertEqual(dataset_len, 64)

        for i, (primal_graph, dual_graph, _) in enumerate(data_loader):
            # Each mesh has 750 edges, hence the number of dual nodes/primal
            # edges, considering the directness of the graphs, must be
            # 750 * 2 * batch_size.
            self.assertEqual(primal_graph.num_edges, 750 * 2 * batch_size)
            self.assertEqual(dual_graph.x.shape[0], 750 * 2 * batch_size)
            self.assertEqual(dual_graph.batch.shape[0], 750 * 2 * batch_size)

        num_batches = i + 1
        self.assertEqual(num_batches, 16)
コード例 #3
0
    def test_forward_pass(self):
        # NOTE: this is not an actual unit test, but it is just used to verify
        # that the forward pass goes through and that its output are in the
        # correct format.
        single_dual_nodes = True
        undirected_dual_edges = True
        dataset = CosegDualPrimal(root=osp.abspath(
            osp.join(current_dir, '../common_data/coseg_config_A/')),
                                  categories=['aliens'],
                                  single_dual_nodes=single_dual_nodes,
                                  undirected_dual_edges=undirected_dual_edges,
                                  return_sample_indices=True)
        batch_size = 4
        num_classes = 4
        data_loader = DualPrimalDataLoader(dataset=dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
        # Test without pooling.
        mesh_segmenter = DualPrimalUNetMeshSegmenter(
            in_channels_primal=1,
            in_channels_dual=7,
            conv_primal_out_res=[32, 64, 128],
            conv_dual_out_res=[32, 64, 128],
            num_classes=num_classes,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            fractions_primal_edges_to_keep=[0.7, 0.7],
            num_res_blocks=2,
            heads_encoder=3,
            heads_decoder=1,
            concat_primal=False,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            add_self_loops_to_dual_graph=False,
            return_node_to_cluster=False,
            log_ratios_new_old_primal_edges=True)

        for iteration_idx, (primal_graph, dual_graph,
                            petdni) in enumerate(data_loader):
            # - Limit test to first 5 batches.
            if (iteration_idx == 5):
                break
            output_scores, _ = mesh_segmenter(
                primal_graph_batch=primal_graph,
                dual_graph_batch=dual_graph,
                primal_edge_to_dual_node_idx_batch=petdni)
            # Verify that for each output node in the input primal graph one
            # score per each class is returned.
            self.assertEqual(output_scores.shape,
                             (primal_graph.num_nodes, num_classes))
コード例 #4
0
    def test_batch_formation(self):
        dataset = CosegDualPrimal(root=osp.abspath(
            osp.join(current_dir, '../common_data/coseg_config_A/')),
                                  categories=['aliens'],
                                  single_dual_nodes=True,
                                  undirected_dual_edges=True,
                                  return_sample_indices=True)
        segmentation_data_root = osp.join(
            current_dir, '../common_data/coseg_config_A/raw/coseg_aliens/seg')
        batch_size = 13  # 169 = 13 ** 2

        self.assertEqual(len(dataset), 169)
        data_loader = DualPrimalDataLoader(dataset=dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           return_sample_indices=True)
        for (primal_graph_batch, dual_graph_batch, _,
             sample_indices) in data_loader:
            self.assertEqual(primal_graph_batch.num_graphs, batch_size)
            self.assertEqual(dual_graph_batch.num_graphs, batch_size)
            # Check that the class labels associated to each of the primal nodes
            # in the batch matches the one in the segmentation data file
            # associated to that sample.
            for sample_index_in_batch, sample_index in enumerate(
                    sample_indices):
                # - Load the ground-truth labels.
                base_filename = dataset.processed_file_names_train[
                    3 * sample_index].rpartition('/')[-1].split('_')[0]
                gt_label_file = osp.join(segmentation_data_root,
                                         f'{base_filename}.eseg')
                with open(gt_label_file, 'r') as f:
                    gt_labels = np.loadtxt(f, dtype='float64')
                # - Find the class labels of the nodes in the batch that belong
                #   to the current sample.
                indices_nodes_in_sample = (
                    primal_graph_batch.batch == sample_index_in_batch
                ).nonzero().view(-1)
                class_labels_nodes_in_sample = primal_graph_batch.y[
                    indices_nodes_in_sample].numpy()
                self.class_labels_nodes_in_sample = class_labels_nodes_in_sample
                self.gt_labels = gt_labels
                # - Verify that the class labels match the ground-truth ones.
                self.assertTrue(
                    np.array_equal(class_labels_nodes_in_sample, gt_labels))
コード例 #5
0
    def test_sample_indices_no_shuffle(self):
        dataset = Shrec2016DualPrimal(root=osp.abspath(
            osp.join(current_dir, '../common_data/shrec2016/')),
                                      categories=['shark'],
                                      single_dual_nodes=False,
                                      undirected_dual_edges=True,
                                      vertices_scale_mean=1.,
                                      vertices_scale_var=0.1,
                                      edges_flip_fraction=0.5,
                                      slide_vertices_fraction=0.2,
                                      num_augmentations=4,
                                      return_sample_indices=True)

        self.assertEqual(len(dataset), 64)
        data_loader = DualPrimalDataLoader(dataset=dataset,
                                           batch_size=8,
                                           shuffle=False,
                                           return_sample_indices=True)
        for batch_idx, (_, _, _, sample_indices) in enumerate(data_loader):
            self.assertEqual(sample_indices,
                             [*range(batch_idx * 8, batch_idx * 8 + 8)])
コード例 #6
0
 def test_shrec_multiple_classes(self):
     # The shape in SHREC are closed, manifold meshes. Therefore the
     # associated 'primal graph' (simplex mesh) and 'dual graph' (medial
     # graph) should be respectively 3-regular and 4-regular (when directness
     # is not considered).
     root_shrec = osp.abspath(
         osp.join(current_dir, '../common_data/shrec2016_shark_gorilla/'))
     dataset = Shrec2016DualPrimal(root=root_shrec,
                                   train=True,
                                   categories=['shark', 'gorilla'],
                                   single_dual_nodes=False,
                                   undirected_dual_edges=True,
                                   vertices_scale_mean=1.,
                                   vertices_scale_var=0.1,
                                   edges_flip_fraction=0.5,
                                   slide_vertices_fraction=0.2,
                                   num_augmentations=2)
     batch_size = 4
     data_loader = DualPrimalDataLoader(dataset=dataset,
                                        batch_size=batch_size,
                                        shuffle=True)
     dataset_size = len(dataset)
     # There are 16 'gorilla' shapes and 16 'shark' shapes in the 'train'
     # split of SHREC2016. Counting 2 versions per shape due to data
     # augmentation, one has (16 + 16) * 2 = 64 shapes in total.
     self.assertEqual(dataset_size, 64)
     # Check the ground-truth class index of the samples in the batch.
     shark_class_index = Shrec2016DualPrimal.valid_categories.index('shark')
     gorilla_class_index = Shrec2016DualPrimal.valid_categories.index(
         'gorilla')
     for (primal_graph_batch, _, _) in data_loader:
         self.assertEqual(primal_graph_batch.y.size(), (batch_size, ))
         for primal_graph_idx in range(batch_size):
             sample_class_index = primal_graph_batch.y[
                 primal_graph_idx].item()
             self.assertTrue(sample_class_index in
                             [shark_class_index, gorilla_class_index])
コード例 #7
0
    def _initialize_components(self, network_parameters, dataset_parameters,
                               data_loader_parameters, loss_parameters):
        r"""Instantiates and initializes: network, dataset, data loader and
        loss.

        Args:
            network_parameters, dataset_parameters, data_loader_parameters,
            loss_parameters (dict): Input parameters used to construct and
                initialize the network, the dataset, the data loader and the
                loss.
        
        Returns:
            None.
        """
        # Initialize model.
        assert ('should_initialize_weights' not in network_parameters), (
            "Network parameters should not contain the parameter "
            "'should_initialize_weights', as weights will be automatically "
            "initialized or not, depending on whether training is resumed "
            "from a previous job or not.")
        if (self.__verbose):
            print("Initializing network...")
        if (self.__save_clusterized_meshes):
            network_contains_at_least_one_pooling_layer = False
            if ('num_primal_edges_to_keep' in network_parameters
                    and network_parameters['num_primal_edges_to_keep']
                    is not None):
                num_pooling_layers = len([
                    threshold for threshold in
                    network_parameters['num_primal_edges_to_keep']
                    if threshold is not None
                ])
                network_contains_at_least_one_pooling_layer |= (
                    num_pooling_layers >= 1)
            elif ('fractions_primal_edges_to_keep' in network_parameters
                  and network_parameters['fractions_primal_edges_to_keep']
                  is not None):
                num_pooling_layers = len([
                    threshold for threshold in
                    network_parameters['fractions_primal_edges_to_keep']
                    if threshold is not None
                ])
                network_contains_at_least_one_pooling_layer |= (
                    num_pooling_layers >= 1)
            elif ('primal_attention_coeffs_thresholds' in network_parameters
                  and network_parameters['primal_attention_coeffs_thresholds']
                  is not None):
                num_pooling_layers = len([
                    threshold for threshold in
                    network_parameters['primal_attention_coeffs_thresholds']
                    if threshold is not None
                ])
                network_contains_at_least_one_pooling_layer |= (
                    num_pooling_layers >= 1)
            assert (network_contains_at_least_one_pooling_layer), (
                "Please use at least one pooling layer in the test model to "
                "save the clusterized meshes.")
            # Add to the input parameters of the network the flag that specifies
            # that the node-to-cluster correspondences should be returned.
            network_parameters['return_node_to_cluster'] = True

        self.__net = create_model(should_initialize_weights=False,
                                  **network_parameters)
        if ('log_ratios_new_old_primal_nodes' in network_parameters and
                network_parameters['log_ratios_new_old_primal_nodes'] is True):
            self.__are_ratios_new_old_primal_nodes_logged = True
        else:
            self.__are_ratios_new_old_primal_nodes_logged = False
        # Move network to GPU if necessary.
        if (self.__use_gpu):
            self.__net.to("cuda")
        else:
            self.__net.to("cpu")
        # Initialize dataset.
        if (self.__verbose):
            print("Initializing dataset...")
        if (dataset_parameters['train'] == True):
            print(
                "\033[93mNote: running evaluation on a 'train' split! If you "
                "instead want to use the 'test' split of the dataset, please "
                "set the dataset parameter 'train' as False.\033[0m")
            self.__split = 'train'
        else:
            self.__split = 'test'
        if (self.__standardize_features_using_training_set):
            assert (
                'compute_node_feature_stats' not in dataset_parameters
                or not dataset_parameters['compute_node_feature_stats']
            ), ("Setting argument 'standardize_features_using_training_set' of "
                "the test job to True is incompatible with dataset parameter "
                "'compute_node_feature_stats' = True.")
            # Perform input-feature normalization using the statistics from
            # the training set.
            print("\033[92mWill perform input-feature standardization using "
                  "the provided mean and standard deviation of the "
                  "primal-graph-/dual-graph- node features of the training "
                  f"set (file '{self.__training_params_filename}').\033[0m")
            primal_mean = dataset_parameters.pop('primal_mean_train')
            primal_std = dataset_parameters.pop('primal_std_train')
            dual_mean = dataset_parameters.pop('dual_mean_train')
            dual_std = dataset_parameters.pop('dual_std_train')
            dataset_parameters['compute_node_feature_stats'] = False
            dataset, _ = create_dataset(**dataset_parameters)
        else:
            if ('compute_node_feature_stats' in dataset_parameters
                    and not dataset_parameters['compute_node_feature_stats']):
                # No feature standardization.
                dataset, _ = create_dataset(**dataset_parameters)
                primal_mean = primal_std = dual_mean = dual_std = None
                print("\033[93mNote: no input-feature standardization will be "
                      "performed! If you wish to use standardization instead, "
                      "please set the argument "
                      "'standardize_features_using_training_set' of the test "
                      "job to True or set the dataset-parameter "
                      "`compute_node_feature_stats` to True.\033[0m")
            else:
                print("\033[93mNote: input-feature standardization will be "
                      "performed using the mean and standard deviation of the "
                      "primal-graph-/dual-graph- node features of the test "
                      "set! If you wish to use those of the training set "
                      "instead, please set the argument "
                      "'standardize_features_using_training_set' of the test "
                      "job to True.\033[0m")
                dataset, (primal_mean, primal_std, dual_mean,
                          dual_std) = create_dataset(**dataset_parameters)
        # Initialize data loader.
        assert (len(
            set(['primal_mean', 'primal_std', 'dual_mean', 'dual_std'])
            & set(data_loader_parameters)) == 0), (
                "Data-loader parameters should not contain any of the "
                "following parameters, as they will be automatically computed "
                "from the dataset or restored from the previous training job, "
                "if set to do so: 'primal_mean', "
                "'primal_std', 'dual_mean', 'dual_std'.")
        if (self.__verbose):
            print("Initializing data loader...")
        # Add to the input parameters of the data-loader the flag that specifies
        # that the indices of the sample in the dataset should be returned when
        # iterating on it.
        data_loader_parameters['return_sample_indices'] = True

        self.__data_loader = DualPrimalDataLoader(dataset=dataset,
                                                  primal_mean=primal_mean,
                                                  primal_std=primal_std,
                                                  dual_mean=dual_mean,
                                                  dual_std=dual_std,
                                                  **data_loader_parameters)
        # Initialize loss.
        if (loss_parameters is not None):
            if (self.__verbose):
                print("Initializing loss...")
            self.__loss = create_loss(**loss_parameters)
コード例 #8
0
 def test_forward_pass(self):
     # NOTE: this is not an actual unit test, but it is just used to verify
     # that the forward pass goes through and the output features and graphs
     # have the expected format.
     single_dual_nodes = False
     undirected_dual_edges = True
     dataset = Shrec2016DualPrimal(
         root=osp.join(current_dir, '../common_data/shrec2016'),
         categories=['shark'],
         single_dual_nodes=single_dual_nodes,
         undirected_dual_edges=undirected_dual_edges,
         vertices_scale_mean=1.,
         vertices_scale_var=0.1,
         edges_flip_fraction=0.5,
         slide_vertices_fraction=0.2,
         num_augmentations=1)
     batch_size = 4
     data_loader = DualPrimalDataLoader(dataset=dataset,
                                        batch_size=batch_size,
                                        shuffle=True)
     # Down-convolutional layer.
     out_channels_primal = 5
     out_channels_dual = 5
     out_channels_primal_after_encoder = 7
     out_channels_dual_after_encoder = 7
     heads = 1
     concat_primal = True
     concat_dual = True
     negative_slope_primal = 0.2
     negative_slope_dual = 0.2
     dropout_primal = 0
     dropout_dual = 0
     bias_primal = True
     bias_dual = True
     add_self_loops_to_dual_graph = False
     down_conv = DualPrimalDownConv(
         in_channels_primal=1,
         in_channels_dual=4,
         out_channels_primal=out_channels_primal,
         out_channels_dual=out_channels_dual,
         heads=heads,
         concat_primal=concat_primal,
         concat_dual=concat_dual,
         negative_slope_primal=negative_slope_primal,
         negative_slope_dual=negative_slope_dual,
         dropout_primal=dropout_primal,
         dropout_dual=dropout_dual,
         bias_primal=bias_primal,
         bias_dual=bias_dual,
         single_dual_nodes=single_dual_nodes,
         undirected_dual_edges=undirected_dual_edges,
         add_self_loops_to_dual_graph=add_self_loops_to_dual_graph,
         num_primal_edges_to_keep=600,
         num_skips=3,
         return_old_dual_node_to_new_dual_node=True,
         return_graphs_before_pooling=True)
     # Dual-primal convolutional layer.
     conv = DualPrimalConv(
         in_channels_primal=out_channels_primal,
         in_channels_dual=out_channels_dual,
         out_channels_primal=out_channels_primal_after_encoder,
         out_channels_dual=out_channels_dual_after_encoder,
         single_dual_nodes=single_dual_nodes,
         undirected_dual_edges=undirected_dual_edges,
         heads=heads,
         concat_primal=concat_primal,
         concat_dual=concat_dual,
         negative_slope_primal=negative_slope_primal,
         negative_slope_dual=negative_slope_dual,
         dropout_primal=dropout_primal,
         dropout_dual=dropout_dual,
         bias_primal=bias_primal,
         bias_dual=bias_dual,
         add_self_loops_to_dual_graph=add_self_loops_to_dual_graph)
     # Up-convolutional layer.
     up_conv = DualPrimalUpConv(
         in_channels_primal=out_channels_primal_after_encoder,
         in_channels_dual=out_channels_dual_after_encoder,
         out_channels_primal=out_channels_primal,
         out_channels_dual=out_channels_dual,
         single_dual_nodes=single_dual_nodes,
         undirected_dual_edges=undirected_dual_edges,
         concat_data_from_before_pooling=True)
     # Forward pass.
     for primal_graph, dual_graph, petdni in data_loader:
         (primal_graph_after_pooling, dual_graph_after_pooling,
          petdni_after_pooling, log_info, primal_graph_before_pooling,
          dual_graph_before_pooling) = down_conv(
              primal_graph_batch=primal_graph,
              dual_graph_batch=dual_graph,
              primal_edge_to_dual_node_idx_batch=petdni)
         self.assertNotEqual(log_info, None)
         (primal_graph_after_pooling.x, dual_graph_after_pooling.x) = conv(
             x_primal=primal_graph_after_pooling.x,
             x_dual=dual_graph_after_pooling.x,
             edge_index_primal=primal_graph_after_pooling.edge_index,
             edge_index_dual=dual_graph_after_pooling.edge_index,
             primal_edge_to_dual_node_idx=petdni_after_pooling)
         (primal_graph_batch_out, dual_graph_batch_out,
          primal_edge_to_dual_node_idx_batch_out) = up_conv(
              primal_graph_batch=primal_graph_after_pooling,
              dual_graph_batch=dual_graph_after_pooling,
              primal_edge_to_dual_node_idx_batch=petdni_after_pooling,
              pooling_log=log_info,
              primal_graph_batch_before_pooling=primal_graph_before_pooling,
              dual_graph_batch_before_pooling=dual_graph_before_pooling)
     # Check that the original graphs and the 'unpooled' ones match in
     # connectivity.
     self.assertTrue(
         torch.equal(primal_graph.edge_index,
                     primal_graph_batch_out.edge_index))
     self.assertTrue(
         torch.equal(dual_graph.edge_index, dual_graph_batch_out.edge_index))
     self.assertEqual(petdni, primal_edge_to_dual_node_idx_batch_out)
     # Check that the number of output channels of the new features match the
     # one of the features in the original graphs.
     self.assertEqual(primal_graph.num_node_features,
                      primal_graph_after_pooling.num_node_features)
     self.assertEqual(dual_graph.num_node_features,
                      dual_graph_after_pooling.num_node_features)
コード例 #9
0
    def test_forward_pass(self):
        # NOTE: this is not an actual unit test, but it is just used to verify
        # that the forward pass goes through.
        single_dual_nodes = False
        undirected_dual_edges = True
        dataset = Shrec2016DualPrimal(
            root=osp.join(current_dir, '../common_data/shrec2016'),
            categories=['shark'],
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            vertices_scale_mean=1.,
            vertices_scale_var=0.1,
            edges_flip_fraction=0.5,
            slide_vertices_fraction=0.2,
            num_augmentations=1)
        batch_size = 4
        data_loader = DualPrimalDataLoader(dataset=dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
        # Test without pooling.
        down_conv = DualPrimalDownConv(
            in_channels_primal=1,
            in_channels_dual=4,
            out_channels_primal=5,
            out_channels_dual=5,
            heads=1,
            concat_primal=True,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            add_self_loops_to_dual_graph=False,
            num_skips=3)
        for primal_graph, dual_graph, petdni in data_loader:
            (primal_graph_out, dual_graph_out, petdni_out, log_info, _,
             _) = down_conv(primal_graph_batch=primal_graph,
                            dual_graph_batch=dual_graph,
                            primal_edge_to_dual_node_idx_batch=petdni)
            self.assertEqual(log_info, None)
        # Tests with pooling.
        down_conv = DualPrimalDownConv(
            in_channels_primal=1,
            in_channels_dual=4,
            out_channels_primal=5,
            out_channels_dual=5,
            heads=1,
            concat_primal=True,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            add_self_loops_to_dual_graph=False,
            num_primal_edges_to_keep=600,
            num_skips=3)
        for primal_graph, dual_graph, petdni in data_loader:
            (primal_graph_out, dual_graph_out, petdni_out, log_info, _,
             _) = down_conv(primal_graph_batch=primal_graph,
                            dual_graph_batch=dual_graph,
                            primal_edge_to_dual_node_idx_batch=petdni)
            self.assertNotEqual(log_info, None)

        down_conv = DualPrimalDownConv(
            in_channels_primal=1,
            in_channels_dual=4,
            out_channels_primal=5,
            out_channels_dual=5,
            heads=1,
            concat_primal=True,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            add_self_loops_to_dual_graph=False,
            fraction_primal_edges_to_keep=0.7,
            num_skips=3)
        for primal_graph, dual_graph, petdni in data_loader:
            (primal_graph_out, dual_graph_out, petdni_out, log_info, _,
             _) = down_conv(primal_graph_batch=primal_graph,
                            dual_graph_batch=dual_graph,
                            primal_edge_to_dual_node_idx_batch=petdni)
            self.assertNotEqual(log_info, None)

        down_conv = DualPrimalDownConv(
            in_channels_primal=1,
            in_channels_dual=4,
            out_channels_primal=5,
            out_channels_dual=5,
            heads=1,
            concat_primal=True,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            add_self_loops_to_dual_graph=False,
            primal_attention_coeff_threshold=0.5,
            num_skips=3)
        for primal_graph, dual_graph, petdni in data_loader:
            (primal_graph_out, dual_graph_out, petdni_out, log_info, _,
             _) = down_conv(primal_graph_batch=primal_graph,
                            dual_graph_batch=dual_graph,
                            primal_edge_to_dual_node_idx_batch=petdni)
            self.assertNotEqual(log_info, None)

        down_conv = DualPrimalDownConv(
            in_channels_primal=1,
            in_channels_dual=4,
            out_channels_primal=5,
            out_channels_dual=5,
            heads=1,
            concat_primal=True,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            add_self_loops_to_dual_graph=False,
            num_primal_edges_to_keep=600,
            allow_pooling_consecutive_edges=False,
            num_skips=3)
        for primal_graph, dual_graph, petdni in data_loader:
            (primal_graph_out, dual_graph_out, petdni_out, log_info, _,
             _) = down_conv(primal_graph_batch=primal_graph,
                            dual_graph_batch=dual_graph,
                            primal_edge_to_dual_node_idx_batch=petdni)
            self.assertNotEqual(log_info, None)

        down_conv = DualPrimalDownConv(
            in_channels_primal=1,
            in_channels_dual=4,
            out_channels_primal=5,
            out_channels_dual=5,
            heads=3,
            concat_primal=True,
            concat_dual=True,
            negative_slope_primal=0.2,
            negative_slope_dual=0.2,
            dropout_primal=0,
            dropout_dual=0,
            bias_primal=True,
            bias_dual=True,
            single_dual_nodes=single_dual_nodes,
            undirected_dual_edges=undirected_dual_edges,
            add_self_loops_to_dual_graph=False,
            num_primal_edges_to_keep=600,
            use_decreasing_attention_coefficients=False,
            num_skips=3)
        for primal_graph, dual_graph, petdni in data_loader:
            (primal_graph_out, dual_graph_out, petdni_out, log_info, _,
             _) = down_conv(primal_graph_batch=primal_graph,
                            dual_graph_batch=dual_graph,
                            primal_edge_to_dual_node_idx_batch=petdni)
            self.assertNotEqual(log_info, None)