def test_encode(self): de = VarNamingCharCNNDataEncoder(self.task.graphs_and_instances, excluded_edge_types=frozenset(), instance_to_datapoints_kwargs=dict(), max_name_encoding_length=self.max_name_encoding_length) for graph, instances in self.task.graphs_and_instances: VarNamingCharCNN.fix_up_edges(graph, instances, frozenset()) VarNamingCharCNN.extra_graph_processing(graph, instances, de) for instance in tqdm(instances): dporig = VarNamingCharCNN.instance_to_datapoint(graph, instance, de, max_nodes_per_graph=50) dp = deepcopy(dporig) de.encode(dp) self.assertEqual(list(dp.edges.keys()), sorted(list(de.all_edge_types)), "Not all adjacency matrices were created") for edge_type, adj_mat in dp.edges.items(): np.testing.assert_equal(adj_mat.todense(), dporig.subgraph.get_adjacency_matrix(edge_type).todense()) self.assertIsInstance(adj_mat, sp.sparse.coo_matrix, "Encoding produces adjacency matrix of wrong type") self.assertEqual(len(dporig.node_types), len(dp.node_types), "Type for some node got lost during encoding") self.assertEqual([len(i) for i in dporig.node_types], [len(i) for i in dp.node_types], "Some type for some node got lost during encoding") for i in range(len(dp.node_types)): for j in range(len(dp.node_types[i])): self.assertEqual(dp.node_types[i][j], de.all_node_types[dporig.node_types[i][j]], "Some node type got encoded wrong") self.assertEqual(len(dporig.label), len(dp.label), "Some label subtoken got lost during encoding") for i in range(len(dp.label)): self.assertEqual(dp.label[i], de.all_node_name_subtokens[dporig.label[i]])
def test_preprocess_task_existing_encoding_basic_functionality_excluded_edges(self): VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new', excluded_edge_types=syntax_only_excluded_edge_types, data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length), instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20)) de = VarNamingCharCNNDataEncoder.load( os.path.join(self.output_dataset_dir, '{}.pkl'.format(VarNamingCharCNNDataEncoder.__name__))) self.assertEqual(de.excluded_edge_types, syntax_only_excluded_edge_types) self.assertCountEqual(de.all_edge_types, list(syntax_only_edge_types) + ['reverse_' + i for i in syntax_only_edge_types]) datapoints = [os.path.join(self.output_dataset_dir, i) for i in os.listdir(self.output_dataset_dir) if i != 'VarNamingCharCNNDataEncoder.pkl'] for dp in datapoints: datapoint = de.load_datapoint(dp) for e in datapoint.edges.keys(): if e.startswith('reverse_'): self.assertIn(e[8:], syntax_only_edge_types) else: self.assertIn(e, syntax_only_edge_types) VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, excluded_edge_types=syntax_only_excluded_edge_types, data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length)) with self.assertRaises(AssertionError): de = BaseDataEncoder(dict(), frozenset()) VarNamingCharCNN.preprocess_task(self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, excluded_edge_types=syntax_only_excluded_edge_types, data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length))
def test_preprocess_task_existing_encoding_basic_functionality(self): VarNamingCharCNN.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder='new', data_encoder_kwargs=dict( max_name_encoding_length=self.max_name_encoding_length), instance_to_datapoints_kwargs=dict(max_nodes_per_graph=20)) de = VarNamingCharCNNDataEncoder.load( os.path.join(self.output_dataset_dir, '{}.pkl'.format( VarNamingCharCNNDataEncoder.__name__))) VarNamingCharCNN.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, data_encoder_kwargs=dict( excluded_edge_types=syntax_only_excluded_edge_types, max_name_encoding_length=self.max_name_encoding_length)) with self.assertRaises(AssertionError): de = BaseDataEncoder(dict(), frozenset()) VarNamingCharCNN.preprocess_task( self.task, output_dir=self.output_dataset_dir, n_jobs=30, data_encoder=de, data_encoder_kwargs=dict( excluded_edge_types=syntax_only_excluded_edge_types, max_name_encoding_length=self.max_name_encoding_length))
def test_init_finds_all_relevant_dataset_information(self): de = VarNamingCharCNNDataEncoder(self.task.graphs_and_instances, excluded_edge_types=frozenset(), instance_to_datapoints_kwargs=dict(), max_name_encoding_length=self.max_name_encoding_length) self.assertCountEqual(de.all_edge_types, list(all_edge_types), "DataEncoder found weird edge types") self.assertTrue(sorted(de.all_node_types.values()) == list(range(len(de.all_node_types))), "DataEncoder didn't use sequential integers for its type encoding") self.assertEqual(de.max_name_encoding_length, self.max_name_encoding_length) self.assertEqual(de.all_node_types['__PAD__'], 0)