def pad_example(example,
                config,
                allow_failure = False):
  """Pads an example so that it has a static shape determined by the config.

  Args:
    example: The example to pad.
    config: Configuration specifying the desired padding size.
    allow_failure: If True, returns None instead of failing if example is too
      large.

  Returns:
    A padded example with static shape.

  Raises:
    ValueError: If the graph is too big to pad to this size.
  """
  if example.input_graph.tokens.values.shape[0] > config.max_tokens:
    if allow_failure:
      return None
    raise ValueError("Example has too many tokens.")

  bundle = graph_bundle.pad_example(example.input_graph.bundle,
                                    config.bundle_padding, allow_failure)
  if bundle is None:
    return None
  return VarMisuseExample(
      input_graph=GraphBundleWithTokens(
          bundle=bundle,
          tokens=example.input_graph.tokens.pad_nonzeros(config.max_tokens),
      ),
      bug_node_index=example.bug_node_index,
      repair_node_mask=jax_util.pad_to(
          example.repair_node_mask,
          config.bundle_padding.static_max_metadata.num_nodes),
      candidate_node_mask=jax_util.pad_to(
          example.candidate_node_mask,
          config.bundle_padding.static_max_metadata.num_nodes),
      unique_candidate_operator=example.unique_candidate_operator.pad_nonzeros(
          config.bundle_padding.static_max_metadata.num_nodes),
      repair_id=example.repair_id)
Пример #2
0
    def test_zeros_like_padded_example(self):
        tree = gast.parse("pass")
        py_graph, _ = py_ast_graphs.py_ast_to_graph(tree)
        example = graph_bundle.convert_graph_with_edges(
            py_graph, [], builder=py_ast_graphs.BUILDER)

        padding_config = graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=16, num_input_tagged_nodes=34),
            max_initial_transitions=64,
            max_in_tagged_transitions=128,
            max_edges=4)

        padded_example = graph_bundle.pad_example(example, padding_config)
        generated = graph_bundle.zeros_like_padded_example(padding_config)

        def _check(x, y):
            x = np.asarray(x)
            y = np.asarray(y)
            self.assertEqual(x.shape, y.shape)
            self.assertEqual(x.dtype, y.dtype)

        jax.tree_multimap(_check, generated, padded_example)
Пример #3
0
    def test_pad_example(self):
        tree = gast.parse(
            textwrap.dedent("""\
          def foo():
            x = 5
            return x
          """))

        py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree))
        ast_edges = [
            (tree.body[0].body[1], tree.body[0], 1),
            (tree.body[0].body[1].value, tree.body[0].body[0].targets[0], 2),
        ]
        converted_edges = [(ast_to_node_id[id(source)],
                            ast_to_node_id[id(dest)], edge_type)
                           for (source, dest, edge_type) in ast_edges]
        example = graph_bundle.convert_graph_with_edges(
            py_graph, converted_edges, builder=py_ast_graphs.BUILDER)

        padding_config = graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=16, num_input_tagged_nodes=34),
            max_initial_transitions=64,
            max_in_tagged_transitions=128,
            max_edges=4)

        padded_example = graph_bundle.pad_example(example, padding_config)

        # Metadata is not affected by padding.
        self.assertEqual(
            padded_example.graph_metadata,
            automaton_builder.EncodedGraphMetadata(num_nodes=11,
                                                   num_input_tagged_nodes=27))

        # Everything else is padded.
        self.assertEqual(padded_example.node_types.shape, (16, ))

        np.testing.assert_array_equal(padded_example.edges.input_indices,
                                      [[1], [2], [0], [0]])
        np.testing.assert_array_equal(padded_example.edges.output_indices,
                                      [[9, 2], [10, 6], [0, 0], [0, 0]])
        np.testing.assert_array_equal(padded_example.edges.values,
                                      [1, 1, 0, 0])

        self.assertEqual(
            padded_example.automaton_graph.initial_to_in_tagged.values.shape,
            (64, ))
        self.assertEqual(
            padded_example.automaton_graph.initial_to_special.shape, (16, ))
        self.assertEqual(
            padded_example.automaton_graph.in_tagged_to_in_tagged.values.shape,
            (128, ))
        self.assertEqual(
            padded_example.automaton_graph.in_tagged_to_special.shape, (34, ))

        # Transition matrix also becomes padded once it is built.
        # (Note that we pass the padded static metadata to the transition matrix
        # builder, since the encoded graph has been padded.)
        routing_params = py_ast_graphs.BUILDER.initialize_routing_params(
            None, 1, 1, noise_factor=0)
        transition_matrix = py_ast_graphs.BUILDER.build_transition_matrix(
            routing_params, padded_example.automaton_graph,
            padding_config.static_max_metadata)

        self.assertEqual(transition_matrix.initial_to_in_tagged.shape,
                         (1, 16, 1, 34, 1))
        self.assertEqual(transition_matrix.initial_to_special.shape,
                         (1, 16, 1, 3))
        self.assertEqual(transition_matrix.in_tagged_to_in_tagged.shape,
                         (1, 34, 1, 34, 1))
        self.assertEqual(transition_matrix.in_tagged_to_special.shape,
                         (1, 34, 1, 3))
        self.assertEqual(transition_matrix.in_tagged_node_indices.shape,
                         (34, ))