コード例 #1
0
    def setUp(self):
        self.train_gml_dir = os.path.join(test_s3shared_path, 'test_dataset',
                                          'seen_repos', 'train_graphs')
        self.test_gml_dir = os.path.join(test_s3shared_path, 'test_dataset',
                                         'seen_repos', 'test_graphs')
        self.train_output_dataset_dir = os.path.join(test_s3shared_path,
                                                     'train_model_dataset')
        os.makedirs(self.train_output_dataset_dir, exist_ok=True)
        self.test_output_dataset_dir = os.path.join(test_s3shared_path,
                                                    'test_model_dataset')
        os.makedirs(self.test_output_dataset_dir, exist_ok=True)
        self.train_log_dir = os.path.join(test_s3shared_path, 'train_logs',
                                          get_time())
        self.test_log_dir = os.path.join(test_s3shared_path, 'test_logs',
                                         get_time())
        self.train_gml_files = []
        for file in os.listdir(self.train_gml_dir):
            if file[-4:] == '.gml':
                self.train_gml_files.append(
                    os.path.abspath(os.path.join(self.train_gml_dir, file)))
        self.test_gml_files = []
        for file in os.listdir(self.test_gml_dir):
            if file[-4:] == '.gml':
                self.test_gml_files.append(
                    os.path.abspath(os.path.join(self.test_gml_dir, file)))

        train_task = VarNamingTask.from_gml_files(self.train_gml_files)
        self.train_task_filepath = os.path.join(self.train_gml_dir,
                                                'TrainVarNamingTask.pkl')
        train_task.save(self.train_task_filepath)
        test_task = VarNamingTask.from_gml_files(self.test_gml_files)
        self.test_task_filepath = os.path.join(self.test_gml_dir,
                                               'TestVarNamingTask.pkl')
        test_task.save(self.test_task_filepath)
コード例 #2
0
 def setUp(self):
     self.gml_dir = os.path.join(test_s3shared_path, 'test_dataset', 'repositories')
     self.output_dataset_dir = os.path.join(test_s3shared_path, 'VarNaming_CharCNN_dataset')
     self.test_gml_files = []
     for file in os.listdir(self.gml_dir):
         if file[-4:] == '.gml':
             self.test_gml_files.append(os.path.abspath(os.path.join(self.gml_dir, file)))
     self.task = VarNamingTask.from_gml_files(self.test_gml_files)
     self.max_name_encoding_length = 10
コード例 #3
0
 def setUp(self):
     self.gml_dir = os.path.join(test_s3shared_path, 'test_dataset',
                                 'repositories')
     self.output_dataset_dir = os.path.join(
         test_s3shared_path, 'VarNaming_Fixed_Vocab_dataset')
     os.makedirs(self.output_dataset_dir, exist_ok=True)
     self.test_gml_files = []
     for file in os.listdir(self.gml_dir):
         if file[-4:] == '.gml':
             self.test_gml_files.append(
                 os.path.abspath(os.path.join(self.gml_dir, file)))
     self.task = VarNamingTask.from_gml_files(self.test_gml_files)
    def test_preprocess_task_for_model_no_subtoken_edges(self):
        task = VarNamingTask.from_gml_files(self.test_gml_files)
        task_filepath = os.path.join(self.output_dataset_dir, 'VarNamingTask.pkl')
        task.save(task_filepath)
        VarNamingGSCVocab.preprocess_task(task=task,
                                                output_dir=self.output_dataset_dir,
                                                n_jobs=30,
                                                data_encoder='new',
                                                data_encoder_kwargs=dict(
                                                    max_name_encoding_length=self.max_name_encoding_length,
                                                    add_edges=False),
                                                instance_to_datapoints_kwargs=dict(max_nodes_per_graph=100))
        self.assertNotIn('jobs.txt', os.listdir(self.output_dataset_dir),
                         "The jobs.txt file from process_graph_to_datapoints_with_xargs didn't get deleted")
        self.assertTrue(all(len(i) > 10 for i in os.listdir(self.output_dataset_dir)),
                        "Hacky check for if pickled jobs didn't get deleted")
        reencoding_dir = os.path.join(self.output_dataset_dir, 're-encoding')
        os.mkdir(reencoding_dir)
        data_encoder = VarNamingGSCVocab.DataEncoder.load(os.path.join(self.output_dataset_dir,
                                                                             'VarNamingGSCVocabDataEncoder.pkl'))
        self.assertCountEqual(data_encoder.all_edge_types,
                              list(all_edge_types) + ['reverse_{}'.format(i) for i in all_edge_types],
                              "DataEncoder found weird edge types")
        VarNamingGSCVocab.preprocess_task(task=task,
                                                output_dir=reencoding_dir,
                                                n_jobs=30,
                                                data_encoder=data_encoder)
        orig_datapoints = []
        for file in os.listdir(self.output_dataset_dir):
            if file not in ['VarNamingGSCVocabDataEncoder.pkl', 'VarNamingTask.pkl', 're-encoding']:
                with open(os.path.join(self.output_dataset_dir, file), 'rb') as f:
                    dp = pickle.load(f)
                    self.assertNotIn('SUBTOKEN_USE', dp.edges.keys())
                    self.assertNotIn('reverse_SUBTOKEN_USE', dp.edges.keys())
                    self.assertCountEqual(dp.edges.keys(),
                                          list(all_edge_types) + ['reverse_{}'.format(i) for i in all_edge_types],
                                          'We lost some edge types')
                    orig_datapoints.append(
                        (dp.real_variable_name, dp.origin_file, dp.encoder_hash, dp.edges.keys()))

        reencoded_datapoints = []
        for file in os.listdir(reencoding_dir):
            with open(os.path.join(reencoding_dir, file), 'rb') as f:
                dp = pickle.load(f)
                self.assertNotIn('SUBTOKEN_USE', dp.edges.keys())
                self.assertNotIn('reverse_SUBTOKEN_USE', dp.edges.keys())
                reencoded_datapoints.append(
                    (dp.real_variable_name, dp.origin_file, dp.encoder_hash, dp.edges.keys()))
        self.assertEqual(len(orig_datapoints), len(reencoded_datapoints))
        self.assertCountEqual(orig_datapoints, reencoded_datapoints)
 def setUp(self):
     super().setUp()
     self.task = VarNamingTask.from_gml_files(self.test_gml_files)