def _init(rng):
     # Set up a dummy stochastic scope for random perturbations.
     with flax.nn.stochastic(jax.random.PRNGKey(0)):
         ex = graph_bundle.zeros_like_padded_example(padding_config)
         ex = jax.tree_map(jnp.array, ex)
         _, initial_params = model_def.init(rng, ex)
     return initial_params
  def test_variants_from_edges(self):
    example = graph_bundle.zeros_like_padded_example(
        graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=5, num_input_tagged_nodes=0),
            max_initial_transitions=0,
            max_in_tagged_transitions=0,
            max_edges=8))
    example = dataclasses.replace(
        example,
        graph_metadata=automaton_builder.EncodedGraphMetadata(
            num_nodes=4, num_input_tagged_nodes=0),
        edges=sparse_operator.SparseCoordOperator(
            input_indices=jnp.array([[0], [0], [0], [1], [1], [2], [0], [0]]),
            output_indices=jnp.array([[1, 2], [2, 3], [3, 0], [2, 0], [0, 2],
                                      [0, 3], [0, 0], [0, 0]]),
            values=jnp.array([1, 1, 1, 1, 1, 1, 0, 0])))

    weights = edge_supervision_models.variants_from_edges(
        example,
        automaton_builder.EncodedGraphMetadata(
            num_nodes=5, num_input_tagged_nodes=0),
        variant_edge_type_indices=[2, 0],
        num_edge_types=3)
    expected = np.array([
        [[1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0]],
        [[1, 0, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]],
        [[1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 0, 1]],
        [[0, 0, 1], [1, 0, 0], [1, 0, 0], [1, 0, 0]],
    ], np.float32)
    # Only assert on the non-padded part.
    np.testing.assert_allclose(weights[:4, :4], expected)
def zeros_like_padded_example(config):
    """Builds a VarMisuseExample containing only zeros.

  This can be useful to initialize model parameters, or do tests.

  Args:
    config: Configuration specifying the desired padding size.

  Returns:
    An "example" filled with zeros of the given size.
  """
    return VarMisuseExample(
        input_graph=GraphBundleWithTokens(
            bundle=graph_bundle.zeros_like_padded_example(
                config.bundle_padding),
            tokens=sparse_operator.SparseCoordOperator(
                input_indices=np.zeros(shape=(config.max_tokens, 1),
                                       dtype=np.int32),
                output_indices=np.zeros(shape=(config.max_tokens, 1),
                                        dtype=np.int32),
                values=np.zeros(shape=(config.max_tokens, ), dtype=np.int32))),
        bug_node_index=-1,
        repair_node_mask=np.zeros(
            shape=(config.bundle_padding.static_max_metadata.num_nodes, ),
            dtype=np.float32),
        candidate_node_mask=np.zeros(
            shape=(config.bundle_padding.static_max_metadata.num_nodes, ),
            dtype=np.float32),
        unique_candidate_operator=sparse_operator.SparseCoordOperator(
            input_indices=np.zeros(shape=(config.max_tokens, 1),
                                   dtype=np.int32),
            output_indices=np.zeros(shape=(config.max_tokens, 1),
                                    dtype=np.int32),
            values=np.zeros(shape=(config.max_tokens, ), dtype=np.float32)),
        repair_id=0)
 def _make_example(self):
   example = graph_bundle.zeros_like_padded_example(
       graph_bundle.PaddingConfig(
           static_max_metadata=automaton_builder.EncodedGraphMetadata(
               num_nodes=5, num_input_tagged_nodes=0),
           max_initial_transitions=0,
           max_in_tagged_transitions=0,
           max_edges=8))
   example = dataclasses.replace(
       example,
       graph_metadata=automaton_builder.EncodedGraphMetadata(
           num_nodes=4, num_input_tagged_nodes=0),
       edges=sparse_operator.SparseCoordOperator(
           input_indices=jnp.array([[0], [0], [0], [0], [1], [2], [2], [0]]),
           output_indices=jnp.array([[1, 2], [2, 3], [2, 2], [3, 0], [0, 2],
                                     [0, 3], [0, 0], [0, 0]]),
           values=jnp.array([1, 1, 1, 1, 1, 1, 0, 0])))
   return example
    def test_component_shapes(self,
                              component,
                              embed_edges,
                              expected_dims,
                              extra_config=None):
        gin.clear_config()
        gin.parse_config(CONFIG)
        if extra_config:
            gin.parse_config(extra_config)

        # Run the computation with placeholder inputs.
        (node_out,
         edge_out), _ = end_to_end_stack.ALL_COMPONENTS[component].init(
             jax.random.PRNGKey(0),
             graph_context=end_to_end_stack.SharedGraphContext(
                 bundle=graph_bundle.zeros_like_padded_example(
                     graph_bundle.PaddingConfig(
                         static_max_metadata=automaton_builder.
                         EncodedGraphMetadata(num_nodes=16,
                                              num_input_tagged_nodes=32),
                         max_initial_transitions=11,
                         max_in_tagged_transitions=12,
                         max_edges=13)),
                 static_metadata=automaton_builder.EncodedGraphMetadata(
                     num_nodes=16, num_input_tagged_nodes=32),
                 edge_types_to_indices={"foo": 0},
                 builder=automaton_builder.AutomatonBuilder({
                     graph_types.NodeType("node"):
                     graph_types.NodeSchema(
                         in_edges=[graph_types.InEdgeType("in")],
                         out_edges=[graph_types.InEdgeType("out")])
                 }),
                 edges_are_embedded=embed_edges),
             node_embeddings=jnp.zeros((16, NODE_DIM)),
             edge_embeddings=jnp.zeros((16, 16, EDGE_DIM)))

        self.assertEqual(node_out.shape, (16, expected_dims["node"]))
        self.assertEqual(edge_out.shape, (16, 16, expected_dims["edge"]))
Ejemplo n.º 6
0
    def test_zeros_like_padded_example(self):
        tree = gast.parse("pass")
        py_graph, _ = py_ast_graphs.py_ast_to_graph(tree)
        example = graph_bundle.convert_graph_with_edges(
            py_graph, [], builder=py_ast_graphs.BUILDER)

        padding_config = graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=16, num_input_tagged_nodes=34),
            max_initial_transitions=64,
            max_in_tagged_transitions=128,
            max_edges=4)

        padded_example = graph_bundle.pad_example(example, padding_config)
        generated = graph_bundle.zeros_like_padded_example(padding_config)

        def _check(x, y):
            x = np.asarray(x)
            y = np.asarray(y)
            self.assertEqual(x.shape, y.shape)
            self.assertEqual(x.dtype, y.dtype)

        jax.tree_multimap(_check, generated, padded_example)
Ejemplo n.º 7
0
def train(
    runner,
    dataset_paths=gin.REQUIRED,
    prefetch=4,
    batch_size_per_device=gin.REQUIRED,
    validation_example_count=gin.REQUIRED,
):
    """Train the maze automaton.

  Args:
    runner: Helper object that runs the experiment.
    dataset_paths: Dictionary of dataset paths, with keys:
      - "train_dataset": Path to training dataset files.
      - "eval_dataset": Path to validation dataset files.
    prefetch: Maximum number of examples to prefetch in a background thread.
    batch_size_per_device: Batch size for each device.
    validation_example_count: How many examples to use when computing validation
      metrics.

  Returns:
    Optimizer at the end of training (for interactive debugging).
  """
    num_devices = jax.local_device_count()
    logging.info("Found %d devices: %s", num_devices, jax.devices())

    with contextlib.ExitStack() as exit_stack:
        logging.info("Setting up datasets...")
        raw_train_iterator = runner.build_sampling_iterator(
            dataset_paths["train_dataset"],
            example_type=graph_bundle.GraphBundle)

        raw_valid_iterator_factory = runner.build_one_pass_iterator_factory(
            dataset_paths["eval_dataset"],
            example_type=graph_bundle.GraphBundle,
            truncate_at=validation_example_count)

        # Add the example id into the example itself, so that we can use it to
        # randomly choose a goal.
        def reify_id(it):
            for item in it:
                yield dataclasses.replace(item,
                                          example=(item.example,
                                                   item.example_id))

        def reify_id_and_batch(it):
            return data_loading.batch(reify_id(it),
                                      (num_devices, batch_size_per_device),
                                      remainder_behavior=data_loading.
                                      BatchRemainderBehavior.PAD_ZERO)

        train_iterator = reify_id_and_batch(raw_train_iterator)
        valid_iterator_factory = (
            lambda: reify_id_and_batch(raw_valid_iterator_factory()))

        if prefetch:
            train_iterator = exit_stack.enter_context(
                data_loading.ThreadedPrefetcher(train_iterator, prefetch))

        logging.info("Setting up model...")
        padding_config = maze_task.PADDING_CONFIG
        model_def = automaton_layer.FiniteStateGraphAutomaton.partial(
            static_metadata=padding_config.static_max_metadata,
            builder=maze_task.BUILDER)

        # Initialize parameters randomly.
        _, initial_params = model_def.init(
            jax.random.PRNGKey(int(time.time() * 1000)),
            graph_bundle.zeros_like_padded_example(
                padding_config).automaton_graph,
            dynamic_metadata=padding_config.static_max_metadata)

        model = flax.nn.Model(model_def, initial_params)
        optimizer = flax.optim.Adam().create(model)

        extra_artifacts = {
            "builder.pickle": maze_task.BUILDER,
        }

        return runner.training_loop(
            optimizer=optimizer,
            train_iterator=train_iterator,
            loss_fn=loss_fn,
            validation_fn=train_util.build_averaging_validator(
                loss_fn, valid_iterator_factory),
            extra_artifacts=extra_artifacts)