def test_preprocess_task_existing_encoding_basic_functionality(self):
     VarNamingClosedVocab.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder='new',
         data_encoder_kwargs=dict(),
         instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     de = VarNamingClosedVocabDataEncoder.load(
         os.path.join(
             self.output_dataset_dir,
             '{}.pkl'.format(VarNamingClosedVocabDataEncoder.__name__)))
     VarNamingClosedVocab.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder=de,
         instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     with self.assertRaises(AssertionError):
         de = BaseDataEncoder(dict(), frozenset())
         VarNamingClosedVocab.preprocess_task(
             self.task,
             output_dir=self.output_dataset_dir,
             n_jobs=30,
             data_encoder=de,
             instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
    def test_preprocess_task_for_model(self):
        task = VarNamingTask.from_gml_files(self.test_gml_files)
        task_filepath = os.path.join(self.output_dataset_dir,
                                     'VarNamingTask.pkl')
        task.save(task_filepath)
        VarNamingClosedVocab.preprocess_task(
            task=task,
            output_dir=self.output_dataset_dir,
            n_jobs=30,
            data_encoder='new',
            data_encoder_kwargs=dict(),
            instance_to_datapoints_kwargs=dict(max_nodes_per_graph=100))
        self.assertNotIn(
            'jobs.txt', os.listdir(self.output_dataset_dir),
            "The jobs.txt file from process_graph_to_datapoints_with_xargs didn't get deleted"
        )
        self.assertTrue(
            all(len(i) > 10 for i in os.listdir(self.output_dataset_dir)),
            "Hacky check for if pickled jobs didn't get deleted")
        reencoding_dir = os.path.join(self.output_dataset_dir, 're-encoding')
        os.mkdir(reencoding_dir)
        data_encoder = VarNamingClosedVocab.DataEncoder.load(
            os.path.join(self.output_dataset_dir,
                         'VarNamingClosedVocabDataEncoder.pkl'))
        self.assertCountEqual(
            data_encoder.all_edge_types,
            list(all_edge_types) +
            ['reverse_{}'.format(i) for i in all_edge_types],
            "DataEncoder found weird edge types")
        VarNamingClosedVocab.preprocess_task(task=task,
                                             output_dir=reencoding_dir,
                                             n_jobs=30,
                                             data_encoder=data_encoder)
        orig_datapoints = []
        for file in os.listdir(self.output_dataset_dir):
            if file not in [
                    'VarNamingClosedVocabDataEncoder.pkl', 'VarNamingTask.pkl',
                    're-encoding'
            ]:
                with open(os.path.join(self.output_dataset_dir, file),
                          'rb') as f:
                    dp = pickle.load(f)
                    self.assertCountEqual(
                        dp.edges.keys(),
                        list(all_edge_types) +
                        ['reverse_{}'.format(i) for i in all_edge_types],
                        'We lost some edge types')
                    orig_datapoints.append(
                        (dp.node_types, dp.node_names, dp.real_variable_name,
                         dp.label, dp.origin_file, dp.encoder_hash,
                         dp.edges.keys()))

        reencoded_datapoints = []
        for file in os.listdir(reencoding_dir):
            with open(os.path.join(reencoding_dir, file), 'rb') as f:
                dp = pickle.load(f)
                reencoded_datapoints.append(
                    (dp.node_types, dp.node_names, dp.real_variable_name,
                     dp.label, dp.origin_file, dp.encoder_hash,
                     dp.edges.keys()))
        self.assertCountEqual(orig_datapoints, reencoded_datapoints)
    def test_instance_to_datapoint(self):
        for excluded_edge_types in [
                syntax_only_excluded_edge_types,
                frozenset()
        ]:
            de = VarNamingClosedVocab.DataEncoder(
                self.task.graphs_and_instances,
                excluded_edge_types=excluded_edge_types,
                instance_to_datapoints_kwargs=dict())
            for graph, instances in tqdm(self.task.graphs_and_instances):
                VarNamingClosedVocab.fix_up_edges(graph, instances,
                                                  excluded_edge_types)
                VarNamingClosedVocab.extra_graph_processing(
                    graph, instances, de)
                for instance in instances:
                    dp = VarNamingClosedVocab.instance_to_datapoint(
                        graph, instance, de, max_nodes_per_graph=100)
                    self.assertEqual(type(dp), VarNamingClosedVocabDataPoint)
                    self.assertEqual(len(dp.subgraph.nodes),
                                     len(dp.node_types))
                    self.assertEqual(len(dp.subgraph.nodes),
                                     len(dp.node_names))
                    name_me_nodes = [
                        i for i in dp.subgraph.nodes_that_represent_variables
                        if i[1]['identifier'] == de.name_me_flag
                    ]
                    self.assertTrue(
                        all(
                            dp.subgraph.is_variable_node(i[0])
                            for i in name_me_nodes),
                        "Some non-variable got masked")
                    self.assertEqual(
                        len([i[0] for i in name_me_nodes]), len(instance[1]),
                        "Wrong number of variables got their names masked")
                    self.assertEqual(
                        1, len(set([i[1]['text'] for i in name_me_nodes])),
                        "Not all name-masked nodes contain the same name")
                    self.assertTrue(
                        all([
                            i[1]['text'] == dp.real_variable_name
                            for i in name_me_nodes
                        ]), "Some nodes have the wrong name")

                    for node, _ in name_me_nodes:
                        for et in too_useful_edge_types:
                            self.assertNotIn(et, [
                                e[3]['type']
                                for e in dp.subgraph.all_adjacent_edges(node)
                            ])

                    for i, (names, types) in enumerate(
                            zip(dp.node_names, dp.node_types)):
                        self.assertEqual(type(names), list)
                        self.assertGreaterEqual(len(names), 1)
                        self.assertEqual(type(types), list)
                        self.assertGreaterEqual(len(types), 1)
                        if dp.subgraph.is_variable_node(i):
                            self.assertCountEqual(
                                set(
                                    re.split(r'[,.]',
                                             dp.subgraph[i]['reference'])),
                                types)
                            if names != [de.name_me_flag]:
                                for name in names:
                                    self.assertIn(
                                        name,
                                        dp.subgraph[i]['identifier'].lower())
                            else:
                                self.assertEqual(len(names), 1)

                    self.assertEqual(
                        dp.label,
                        de.name_to_subtokens(name_me_nodes[0][1]['text']),
                        "Label is wrong")

                    de.encode(dp)
                    self.assertIn('AST', dp.edges.keys())
                    self.assertIn('NEXT_TOKEN', dp.edges.keys())
                    de.save_datapoint(dp, self.output_dataset_dir)
 def test_preprocess_task_type_check_basic_functionality(self):
     task = Task
     with self.assertRaises(AssertionError):
         VarNamingClosedVocab.preprocess_task(task)