def test_encode(self):
        de = VarNamingClosedVocabDataEncoder(
            self.task.graphs_and_instances,
            excluded_edge_types=frozenset(),
            instance_to_datapoints_kwargs=dict())
        for graph, instances in self.task.graphs_and_instances:
            VarNamingClosedVocab.fix_up_edges(graph, instances, frozenset())
            VarNamingClosedVocab.extra_graph_processing(graph, instances, de)
            for instance in tqdm(instances):
                dporig = VarNamingClosedVocab.instance_to_datapoint(
                    graph, instance, de, max_nodes_per_graph=50)
                dp = deepcopy(dporig)
                de.encode(dp)
                self.assertEqual(list(dp.edges.keys()),
                                 sorted(list(de.all_edge_types)),
                                 "Not all adjacency matrices were created")
                for edge_type, adj_mat in dp.edges.items():
                    np.testing.assert_equal(
                        adj_mat.todense(),
                        dporig.subgraph.get_adjacency_matrix(
                            edge_type).todense())
                    self.assertIsInstance(
                        adj_mat, sp.sparse.coo_matrix,
                        "Encoding produces adjacency matrix of wrong type")

                self.assertEqual(
                    len(dporig.node_types), len(dp.node_types),
                    "Type for some node got lost during encoding")
                self.assertEqual(
                    [len(i) for i in dporig.node_types],
                    [len(i) for i in dp.node_types],
                    "Some type for some node got lost during encoding")
                for i in range(len(dp.node_types)):
                    for j in range(len(dp.node_types[i])):
                        self.assertEqual(
                            dp.node_types[i][j],
                            de.all_node_types[dporig.node_types[i][j]],
                            "Some node type got encoded wrong")

                self.assertEqual(
                    len(dporig.node_names), len(dp.node_names),
                    "Name for some node got lost during encoding")
                self.assertEqual([len(i) for i in dporig.node_names], [
                    len(i) for i in dp.node_names
                ], "Some name subtoken for some node got lost during encoding")
                for i in range(len(dp.node_names)):
                    for j in range(len(dp.node_names[i])):
                        self.assertEqual(
                            dp.node_names[i][j], de.all_node_name_subtokens[
                                dporig.node_names[i][j]],
                            "Some node name got encoded wrong")

                self.assertEqual(
                    len(dporig.label), len(dp.label),
                    "Some label subtoken got lost during encoding")
                for i in range(len(dp.label)):
                    self.assertEqual(
                        dp.label[i],
                        de.all_node_name_subtokens[dporig.label[i]])
 def test_preprocess_task_existing_encoding_basic_functionality(self):
     VarNamingClosedVocab.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder='new',
         data_encoder_kwargs=dict(),
         instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     de = VarNamingClosedVocabDataEncoder.load(
         os.path.join(
             self.output_dataset_dir,
             '{}.pkl'.format(VarNamingClosedVocabDataEncoder.__name__)))
     VarNamingClosedVocab.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder=de,
         instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     with self.assertRaises(AssertionError):
         de = BaseDataEncoder(dict(), frozenset())
         VarNamingClosedVocab.preprocess_task(
             self.task,
             output_dir=self.output_dataset_dir,
             n_jobs=30,
             data_encoder=de,
             instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
    def instance_to_datapoint(graph: AugmentedAST,
                              instance,
                              data_encoder: VarNamingClosedVocabDataEncoder,
                              max_nodes_per_graph: int = None):
        var_name, locs = instance

        name_me_flag = data_encoder.name_me_flag
        internal_node_flag = data_encoder.internal_node_flag

        subgraph = graph.get_containing_subgraph(locs, max_nodes_per_graph)

        # Flag the variables to be named
        for loc in locs:
            subgraph.nodes[loc]['identifier'] = name_me_flag
            edges_to_prune = subgraph.all_adjacent_edges(
                loc, too_useful_edge_types)
            subgraph._graph.remove_edges_from(edges_to_prune)

        # Remove any disconnected subtoken nodes (they could come from subtokens that are only in the name, and thus be unfair hints)
        for node, data in list(subgraph.nodes):
            if data['type'] == data_encoder.subtoken_flag and subgraph._graph.degree(
                    node) == 0:
                subgraph._graph.remove_node(node)

        # Assemble node types, node names, and label
        subgraph.node_ids_to_ints_from_0()
        node_types = []
        node_names = []
        for node, data in sorted(subgraph.nodes):
            if subgraph.is_variable_node(node):
                node_types.append(
                    sorted(list(set(re.split(r'[,.]', data['reference'])))))
                node_names.append(data['identifier'])
            else:
                node_types.append([data['type']])
                if data['type'] == data_encoder.subtoken_flag:
                    node_names.append(data['identifier'])
                else:
                    node_names.append(internal_node_flag)

        label = data_encoder.name_to_subtokens(var_name)

        return data_encoder.DataPoint(subgraph, node_types, node_names,
                                      var_name, label, graph.origin_file,
                                      data_encoder.encoder_hash)
 def test_init_finds_all_relevant_dataset_information(self):
     de = VarNamingClosedVocabDataEncoder(
         self.task.graphs_and_instances,
         excluded_edge_types=frozenset(),
         instance_to_datapoints_kwargs=dict())
     self.assertCountEqual(de.all_edge_types, list(all_edge_types),
                           "DataEncoder found weird edge types")
     self.assertTrue(
         sorted(de.all_node_name_subtokens.values()) == list(
             range(len(de.all_node_name_subtokens))),
         "DataEncoder didn't use sequential integers for its name subtoken encoding"
     )
     self.assertTrue(
         sorted(de.all_node_types.values()) == list(
             range(len(de.all_node_types))),
         "DataEncoder didn't use sequential integers for its type encoding")
     self.assertEqual(de.all_node_types['__PAD__'], 0)