예제 #1
0
    def setUp(self):
        self.train_gml_dir = os.path.join(test_s3shared_path, 'test_dataset',
                                          'seen_repos', 'train_graphs')
        self.test_gml_dir = os.path.join(test_s3shared_path, 'test_dataset',
                                         'seen_repos', 'test_graphs')
        self.train_output_dataset_dir = os.path.join(test_s3shared_path,
                                                     'train_model_dataset')
        os.makedirs(self.train_output_dataset_dir, exist_ok=True)
        self.test_output_dataset_dir = os.path.join(test_s3shared_path,
                                                    'test_model_dataset')
        os.makedirs(self.test_output_dataset_dir, exist_ok=True)
        self.train_log_dir = os.path.join(test_s3shared_path, 'train_logs',
                                          get_time())
        self.test_log_dir = os.path.join(test_s3shared_path, 'test_logs',
                                         get_time())
        self.train_gml_files = []
        for file in os.listdir(self.train_gml_dir)[:10]:
            if file[-4:] == '.gml':
                self.train_gml_files.append(
                    os.path.abspath(os.path.join(self.train_gml_dir, file)))
        self.test_gml_files = []
        for file in os.listdir(self.test_gml_dir)[:10]:
            if file[-4:] == '.gml':
                self.test_gml_files.append(
                    os.path.abspath(os.path.join(self.test_gml_dir, file)))

        train_task = FITBTask.from_gml_files(self.train_gml_files)
        self.train_task_filepath = os.path.join(self.train_gml_dir,
                                                'TrainFITBTask.pkl')
        train_task.save(self.train_task_filepath)
        test_task = FITBTask.from_gml_files(self.test_gml_files)
        self.test_task_filepath = os.path.join(self.test_gml_dir,
                                               'TestFITBTask.pkl')
        test_task.save(self.test_task_filepath)
 def test_preprocess_task_for_model_with_FixedVocab(self):
     task = FITBTask.from_gml_files(self.test_gml_files)
     task_filepath = os.path.join(self.output_dataset_dir, 'FITBTask.pkl')
     task.save(task_filepath)
     preprocess_task_for_model(
         234,
         'FITBTask',
         task_filepath,
         'FITBFixedVocabGGNN',
         dataset_output_dir=self.output_dataset_dir,
         n_jobs=15,
         excluded_edge_types=frozenset(),
         data_encoder='new',
         data_encoder_kwargs=dict(),
         instance_to_datapoints_kwargs=dict(max_nodes_per_graph=100))
     self.assertNotIn(
         'jobs.txt', os.listdir(self.output_dataset_dir),
         "The jobs.txt file from process_graph_to_datapoints_with_xargs didn't get deleted"
     )
     self.assertTrue(
         all(len(i) > 10 for i in os.listdir(self.output_dataset_dir)),
         "Hacky check for if pickled jobs didn't get deleted")
     reencoding_dir = os.path.join(self.output_dataset_dir, 're-encoding')
     os.mkdir(reencoding_dir)
     preprocess_task_for_model(234,
                               'FITBTask',
                               task_filepath,
                               'FITBFixedVocabGGNN',
                               dataset_output_dir=reencoding_dir,
                               n_jobs=15,
                               excluded_edge_types=frozenset(),
                               data_encoder=os.path.join(
                                   self.output_dataset_dir,
                                   'FITBFixedVocabDataEncoder.pkl'))
     orig_datapoints = []
     for file in os.listdir(self.output_dataset_dir):
         if file not in [
                 'FITBFixedVocabDataEncoder.pkl', 'FITBTask.pkl',
                 're-encoding'
         ]:
             with open(os.path.join(self.output_dataset_dir, file),
                       'rb') as f:
                 dp = pickle.load(f)
                 orig_datapoints.append(
                     (dp.node_types, dp.node_names, dp.label,
                      dp.origin_file, dp.encoder_hash, dp.edges.keys()))
     reencoded_datapoints = []
     for file in os.listdir(reencoding_dir):
         with open(os.path.join(reencoding_dir, file), 'rb') as f:
             dp = pickle.load(f)
             reencoded_datapoints.append(
                 (dp.node_types, dp.node_names, dp.label, dp.origin_file,
                  dp.encoder_hash, dp.edges.keys()))
     self.assertNotIn(
         'jobs.txt', os.listdir(reencoding_dir),
         "The jobs.txt file from process_graph_to_datapoints_with_xargs didn't get deleted"
     )
     self.assertTrue(all(len(i) > 10 for i in os.listdir(reencoding_dir)),
                     "Hacky check for if pickled jobs didn't get deleted")
     self.assertCountEqual(orig_datapoints, reencoded_datapoints)
    def test_preprocess_task_for_model(self):
        task = FITBTask.from_gml_files(self.test_gml_files)
        task_filepath = os.path.join(self.output_dataset_dir, 'FITBTask.pkl')
        task.save(task_filepath)
        FITBGSCVocab.preprocess_task(
            task=task,
            output_dir=self.output_dataset_dir,
            n_jobs=30,
            data_encoder='new',
            data_encoder_kwargs=dict(
                max_name_encoding_length=self.max_name_encoding_length),
            instance_to_datapoints_kwargs=dict(max_nodes_per_graph=100))
        self.assertNotIn(
            'jobs.txt', os.listdir(self.output_dataset_dir),
            "The jobs.txt file from process_graph_to_datapoints_with_xargs didn't get deleted"
        )
        self.assertTrue(
            all(len(i) > 10 for i in os.listdir(self.output_dataset_dir)),
            "Hacky check for if pickled jobs didn't get deleted")
        reencoding_dir = os.path.join(self.output_dataset_dir, 're-encoding')
        os.mkdir(reencoding_dir)
        data_encoder = FITBGSCVocab.DataEncoder.load(
            os.path.join(self.output_dataset_dir,
                         'FITBGSCVocabDataEncoder.pkl'))
        self.assertCountEqual(
            data_encoder.all_edge_types,
            list(all_edge_types) +
            ['reverse_{}'.format(i) for i in all_edge_types] +
            ['SUBTOKEN_USE', 'reverse_SUBTOKEN_USE'],
            "DataEncoder found weird edge types")
        FITBGSCVocab.preprocess_task(task=task,
                                     output_dir=reencoding_dir,
                                     n_jobs=30,
                                     data_encoder=data_encoder)
        orig_datapoints = []
        for file in os.listdir(self.output_dataset_dir):
            if file not in [
                    'FITBGSCVocabDataEncoder.pkl', 'FITBTask.pkl',
                    're-encoding'
            ]:
                with open(os.path.join(self.output_dataset_dir, file),
                          'rb') as f:
                    dp = pickle.load(f)
                    self.assertCountEqual(
                        dp.edges.keys(),
                        list(all_edge_types) +
                        ['reverse_{}'.format(i) for i in all_edge_types] +
                        ['SUBTOKEN_USE', 'reverse_SUBTOKEN_USE'],
                        'We lost some edge types')
                    orig_datapoints.append(
                        (dp.origin_file, dp.encoder_hash, dp.edges.keys()))

        reencoded_datapoints = []
        for file in os.listdir(reencoding_dir):
            with open(os.path.join(reencoding_dir, file), 'rb') as f:
                dp = pickle.load(f)
                reencoded_datapoints.append(
                    (dp.origin_file, dp.encoder_hash, dp.edges.keys()))
        self.assertCountEqual(orig_datapoints, reencoded_datapoints)
예제 #4
0
 def setUp(self):
     self.gml_dir = os.path.join(test_s3shared_path, 'test_dataset', 'repositories')
     self.output_dataset_dir = os.path.join(test_s3shared_path, 'FITB_Closed_Vocab_dataset')
     os.makedirs(self.output_dataset_dir, exist_ok=True)
     self.test_gml_files = []
     for file in os.listdir(self.gml_dir):
         if file[-4:] == '.gml':
             self.test_gml_files.append(os.path.abspath(os.path.join(self.gml_dir, file)))
     self.task = FITBTask.from_gml_files(self.test_gml_files)
예제 #5
0
 def setUp(self):
     self.gml_dir = os.path.join(test_s3shared_path, 'test_dataset', 'repositories')
     self.output_dataset_dir = os.path.join(test_s3shared_path, 'FITB_NameGraphVocab_dataset')
     self.test_gml_files = []
     for file in os.listdir(self.gml_dir):
         if file[-4:] == '.gml':
             self.test_gml_files.append(os.path.abspath(os.path.join(self.gml_dir, file)))
     self.task = FITBTask.from_gml_files(self.test_gml_files)
     self.max_name_encoding_length = 10
 def setUp(self):
     super().setUp()
     self.task = FITBTask.from_gml_files(self.test_gml_files)