def pad_example(example, config, allow_failure = False): """Pads an example so that it has a static shape determined by the config. Args: example: The example to pad. config: Configuration specifying the desired padding size. allow_failure: If True, returns None instead of failing if example is too large. Returns: A padded example with static shape. Raises: ValueError: If the graph is too big to pad to this size. """ if example.input_graph.tokens.values.shape[0] > config.max_tokens: if allow_failure: return None raise ValueError("Example has too many tokens.") bundle = graph_bundle.pad_example(example.input_graph.bundle, config.bundle_padding, allow_failure) if bundle is None: return None return VarMisuseExample( input_graph=GraphBundleWithTokens( bundle=bundle, tokens=example.input_graph.tokens.pad_nonzeros(config.max_tokens), ), bug_node_index=example.bug_node_index, repair_node_mask=jax_util.pad_to( example.repair_node_mask, config.bundle_padding.static_max_metadata.num_nodes), candidate_node_mask=jax_util.pad_to( example.candidate_node_mask, config.bundle_padding.static_max_metadata.num_nodes), unique_candidate_operator=example.unique_candidate_operator.pad_nonzeros( config.bundle_padding.static_max_metadata.num_nodes), repair_id=example.repair_id)
def test_zeros_like_padded_example(self): tree = gast.parse("pass") py_graph, _ = py_ast_graphs.py_ast_to_graph(tree) example = graph_bundle.convert_graph_with_edges( py_graph, [], builder=py_ast_graphs.BUILDER) padding_config = graph_bundle.PaddingConfig( static_max_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=16, num_input_tagged_nodes=34), max_initial_transitions=64, max_in_tagged_transitions=128, max_edges=4) padded_example = graph_bundle.pad_example(example, padding_config) generated = graph_bundle.zeros_like_padded_example(padding_config) def _check(x, y): x = np.asarray(x) y = np.asarray(y) self.assertEqual(x.shape, y.shape) self.assertEqual(x.dtype, y.dtype) jax.tree_multimap(_check, generated, padded_example)
def test_pad_example(self): tree = gast.parse( textwrap.dedent("""\ def foo(): x = 5 return x """)) py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree)) ast_edges = [ (tree.body[0].body[1], tree.body[0], 1), (tree.body[0].body[1].value, tree.body[0].body[0].targets[0], 2), ] converted_edges = [(ast_to_node_id[id(source)], ast_to_node_id[id(dest)], edge_type) for (source, dest, edge_type) in ast_edges] example = graph_bundle.convert_graph_with_edges( py_graph, converted_edges, builder=py_ast_graphs.BUILDER) padding_config = graph_bundle.PaddingConfig( static_max_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=16, num_input_tagged_nodes=34), max_initial_transitions=64, max_in_tagged_transitions=128, max_edges=4) padded_example = graph_bundle.pad_example(example, padding_config) # Metadata is not affected by padding. self.assertEqual( padded_example.graph_metadata, automaton_builder.EncodedGraphMetadata(num_nodes=11, num_input_tagged_nodes=27)) # Everything else is padded. self.assertEqual(padded_example.node_types.shape, (16, )) np.testing.assert_array_equal(padded_example.edges.input_indices, [[1], [2], [0], [0]]) np.testing.assert_array_equal(padded_example.edges.output_indices, [[9, 2], [10, 6], [0, 0], [0, 0]]) np.testing.assert_array_equal(padded_example.edges.values, [1, 1, 0, 0]) self.assertEqual( padded_example.automaton_graph.initial_to_in_tagged.values.shape, (64, )) self.assertEqual( padded_example.automaton_graph.initial_to_special.shape, (16, )) self.assertEqual( padded_example.automaton_graph.in_tagged_to_in_tagged.values.shape, (128, )) self.assertEqual( padded_example.automaton_graph.in_tagged_to_special.shape, (34, )) # Transition matrix also becomes padded once it is built. # (Note that we pass the padded static metadata to the transition matrix # builder, since the encoded graph has been padded.) routing_params = py_ast_graphs.BUILDER.initialize_routing_params( None, 1, 1, noise_factor=0) transition_matrix = py_ast_graphs.BUILDER.build_transition_matrix( routing_params, padded_example.automaton_graph, padding_config.static_max_metadata) self.assertEqual(transition_matrix.initial_to_in_tagged.shape, (1, 16, 1, 34, 1)) self.assertEqual(transition_matrix.initial_to_special.shape, (1, 16, 1, 3)) self.assertEqual(transition_matrix.in_tagged_to_in_tagged.shape, (1, 34, 1, 34, 1)) self.assertEqual(transition_matrix.in_tagged_to_special.shape, (1, 34, 1, 3)) self.assertEqual(transition_matrix.in_tagged_node_indices.shape, (34, ))