def test_name_to_1_hot(self, s): for size in [5, 1, 31, 100]: for special in [True, False]: for internal in [True, False]: one_hot = BaseDataEncoder.name_to_1_hot( s, size, special, internal) decoded = BaseDataEncoder.name_from_1_hot(one_hot) if special: self.assertTrue(all(one_hot[38, :] == 1)) self.assertTrue(one_hot.sum() == one_hot.shape[1]) else: self.assertTrue(all(one_hot[38, :] == 0)) if internal and not special: self.assertTrue(all(one_hot[39, :] == 1)) self.assertTrue(one_hot.sum() == one_hot.shape[1]) else: self.assertTrue(all(one_hot[39, :] == 0)) if not special and not internal: self.assertNotIn('S', decoded) self.assertNotIn('I', decoded) orig = first_cap_re.sub(r'\1_\2', s) orig = all_cap_re.sub(r'\1_\2', orig).lower() orig = orig[:size] self.assertTrue( all(decoded[i] == orig[i] if decoded[i] != 'U' else decoded[i] not in decoder for i in range(len(decoded))))
def test_preprocess_task_existing_encoding_basic_functionality(self): FITBClosedVocab.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new', data_encoder_kwargs=dict(), instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20)) de = FITBClosedVocabDataEncoder.load( os.path.join(self.output_dataset_dir, '{}.pkl'.format(FITBClosedVocabDataEncoder.__name__))) FITBClosedVocab.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20)) with self.assertRaises(AssertionError): de = BaseDataEncoder(dict(), frozenset()) FITBClosedVocab.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20))
def test_preprocess_task_existing_encoding_basic_functionality(self): FITBCharCNN.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new', data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length), instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20)) de = FITBCharCNNDataEncoder.load( os.path.join(self.output_dataset_dir, '{}.pkl'.format(FITBCharCNNDataEncoder.__name__))) FITBCharCNN.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, data_encoder_kwargs=dict( excluded_edge_types=syntax_only_excluded_edge_types, max_name_encoding_length=self.max_name_encoding_length)) with self.assertRaises(AssertionError): de = BaseDataEncoder(dict(), frozenset()) FITBCharCNN.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, data_encoder_kwargs=dict( excluded_edge_types=syntax_only_excluded_edge_types, max_name_encoding_length=self.max_name_encoding_length))
def test_preprocess_task_existing_encoding_basic_functionality_excluded_edges(self): VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new', excluded_edge_types=syntax_only_excluded_edge_types, data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length), instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20)) de = VarNamingCharCNNDataEncoder.load( os.path.join(self.output_dataset_dir, '{}.pkl'.format(VarNamingCharCNNDataEncoder.__name__))) self.assertEqual(de.excluded_edge_types, syntax_only_excluded_edge_types) self.assertCountEqual(de.all_edge_types, list(syntax_only_edge_types) + ['reverse_' + i for i in syntax_only_edge_types]) datapoints = [os.path.join(self.output_dataset_dir, i) for i in os.listdir(self.output_dataset_dir) if i != 'VarNamingCharCNNDataEncoder.pkl'] for dp in datapoints: datapoint = de.load_datapoint(dp) for e in datapoint.edges.keys(): if e.startswith('reverse_'): self.assertIn(e[8:], syntax_only_edge_types) else: self.assertIn(e, syntax_only_edge_types) VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, excluded_edge_types=syntax_only_excluded_edge_types, data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length)) with self.assertRaises(AssertionError): de = BaseDataEncoder(dict(), frozenset()) VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, excluded_edge_types=syntax_only_excluded_edge_types, data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length))
def test_name_to_subtokens_2(self): s = 'AAA' st = BaseDataEncoder.name_to_subtokens(s) self.assertTrue(st == ['aaa'])
def test_name_to_subtokens(self, s): st = BaseDataEncoder.name_to_subtokens(s) self.assertTrue(''.join(st) == s.lower().replace('_', ''))