def test_component_shapes(self, component, embed_edges, expected_dims, extra_config=None): gin.clear_config() gin.parse_config(CONFIG) if extra_config: gin.parse_config(extra_config) # Run the computation with placeholder inputs. (node_out, edge_out), _ = end_to_end_stack.ALL_COMPONENTS[component].init( jax.random.PRNGKey(0), graph_context=end_to_end_stack.SharedGraphContext( bundle=graph_bundle.zeros_like_padded_example( graph_bundle.PaddingConfig( static_max_metadata=automaton_builder. EncodedGraphMetadata(num_nodes=16, num_input_tagged_nodes=32), max_initial_transitions=11, max_in_tagged_transitions=12, max_edges=13)), static_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=16, num_input_tagged_nodes=32), edge_types_to_indices={"foo": 0}, builder=automaton_builder.AutomatonBuilder({ graph_types.NodeType("node"): graph_types.NodeSchema( in_edges=[graph_types.InEdgeType("in")], out_edges=[graph_types.InEdgeType("out")]) }), edges_are_embedded=embed_edges), node_embeddings=jnp.zeros((16, NODE_DIM)), edge_embeddings=jnp.zeros((16, 16, EDGE_DIM))) self.assertEqual(node_out.shape, (16, expected_dims["node"])) self.assertEqual(edge_out.shape, (16, 16, expected_dims["edge"]))
def token_graph_model_core( input_graph, static_metadata, encoding_info, node_embedding_dim=gin.REQUIRED, edge_embedding_dim=gin.REQUIRED, forward_edge_types=gin.REQUIRED, reverse_edge_types=gin.REQUIRED, components=gin.REQUIRED, ): """Transforms an input graph into final node embeddings, with no output head. Args: input_graph: Input graph for this example. static_metadata: Metadata about the padded size of this graph. encoding_info: How the example was encoded. node_embedding_dim: Dimension of node embedding space. edge_embedding_dim: Dimension of edge embedding space. If None, just concatenate each edge type's adjacency matrix. forward_edge_types: Edge types to use in the forward direction. As a list of lists to allow configuring groups of edges in config files; this will be flattened before use. reverse_edge_types: Edge types to use in the reverse direction. Note that reversed edge types are given a separate embedding from forward edge types; undirected edges should be represented by adding two edges in opposite directions and then only using `forward_edge_types`. Also a list of lists, as above. components: List of sublayer types. Each element should be the name of one of the components defined above. Returns: Final node embeddings after running all components. """ num_node_types = len(encoding_info.builder.node_types) num_edge_types = len(encoding_info.edge_types) edge_types_to_indices = { edge_type: i for i, edge_type in enumerate(encoding_info.edge_types) } # pylint: disable=g-complex-comprehension forward_edge_type_indices = [ edge_types_to_indices[type_str] for group in forward_edge_types for type_str in group ] reverse_edge_type_indices = [ edge_types_to_indices[type_str] for group in reverse_edge_types for type_str in group ] # pylint: enable=g-complex-comprehension # Build initial node embeddings. node_embeddings = (graph_layers.PositionalAndTypeNodeEmbedding( node_types=input_graph.bundle.node_types, num_node_types=num_node_types, embedding_dim=node_embedding_dim, ) + graph_layers.TokenOperatorNodeEmbedding( operator=input_graph.tokens, vocab_size=encoding_info.token_encoder.vocab_size, num_nodes=static_metadata.num_nodes, embedding_dim=node_embedding_dim, )) if edge_embedding_dim is not None: # Learn initial edge embeddings. edge_embeddings = graph_layers.LearnableEdgeEmbeddings( edges=input_graph.bundle.edges, num_nodes=static_metadata.num_nodes, num_edge_types=num_edge_types, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices, embedding_dim=edge_embedding_dim) else: # Build binary edge embeddings. edge_embeddings = graph_layers.binary_index_edge_embeddings( edges=input_graph.bundle.edges, num_nodes=static_metadata.num_nodes, num_edge_types=num_edge_types, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) # Run the core component stack. # Depending on whether edge_embedding_dim is provided, we either concatenate # new edge types or embed them (see end_to_end_stack for details). graph_context = end_to_end_stack.SharedGraphContext( bundle=input_graph.bundle, static_metadata=static_metadata, edge_types_to_indices=edge_types_to_indices, builder=encoding_info.builder, edges_are_embedded=edge_embedding_dim is not None) for component in components: component_fn = end_to_end_stack.ALL_COMPONENTS[component] node_embeddings, edge_embeddings = component_fn( graph_context, node_embeddings, edge_embeddings) return node_embeddings