示例#1
0
 def test_name_to_1_hot(self, s):
     for size in [5, 1, 31, 100]:
         for special in [True, False]:
             for internal in [True, False]:
                 one_hot = BaseDataEncoder.name_to_1_hot(
                     s, size, special, internal)
                 decoded = BaseDataEncoder.name_from_1_hot(one_hot)
                 if special:
                     self.assertTrue(all(one_hot[38, :] == 1))
                     self.assertTrue(one_hot.sum() == one_hot.shape[1])
                 else:
                     self.assertTrue(all(one_hot[38, :] == 0))
                 if internal and not special:
                     self.assertTrue(all(one_hot[39, :] == 1))
                     self.assertTrue(one_hot.sum() == one_hot.shape[1])
                 else:
                     self.assertTrue(all(one_hot[39, :] == 0))
                 if not special and not internal:
                     self.assertNotIn('S', decoded)
                     self.assertNotIn('I', decoded)
                     orig = first_cap_re.sub(r'\1_\2', s)
                     orig = all_cap_re.sub(r'\1_\2', orig).lower()
                     orig = orig[:size]
                     self.assertTrue(
                         all(decoded[i] == orig[i] if decoded[i] != 'U' else
                             decoded[i] not in decoder
                             for i in range(len(decoded))))
 def test_preprocess_task_existing_encoding_basic_functionality(self):
     FITBClosedVocab.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder='new',
         data_encoder_kwargs=dict(),
         instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     de = FITBClosedVocabDataEncoder.load(
         os.path.join(self.output_dataset_dir,
                      '{}.pkl'.format(FITBClosedVocabDataEncoder.__name__)))
     FITBClosedVocab.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder=de,
         instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     with self.assertRaises(AssertionError):
         de = BaseDataEncoder(dict(), frozenset())
         FITBClosedVocab.preprocess_task(
             self.task,
             output_dir=self.output_dataset_dir,
             n_jobs=30,
             data_encoder=de,
             instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
示例#3
0
 def test_preprocess_task_existing_encoding_basic_functionality(self):
     FITBCharCNN.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder='new',
         data_encoder_kwargs=dict(
             max_name_encoding_length=self.max_name_encoding_length),
         instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     de = FITBCharCNNDataEncoder.load(
         os.path.join(self.output_dataset_dir,
                      '{}.pkl'.format(FITBCharCNNDataEncoder.__name__)))
     FITBCharCNN.preprocess_task(
         self.task,
         output_dir=self.output_dataset_dir,
         n_jobs=30,
         data_encoder=de,
         data_encoder_kwargs=dict(
             excluded_edge_types=syntax_only_excluded_edge_types,
             max_name_encoding_length=self.max_name_encoding_length))
     with self.assertRaises(AssertionError):
         de = BaseDataEncoder(dict(), frozenset())
         FITBCharCNN.preprocess_task(
             self.task,
             output_dir=self.output_dataset_dir,
             n_jobs=30,
             data_encoder=de,
             data_encoder_kwargs=dict(
                 excluded_edge_types=syntax_only_excluded_edge_types,
                 max_name_encoding_length=self.max_name_encoding_length))
示例#4
0
 def test_preprocess_task_existing_encoding_basic_functionality_excluded_edges(self):
     VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new',
                                      excluded_edge_types=syntax_only_excluded_edge_types,
                                      data_encoder_kwargs=dict(
                                          max_name_encoding_length=self.max_name_encoding_length),
                                      instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
     de = VarNamingCharCNNDataEncoder.load(
         os.path.join(self.output_dataset_dir, '{}.pkl'.format(VarNamingCharCNNDataEncoder.__name__)))
     self.assertEqual(de.excluded_edge_types, syntax_only_excluded_edge_types)
     self.assertCountEqual(de.all_edge_types,
                           list(syntax_only_edge_types) + ['reverse_' + i for i in syntax_only_edge_types])
     datapoints = [os.path.join(self.output_dataset_dir, i) for i in os.listdir(self.output_dataset_dir) if
                   i != 'VarNamingCharCNNDataEncoder.pkl']
     for dp in datapoints:
         datapoint = de.load_datapoint(dp)
         for e in datapoint.edges.keys():
             if e.startswith('reverse_'):
                 self.assertIn(e[8:], syntax_only_edge_types)
             else:
                 self.assertIn(e, syntax_only_edge_types)
     VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de,
                                      excluded_edge_types=syntax_only_excluded_edge_types,
                                      data_encoder_kwargs=dict(
                                          max_name_encoding_length=self.max_name_encoding_length))
     with self.assertRaises(AssertionError):
         de = BaseDataEncoder(dict(), frozenset())
         VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de,
                                          excluded_edge_types=syntax_only_excluded_edge_types,
                                          data_encoder_kwargs=dict(
                                              max_name_encoding_length=self.max_name_encoding_length))
示例#5
0
 def test_name_to_subtokens_2(self):
     s = 'AAA'
     st = BaseDataEncoder.name_to_subtokens(s)
     self.assertTrue(st == ['aaa'])
示例#6
0
 def test_name_to_subtokens(self, s):
     st = BaseDataEncoder.name_to_subtokens(s)
     self.assertTrue(''.join(st) == s.lower().replace('_', ''))