Exemple #1
0
    def test_encode(self):
        de = VarNamingCharCNNDataEncoder(self.task.graphs_and_instances,
                                         excluded_edge_types=frozenset(),
                                         instance_to_datapoints_kwargs=dict(),
                                         max_name_encoding_length=self.max_name_encoding_length)
        for graph, instances in self.task.graphs_and_instances:
            VarNamingCharCNN.fix_up_edges(graph, instances, frozenset())
            VarNamingCharCNN.extra_graph_processing(graph, instances, de)
            for instance in tqdm(instances):
                dporig = VarNamingCharCNN.instance_to_datapoint(graph, instance, de, max_nodes_per_graph=50)
                dp = deepcopy(dporig)
                de.encode(dp)
                self.assertEqual(list(dp.edges.keys()), sorted(list(de.all_edge_types)),
                                 "Not all adjacency matrices were created")
                for edge_type, adj_mat in dp.edges.items():
                    np.testing.assert_equal(adj_mat.todense(),
                                            dporig.subgraph.get_adjacency_matrix(edge_type).todense())
                    self.assertIsInstance(adj_mat, sp.sparse.coo_matrix,
                                          "Encoding produces adjacency matrix of wrong type")

                self.assertEqual(len(dporig.node_types), len(dp.node_types),
                                 "Type for some node got lost during encoding")
                self.assertEqual([len(i) for i in dporig.node_types], [len(i) for i in dp.node_types],
                                 "Some type for some node got lost during encoding")
                for i in range(len(dp.node_types)):
                    for j in range(len(dp.node_types[i])):
                        self.assertEqual(dp.node_types[i][j], de.all_node_types[dporig.node_types[i][j]],
                                         "Some node type got encoded wrong")

                self.assertEqual(len(dporig.label), len(dp.label),
                                 "Some label subtoken got lost during encoding")
                for i in range(len(dp.label)):
                    self.assertEqual(dp.label[i], de.all_node_name_subtokens[dporig.label[i]])
Exemple #2
0
 def test_preprocess_task_existing_encoding_basic_functionality_excluded_edges(self):
     VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new',
                                      excluded_edge_types=syntax_only_excluded_edge_types,
                                      data_encoder_kwargs=dict(
                                          max_name_encoding_length=self.max_name_encoding_length),
                                      instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     de = VarNamingCharCNNDataEncoder.load(
         os.path.join(self.output_dataset_dir, '{}.pkl'.format(VarNamingCharCNNDataEncoder.__name__)))
     self.assertEqual(de.excluded_edge_types, syntax_only_excluded_edge_types)
     self.assertCountEqual(de.all_edge_types,
                           list(syntax_only_edge_types) + ['reverse_' + i for i in syntax_only_edge_types])
     datapoints = [os.path.join(self.output_dataset_dir, i) for i in os.listdir(self.output_dataset_dir) if
                   i != 'VarNamingCharCNNDataEncoder.pkl']
     for dp in datapoints:
         datapoint = de.load_datapoint(dp)
         for e in datapoint.edges.keys():
             if e.startswith('reverse_'):
                 self.assertIn(e[8:], syntax_only_edge_types)
             else:
                 self.assertIn(e, syntax_only_edge_types)
     VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de,
                                      excluded_edge_types=syntax_only_excluded_edge_types,
                                      data_encoder_kwargs=dict(
                                          max_name_encoding_length=self.max_name_encoding_length))
     with self.assertRaises(AssertionError):
         de = BaseDataEncoder(dict(), frozenset())
         VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de,
                                          excluded_edge_types=syntax_only_excluded_edge_types,
                                          data_encoder_kwargs=dict(
                                              max_name_encoding_length=self.max_name_encoding_length))
Exemple #3
0
 def test_preprocess_task_existing_encoding_basic_functionality(self):
     VarNamingCharCNN.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder='new',
         data_encoder_kwargs=dict(
             max_name_encoding_length=self.max_name_encoding_length),
         instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     de = VarNamingCharCNNDataEncoder.load(
         os.path.join(self.output_dataset_dir, '{}.pkl'.format(
             VarNamingCharCNNDataEncoder.__name__)))
     VarNamingCharCNN.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder=de,
         data_encoder_kwargs=dict(
             excluded_edge_types=syntax_only_excluded_edge_types,
             max_name_encoding_length=self.max_name_encoding_length))
     with self.assertRaises(AssertionError):
         de = BaseDataEncoder(dict(), frozenset())
         VarNamingCharCNN.preprocess_task(
             self.task,
             output_dir=self.output_dataset_dir,
             n_jobs=30,
             data_encoder=de,
             data_encoder_kwargs=dict(
                 excluded_edge_types=syntax_only_excluded_edge_types,
                 max_name_encoding_length=self.max_name_encoding_length))
Exemple #4
0
 def test_init_finds_all_relevant_dataset_information(self):
     de = VarNamingCharCNNDataEncoder(self.task.graphs_and_instances, excluded_edge_types=frozenset(),
                                      instance_to_datapoints_kwargs=dict(),
                                      max_name_encoding_length=self.max_name_encoding_length)
     self.assertCountEqual(de.all_edge_types, list(all_edge_types), "DataEncoder found weird edge types")
     self.assertTrue(sorted(de.all_node_types.values()) == list(range(len(de.all_node_types))),
                     "DataEncoder didn't use sequential integers for its type encoding")
     self.assertEqual(de.max_name_encoding_length, self.max_name_encoding_length)
     self.assertEqual(de.all_node_types['__PAD__'], 0)