def test_component_shapes(self,
                              component,
                              embed_edges,
                              expected_dims,
                              extra_config=None):
        gin.clear_config()
        gin.parse_config(CONFIG)
        if extra_config:
            gin.parse_config(extra_config)

        # Run the computation with placeholder inputs.
        (node_out,
         edge_out), _ = end_to_end_stack.ALL_COMPONENTS[component].init(
             jax.random.PRNGKey(0),
             graph_context=end_to_end_stack.SharedGraphContext(
                 bundle=graph_bundle.zeros_like_padded_example(
                     graph_bundle.PaddingConfig(
                         static_max_metadata=automaton_builder.
                         EncodedGraphMetadata(num_nodes=16,
                                              num_input_tagged_nodes=32),
                         max_initial_transitions=11,
                         max_in_tagged_transitions=12,
                         max_edges=13)),
                 static_metadata=automaton_builder.EncodedGraphMetadata(
                     num_nodes=16, num_input_tagged_nodes=32),
                 edge_types_to_indices={"foo": 0},
                 builder=automaton_builder.AutomatonBuilder({
                     graph_types.NodeType("node"):
                     graph_types.NodeSchema(
                         in_edges=[graph_types.InEdgeType("in")],
                         out_edges=[graph_types.InEdgeType("out")])
                 }),
                 edges_are_embedded=embed_edges),
             node_embeddings=jnp.zeros((16, NODE_DIM)),
             edge_embeddings=jnp.zeros((16, 16, EDGE_DIM)))

        self.assertEqual(node_out.shape, (16, expected_dims["node"]))
        self.assertEqual(edge_out.shape, (16, 16, expected_dims["edge"]))
def token_graph_model_core(
    input_graph,
    static_metadata,
    encoding_info,
    node_embedding_dim=gin.REQUIRED,
    edge_embedding_dim=gin.REQUIRED,
    forward_edge_types=gin.REQUIRED,
    reverse_edge_types=gin.REQUIRED,
    components=gin.REQUIRED,
):
    """Transforms an input graph into final node embeddings, with no output head.

  Args:
    input_graph: Input graph for this example.
    static_metadata: Metadata about the padded size of this graph.
    encoding_info: How the example was encoded.
    node_embedding_dim: Dimension of node embedding space.
    edge_embedding_dim: Dimension of edge embedding space. If None, just
      concatenate each edge type's adjacency matrix.
    forward_edge_types: Edge types to use in the forward direction. As a list of
      lists to allow configuring groups of edges in config files; this will be
      flattened before use.
    reverse_edge_types: Edge types to use in the reverse direction. Note that
      reversed edge types are given a separate embedding from forward edge
      types; undirected edges should be represented by adding two edges in
      opposite directions and then only using `forward_edge_types`. Also a list
      of lists, as above.
    components: List of sublayer types. Each element should be the name of one
      of the components defined above.

  Returns:
    Final node embeddings after running all components.
  """
    num_node_types = len(encoding_info.builder.node_types)
    num_edge_types = len(encoding_info.edge_types)
    edge_types_to_indices = {
        edge_type: i
        for i, edge_type in enumerate(encoding_info.edge_types)
    }
    # pylint: disable=g-complex-comprehension
    forward_edge_type_indices = [
        edge_types_to_indices[type_str] for group in forward_edge_types
        for type_str in group
    ]
    reverse_edge_type_indices = [
        edge_types_to_indices[type_str] for group in reverse_edge_types
        for type_str in group
    ]
    # pylint: enable=g-complex-comprehension

    # Build initial node embeddings.
    node_embeddings = (graph_layers.PositionalAndTypeNodeEmbedding(
        node_types=input_graph.bundle.node_types,
        num_node_types=num_node_types,
        embedding_dim=node_embedding_dim,
    ) + graph_layers.TokenOperatorNodeEmbedding(
        operator=input_graph.tokens,
        vocab_size=encoding_info.token_encoder.vocab_size,
        num_nodes=static_metadata.num_nodes,
        embedding_dim=node_embedding_dim,
    ))

    if edge_embedding_dim is not None:
        # Learn initial edge embeddings.
        edge_embeddings = graph_layers.LearnableEdgeEmbeddings(
            edges=input_graph.bundle.edges,
            num_nodes=static_metadata.num_nodes,
            num_edge_types=num_edge_types,
            forward_edge_type_indices=forward_edge_type_indices,
            reverse_edge_type_indices=reverse_edge_type_indices,
            embedding_dim=edge_embedding_dim)
    else:
        # Build binary edge embeddings.
        edge_embeddings = graph_layers.binary_index_edge_embeddings(
            edges=input_graph.bundle.edges,
            num_nodes=static_metadata.num_nodes,
            num_edge_types=num_edge_types,
            forward_edge_type_indices=forward_edge_type_indices,
            reverse_edge_type_indices=reverse_edge_type_indices)

    # Run the core component stack.
    # Depending on whether edge_embedding_dim is provided, we either concatenate
    # new edge types or embed them (see end_to_end_stack for details).
    graph_context = end_to_end_stack.SharedGraphContext(
        bundle=input_graph.bundle,
        static_metadata=static_metadata,
        edge_types_to_indices=edge_types_to_indices,
        builder=encoding_info.builder,
        edges_are_embedded=edge_embedding_dim is not None)
    for component in components:
        component_fn = end_to_end_stack.ALL_COMPONENTS[component]
        node_embeddings, edge_embeddings = component_fn(
            graph_context, node_embeddings, edge_embeddings)

    return node_embeddings