示例#1
0
 def test_batchify_and_unbatchify_are_inverses(self):
     VarNamingCharCNN.preprocess_task(self.task,
                                      output_dir=self.output_dataset_dir,
                                      n_jobs=30,
                                      data_encoder='new',
                                      data_encoder_kwargs=dict(
                                          max_name_encoding_length=self.max_name_encoding_length),
                                      instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     with open(os.path.join(self.output_dataset_dir, '{}.pkl'.format(VarNamingCharCNN.DataEncoder.__name__)),
               'rb') as f:
         de = pickle.load(f)
     model = VarNamingCharCNNGGNN(data_encoder=de,
                                  hidden_size=17,
                                  type_emb_size=5,
                                  name_emb_size=7,
                                  n_msg_pass_iters=1,
                                  max_name_length=8)
     model.collect_params().initialize('Xavier', ctx=mx.cpu())
     datapoints = [os.path.join(self.output_dataset_dir, i) for i in os.listdir(self.output_dataset_dir) if
                   'Encoder.pkl' not in i]
     batch_size = 64
     for b in tqdm(range(int(math.ceil(len(datapoints) / batch_size)))):
         batchdpspaths = datapoints[batch_size * b: batch_size * (b + 1)]
         batchdps = [de.load_datapoint(b) for b in batchdpspaths]
         batchified = model.batchify(batchdpspaths, ctx=mx.cpu())
         model_output = model(batchified.data)
         self.assertEqual(len(model_output.shape), 3, "model_output is the wrong size")
         self.assertEqual(model_output.shape[0], len(batchified.data.batch_sizes),
                          "model_output has wrong batch dimension")
         self.assertEqual(model_output.shape[1], model.max_name_length,
                          "model_output is outputting wrong length names")
         self.assertEqual(model_output.shape[2], len(de.all_node_name_subtokens),
                          "model_output's output dimension is off")
         unbatchified = model.unbatchify(batchified, model_output)
         self.assertEqual(len(batchdps), len(unbatchified), "We lost some datapoints somewhere")
         self.assertEqual(sum(len(dp.node_names) for dp in batchdps), sum(batchified.data.batch_sizes).asscalar())
         self.assertEqual(sum(len(dp.node_types) for dp in batchdps), sum(batchified.data.batch_sizes).asscalar())
         self.assertEqual(len(batchified.data.target_locations),
                          sum([dp.node_names.count('__NAME_ME!__') for dp in
                               batchdps]),
                          "Some target location went missing")
         for adj_mat in batchified.data.edges.values():
             self.assertEqual(adj_mat.shape, (
                 sum(len(dp.node_names) for dp in batchdps), sum(len(dp.node_names) for dp in batchdps)),
                              "Batchified adjacency matrix is wrong size")
         for i, (dp, (prediction, label)) in enumerate(zip(batchdps, unbatchified)):
             self.assertEqual(len(dp.node_types), len(dp.node_names),
                              "node_types and node_names arrays are different lengths")
             self.assertEqual(len(dp.node_types), batchified.data.batch_sizes[i],
                              "batch_sizes doesn't match datapoint's array size")
             self.assertEqual(de.name_to_subtokens(dp.real_variable_name), label, "Something got labeled wrong")
示例#2
0
    def test_preprocess_task_for_model(self):
        task = VarNamingTask.from_gml_files(self.test_gml_files)
        task_filepath = os.path.join(self.output_dataset_dir, 'VarNamingTask.pkl')
        task.save(task_filepath)
        VarNamingCharCNN.preprocess_task(task=task,
                                         output_dir=self.output_dataset_dir,
                                         n_jobs=30,
                                         data_encoder='new',
                                         data_encoder_kwargs=dict(
                                             max_name_encoding_length=self.max_name_encoding_length),
                                         instance_to_datapoints_kwargs=dict(max_nodes_per_graph=100))
        self.assertNotIn('jobs.txt', os.listdir(self.output_dataset_dir),
                         "The jobs.txt file from process_graph_to_datapoints_with_xargs didn't get deleted")
        self.assertTrue(all(len(i) > 10 for i in os.listdir(self.output_dataset_dir)),
                        "Hacky check for if pickled jobs didn't get deleted")
        reencoding_dir = os.path.join(self.output_dataset_dir, 're-encoding')
        os.mkdir(reencoding_dir)
        data_encoder = VarNamingCharCNN.DataEncoder.load(os.path.join(self.output_dataset_dir,
                                                                      'VarNamingCharCNNDataEncoder.pkl'))
        self.assertCountEqual(data_encoder.all_edge_types,
                              list(all_edge_types) + ['reverse_{}'.format(i) for i in all_edge_types],
                              "DataEncoder found weird edge types")
        VarNamingCharCNN.preprocess_task(task=task,
                                         output_dir=reencoding_dir,
                                         n_jobs=30,
                                         data_encoder=data_encoder)
        orig_datapoints = []
        for file in os.listdir(self.output_dataset_dir):
            if file not in ['VarNamingCharCNNDataEncoder.pkl', 'VarNamingTask.pkl', 're-encoding']:
                with open(os.path.join(self.output_dataset_dir, file), 'rb') as f:
                    dp = pickle.load(f)
                    self.assertCountEqual(dp.edges.keys(),
                                          list(all_edge_types) + ['reverse_{}'.format(i) for i in all_edge_types],
                                          'We lost some edge types')
                    orig_datapoints.append(
                        (dp.node_types, dp.node_names, dp.real_variable_name, dp.label, dp.origin_file, dp.encoder_hash,
                         dp.edges.keys()))

        reencoded_datapoints = []
        for file in os.listdir(reencoding_dir):
            with open(os.path.join(reencoding_dir, file), 'rb') as f:
                dp = pickle.load(f)
                reencoded_datapoints.append(
                    (dp.node_types, dp.node_names, dp.real_variable_name, dp.label, dp.origin_file, dp.encoder_hash,
                     dp.edges.keys()))
        self.assertCountEqual(orig_datapoints, reencoded_datapoints)
示例#3
0
 def test_preprocess_task_existing_encoding_basic_functionality_excluded_edges(self):
     VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new',
                                      excluded_edge_types=syntax_only_excluded_edge_types,
                                      data_encoder_kwargs=dict(
                                          max_name_encoding_length=self.max_name_encoding_length),
                                      instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     de = VarNamingCharCNNDataEncoder.load(
         os.path.join(self.output_dataset_dir, '{}.pkl'.format(VarNamingCharCNNDataEncoder.__name__)))
     self.assertEqual(de.excluded_edge_types, syntax_only_excluded_edge_types)
     self.assertCountEqual(de.all_edge_types,
                           list(syntax_only_edge_types) + ['reverse_' + i for i in syntax_only_edge_types])
     datapoints = [os.path.join(self.output_dataset_dir, i) for i in os.listdir(self.output_dataset_dir) if
                   i != 'VarNamingCharCNNDataEncoder.pkl']
     for dp in datapoints:
         datapoint = de.load_datapoint(dp)
         for e in datapoint.edges.keys():
             if e.startswith('reverse_'):
                 self.assertIn(e[8:], syntax_only_edge_types)
             else:
                 self.assertIn(e, syntax_only_edge_types)
     VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de,
                                      excluded_edge_types=syntax_only_excluded_edge_types,
                                      data_encoder_kwargs=dict(
                                          max_name_encoding_length=self.max_name_encoding_length))
     with self.assertRaises(AssertionError):
         de = BaseDataEncoder(dict(), frozenset())
         VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de,
                                          excluded_edge_types=syntax_only_excluded_edge_types,
                                          data_encoder_kwargs=dict(
                                              max_name_encoding_length=self.max_name_encoding_length))
示例#4
0
 def test_preprocess_task_existing_encoding_basic_functionality(self):
     VarNamingCharCNN.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder='new',
         data_encoder_kwargs=dict(
             max_name_encoding_length=self.max_name_encoding_length),
         instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     de = VarNamingCharCNNDataEncoder.load(
         os.path.join(self.output_dataset_dir, '{}.pkl'.format(
             VarNamingCharCNNDataEncoder.__name__)))
     VarNamingCharCNN.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder=de,
         data_encoder_kwargs=dict(
             excluded_edge_types=syntax_only_excluded_edge_types,
             max_name_encoding_length=self.max_name_encoding_length))
     with self.assertRaises(AssertionError):
         de = BaseDataEncoder(dict(), frozenset())
         VarNamingCharCNN.preprocess_task(
             self.task,
             output_dir=self.output_dataset_dir,
             n_jobs=30,
             data_encoder=de,
             data_encoder_kwargs=dict(
                 excluded_edge_types=syntax_only_excluded_edge_types,
                 max_name_encoding_length=self.max_name_encoding_length))
示例#5
0
 def test_preprocess_task_type_check_basic_functionality(self):
     task = Task
     with self.assertRaises(AssertionError):
         VarNamingCharCNN.preprocess_task(task)