def test_batchify_and_unbatchify_are_inverses(self): VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new', data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length), instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20)) with open(os.path.join(self.output_dataset_dir, '{}.pkl'.format(VarNamingCharCNN.DataEncoder.__name__)), 'rb') as f: de = pickle.load(f) model = VarNamingCharCNNGGNN(data_encoder=de, hidden_size=17, type_emb_size=5, name_emb_size=7, n_msg_pass_iters=1, max_name_length=8) model.collect_params().initialize('Xavier', ctx=mx.cpu()) datapoints = [os.path.join(self.output_dataset_dir, i) for i in os.listdir(self.output_dataset_dir) if 'Encoder.pkl' not in i] batch_size = 64 for b in tqdm(range(int(math.ceil(len(datapoints) / batch_size)))): batchdpspaths = datapoints[batch_size * b: batch_size * (b + 1)] batchdps = [de.load_datapoint(b) for b in batchdpspaths] batchified = model.batchify(batchdpspaths, ctx=mx.cpu()) model_output = model(batchified.data) self.assertEqual(len(model_output.shape), 3, "model_output is the wrong size") self.assertEqual(model_output.shape[0], len(batchified.data.batch_sizes), "model_output has wrong batch dimension") self.assertEqual(model_output.shape[1], model.max_name_length, "model_output is outputting wrong length names") self.assertEqual(model_output.shape[2], len(de.all_node_name_subtokens), "model_output's output dimension is off") unbatchified = model.unbatchify(batchified, model_output) self.assertEqual(len(batchdps), len(unbatchified), "We lost some datapoints somewhere") self.assertEqual(sum(len(dp.node_names) for dp in batchdps), sum(batchified.data.batch_sizes).asscalar()) self.assertEqual(sum(len(dp.node_types) for dp in batchdps), sum(batchified.data.batch_sizes).asscalar()) self.assertEqual(len(batchified.data.target_locations), sum([dp.node_names.count('__NAME_ME!__') for dp in batchdps]), "Some target location went missing") for adj_mat in batchified.data.edges.values(): self.assertEqual(adj_mat.shape, ( sum(len(dp.node_names) for dp in batchdps), sum(len(dp.node_names) for dp in batchdps)), "Batchified adjacency matrix is wrong size") for i, (dp, (prediction, label)) in enumerate(zip(batchdps, unbatchified)): self.assertEqual(len(dp.node_types), len(dp.node_names), "node_types and node_names arrays are different lengths") self.assertEqual(len(dp.node_types), batchified.data.batch_sizes[i], "batch_sizes doesn't match datapoint's array size") self.assertEqual(de.name_to_subtokens(dp.real_variable_name), label, "Something got labeled wrong")
def test_preprocess_task_for_model(self): task = VarNamingTask.from_gml_files(self.test_gml_files) task_filepath = os.path.join(self.output_dataset_dir, 'VarNamingTask.pkl') task.save(task_filepath) VarNamingCharCNN.preprocess_task(task=task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new', data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length), instance_to_datapoints_kwargs=dict(max_nodes_per_graph=100)) self.assertNotIn('jobs.txt', os.listdir(self.output_dataset_dir), "The jobs.txt file from process_graph_to_datapoints_with_xargs didn't get deleted") self.assertTrue(all(len(i) > 10 for i in os.listdir(self.output_dataset_dir)), "Hacky check for if pickled jobs didn't get deleted") reencoding_dir = os.path.join(self.output_dataset_dir, 're-encoding') os.mkdir(reencoding_dir) data_encoder = VarNamingCharCNN.DataEncoder.load(os.path.join(self.output_dataset_dir, 'VarNamingCharCNNDataEncoder.pkl')) self.assertCountEqual(data_encoder.all_edge_types, list(all_edge_types) + ['reverse_{}'.format(i) for i in all_edge_types], "DataEncoder found weird edge types") VarNamingCharCNN.preprocess_task(task=task, output_dir=reencoding_dir, n_jobs=30, data_encoder=data_encoder) orig_datapoints = [] for file in os.listdir(self.output_dataset_dir): if file not in ['VarNamingCharCNNDataEncoder.pkl', 'VarNamingTask.pkl', 're-encoding']: with open(os.path.join(self.output_dataset_dir, file), 'rb') as f: dp = pickle.load(f) self.assertCountEqual(dp.edges.keys(), list(all_edge_types) + ['reverse_{}'.format(i) for i in all_edge_types], 'We lost some edge types') orig_datapoints.append( (dp.node_types, dp.node_names, dp.real_variable_name, dp.label, dp.origin_file, dp.encoder_hash, dp.edges.keys())) reencoded_datapoints = [] for file in os.listdir(reencoding_dir): with open(os.path.join(reencoding_dir, file), 'rb') as f: dp = pickle.load(f) reencoded_datapoints.append( (dp.node_types, dp.node_names, dp.real_variable_name, dp.label, dp.origin_file, dp.encoder_hash, dp.edges.keys())) self.assertCountEqual(orig_datapoints, reencoded_datapoints)
def test_preprocess_task_existing_encoding_basic_functionality_excluded_edges(self): VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new', excluded_edge_types=syntax_only_excluded_edge_types, data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length), instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20)) de = VarNamingCharCNNDataEncoder.load( os.path.join(self.output_dataset_dir, '{}.pkl'.format(VarNamingCharCNNDataEncoder.__name__))) self.assertEqual(de.excluded_edge_types, syntax_only_excluded_edge_types) self.assertCountEqual(de.all_edge_types, list(syntax_only_edge_types) + ['reverse_' + i for i in syntax_only_edge_types]) datapoints = [os.path.join(self.output_dataset_dir, i) for i in os.listdir(self.output_dataset_dir) if i != 'VarNamingCharCNNDataEncoder.pkl'] for dp in datapoints: datapoint = de.load_datapoint(dp) for e in datapoint.edges.keys(): if e.startswith('reverse_'): self.assertIn(e[8:], syntax_only_edge_types) else: self.assertIn(e, syntax_only_edge_types) VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, excluded_edge_types=syntax_only_excluded_edge_types, data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length)) with self.assertRaises(AssertionError): de = BaseDataEncoder(dict(), frozenset()) VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, excluded_edge_types=syntax_only_excluded_edge_types, data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length))
def test_preprocess_task_existing_encoding_basic_functionality(self): VarNamingCharCNN.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new', data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length), instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20)) de = VarNamingCharCNNDataEncoder.load( os.path.join(self.output_dataset_dir, '{}.pkl'.format( VarNamingCharCNNDataEncoder.__name__))) VarNamingCharCNN.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, data_encoder_kwargs=dict( excluded_edge_types=syntax_only_excluded_edge_types, max_name_encoding_length=self.max_name_encoding_length)) with self.assertRaises(AssertionError): de = BaseDataEncoder(dict(), frozenset()) VarNamingCharCNN.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, data_encoder_kwargs=dict( excluded_edge_types=syntax_only_excluded_edge_types, max_name_encoding_length=self.max_name_encoding_length))
def test_preprocess_task_type_check_basic_functionality(self): task = Task with self.assertRaises(AssertionError): VarNamingCharCNN.preprocess_task(task)