def test_preprocess_task_existing_encoding_basic_functionality_excluded_edges(self): VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new', excluded_edge_types=syntax_only_excluded_edge_types, data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length), instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20)) de = VarNamingCharCNNDataEncoder.load( os.path.join(self.output_dataset_dir, '{}.pkl'.format(VarNamingCharCNNDataEncoder.__name__))) self.assertEqual(de.excluded_edge_types, syntax_only_excluded_edge_types) self.assertCountEqual(de.all_edge_types, list(syntax_only_edge_types) + ['reverse_' + i for i in syntax_only_edge_types]) datapoints = [os.path.join(self.output_dataset_dir, i) for i in os.listdir(self.output_dataset_dir) if i != 'VarNamingCharCNNDataEncoder.pkl'] for dp in datapoints: datapoint = de.load_datapoint(dp) for e in datapoint.edges.keys(): if e.startswith('reverse_'): self.assertIn(e[8:], syntax_only_edge_types) else: self.assertIn(e, syntax_only_edge_types) VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, excluded_edge_types=syntax_only_excluded_edge_types, data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length)) with self.assertRaises(AssertionError): de = BaseDataEncoder(dict(), frozenset()) VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, excluded_edge_types=syntax_only_excluded_edge_types, data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length))
def test_preprocess_task_existing_encoding_basic_functionality(self): VarNamingCharCNN.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new', data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length), instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20)) de = VarNamingCharCNNDataEncoder.load( os.path.join(self.output_dataset_dir, '{}.pkl'.format( VarNamingCharCNNDataEncoder.__name__))) VarNamingCharCNN.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, data_encoder_kwargs=dict( excluded_edge_types=syntax_only_excluded_edge_types, max_name_encoding_length=self.max_name_encoding_length)) with self.assertRaises(AssertionError): de = BaseDataEncoder(dict(), frozenset()) VarNamingCharCNN.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, data_encoder_kwargs=dict( excluded_edge_types=syntax_only_excluded_edge_types, max_name_encoding_length=self.max_name_encoding_length))