Esempio n. 1
0
    def test_encode(self):
        de = VarNamingCharCNNDataEncoder(self.task.graphs_and_instances,
                                         excluded_edge_types=frozenset(),
                                         instance_to_datapoints_kwargs=dict(),
                                         max_name_encoding_length=self.max_name_encoding_length)
        for graph, instances in self.task.graphs_and_instances:
            VarNamingCharCNN.fix_up_edges(graph, instances, frozenset())
            VarNamingCharCNN.extra_graph_processing(graph, instances, de)
            for instance in tqdm(instances):
                dporig = VarNamingCharCNN.instance_to_datapoint(graph, instance, de, max_nodes_per_graph=50)
                dp = deepcopy(dporig)
                de.encode(dp)
                self.assertEqual(list(dp.edges.keys()), sorted(list(de.all_edge_types)),
                                 "Not all adjacency matrices were created")
                for edge_type, adj_mat in dp.edges.items():
                    np.testing.assert_equal(adj_mat.todense(),
                                            dporig.subgraph.get_adjacency_matrix(edge_type).todense())
                    self.assertIsInstance(adj_mat, sp.sparse.coo_matrix,
                                          "Encoding produces adjacency matrix of wrong type")

                self.assertEqual(len(dporig.node_types), len(dp.node_types),
                                 "Type for some node got lost during encoding")
                self.assertEqual([len(i) for i in dporig.node_types], [len(i) for i in dp.node_types],
                                 "Some type for some node got lost during encoding")
                for i in range(len(dp.node_types)):
                    for j in range(len(dp.node_types[i])):
                        self.assertEqual(dp.node_types[i][j], de.all_node_types[dporig.node_types[i][j]],
                                         "Some node type got encoded wrong")

                self.assertEqual(len(dporig.label), len(dp.label),
                                 "Some label subtoken got lost during encoding")
                for i in range(len(dp.label)):
                    self.assertEqual(dp.label[i], de.all_node_name_subtokens[dporig.label[i]])
Esempio n. 2
0
    def test_instance_to_datapoint(self):
        for excluded_edge_types in [syntax_only_excluded_edge_types, frozenset()]:
            de = VarNamingCharCNN.DataEncoder(self.task.graphs_and_instances,
                                              excluded_edge_types=excluded_edge_types,
                                              instance_to_datapoints_kwargs=dict(),
                                              max_name_encoding_length=self.max_name_encoding_length)
            for graph, instances in tqdm(self.task.graphs_and_instances):
                VarNamingCharCNN.fix_up_edges(graph, instances, excluded_edge_types)
                VarNamingCharCNN.extra_graph_processing(graph, instances, de)
                for instance in instances:
                    dp = VarNamingCharCNN.instance_to_datapoint(graph, instance, de, max_nodes_per_graph=100)
                    self.assertEqual(type(dp), VarNamingCharCNNDataPoint)
                    self.assertEqual(len(dp.subgraph.nodes), len(dp.node_types))
                    self.assertEqual(len(dp.subgraph.nodes), len(dp.node_names))

                    name_me_nodes = [i for i in dp.subgraph.nodes_that_represent_variables if
                                     i[1]['identifier'] == de.name_me_flag]
                    self.assertTrue(all(dp.subgraph.is_variable_node(i[0]) for i in name_me_nodes),
                                    "Some non-variable got masked")
                    self.assertEqual(len([i[0] for i in name_me_nodes]), len(instance[1]),
                                     "Wrong number of variables got their names masked")
                    self.assertEqual(1, len(set([i[1]['text'] for i in name_me_nodes])),
                                     "Not all name-masked nodes contain the same name")
                    self.assertTrue(all([i[1]['text'] == dp.real_variable_name for i in name_me_nodes]),
                                    "Some nodes have the wrong name")

                    for node, _ in name_me_nodes:
                        for et in too_useful_edge_types:
                            self.assertNotIn(et, [e[3]['type'] for e in dp.subgraph.all_adjacent_edges(node)])

                    for i, (name, types) in enumerate(zip(dp.node_names, dp.node_types)):
                        self.assertEqual(type(name), str)
                        self.assertGreater(len(name), 0)
                        self.assertEqual(type(types), list)
                        self.assertGreaterEqual(len(types), 1)
                        if dp.subgraph.is_variable_node(i):
                            self.assertCountEqual(set(re.split(r'[,.]', dp.subgraph[i]['reference'])), types)
                            if name != de.name_me_flag:
                                self.assertEqual(name, dp.subgraph[i]['identifier'])
                            else:
                                self.assertEqual(name, de.name_me_flag)
                        else:
                            self.assertEqual(name, de.internal_node_flag)
                            self.assertEqual(len(types), 1)

                    self.assertEqual(dp.label, de.name_to_subtokens(name_me_nodes[0][1]['text']), "Label is wrong")

                    de.encode(dp)
                    self.assertIn('AST', dp.edges.keys())
                    self.assertIn('NEXT_TOKEN', dp.edges.keys())
                    de.save_datapoint(dp, self.output_dataset_dir)
 def extra_graph_processing(graph, instances, data_encoder):
     graph, instances = VarNamingCharCNN.extra_graph_processing(
         graph, instances, data_encoder)
     for node, data in list(graph.nodes):
         if graph.is_variable_node(node):
             node_subtokens = data_encoder.name_to_subtokens(
                 data['identifier'])
             for st in node_subtokens:
                 st_node, _ = graph.add_node(
                     st, identifier=st, type=data_encoder.subtoken_flag)
                 graph.add_edge(node,
                                st_node,
                                type=data_encoder.subtoken_edge_type)
                 graph.add_edge(
                     st_node,
                     node,
                     type=data_encoder.subtoken_reverse_edge_type)
     return graph, instances