def test_binary_index_edge_embeddings(self): (edges, forward_edge_type_indices, reverse_edge_type_indices) = self._setup_edges() result = graph_layers.binary_index_edge_embeddings( edges, num_nodes=4, num_edge_types=3, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) expected = np.zeros([4, 4, 3], np.float32) expected[2, 0, 0] = 1 expected[0, 1, 1] = 1 expected[1, 2, 1] = 1 expected[0, 2, 1] = 1 expected[1, 0, 2] = 1 expected[2, 1, 2] = 1 expected[2, 0, 2] = 1 np.testing.assert_allclose(result, expected)
def token_graph_model_core( input_graph, static_metadata, encoding_info, node_embedding_dim=gin.REQUIRED, edge_embedding_dim=gin.REQUIRED, forward_edge_types=gin.REQUIRED, reverse_edge_types=gin.REQUIRED, components=gin.REQUIRED, ): """Transforms an input graph into final node embeddings, with no output head. Args: input_graph: Input graph for this example. static_metadata: Metadata about the padded size of this graph. encoding_info: How the example was encoded. node_embedding_dim: Dimension of node embedding space. edge_embedding_dim: Dimension of edge embedding space. If None, just concatenate each edge type's adjacency matrix. forward_edge_types: Edge types to use in the forward direction. As a list of lists to allow configuring groups of edges in config files; this will be flattened before use. reverse_edge_types: Edge types to use in the reverse direction. Note that reversed edge types are given a separate embedding from forward edge types; undirected edges should be represented by adding two edges in opposite directions and then only using `forward_edge_types`. Also a list of lists, as above. components: List of sublayer types. Each element should be the name of one of the components defined above. Returns: Final node embeddings after running all components. """ num_node_types = len(encoding_info.builder.node_types) num_edge_types = len(encoding_info.edge_types) edge_types_to_indices = { edge_type: i for i, edge_type in enumerate(encoding_info.edge_types) } # pylint: disable=g-complex-comprehension forward_edge_type_indices = [ edge_types_to_indices[type_str] for group in forward_edge_types for type_str in group ] reverse_edge_type_indices = [ edge_types_to_indices[type_str] for group in reverse_edge_types for type_str in group ] # pylint: enable=g-complex-comprehension # Build initial node embeddings. node_embeddings = (graph_layers.PositionalAndTypeNodeEmbedding( node_types=input_graph.bundle.node_types, num_node_types=num_node_types, embedding_dim=node_embedding_dim, ) + graph_layers.TokenOperatorNodeEmbedding( operator=input_graph.tokens, vocab_size=encoding_info.token_encoder.vocab_size, num_nodes=static_metadata.num_nodes, embedding_dim=node_embedding_dim, )) if edge_embedding_dim is not None: # Learn initial edge embeddings. edge_embeddings = graph_layers.LearnableEdgeEmbeddings( edges=input_graph.bundle.edges, num_nodes=static_metadata.num_nodes, num_edge_types=num_edge_types, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices, embedding_dim=edge_embedding_dim) else: # Build binary edge embeddings. edge_embeddings = graph_layers.binary_index_edge_embeddings( edges=input_graph.bundle.edges, num_nodes=static_metadata.num_nodes, num_edge_types=num_edge_types, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) # Run the core component stack. # Depending on whether edge_embedding_dim is provided, we either concatenate # new edge types or embed them (see end_to_end_stack for details). graph_context = end_to_end_stack.SharedGraphContext( bundle=input_graph.bundle, static_metadata=static_metadata, edge_types_to_indices=edge_types_to_indices, builder=encoding_info.builder, edges_are_embedded=edge_embedding_dim is not None) for component in components: component_fn = end_to_end_stack.ALL_COMPONENTS[component] node_embeddings, edge_embeddings = component_fn( graph_context, node_embeddings, edge_embeddings) return node_embeddings
def apply( self, example, graph_metadata, edge_types_to_indices, forward_edge_types, reverse_edge_types, use_position_embeddings=gin.REQUIRED, learn_edge_embeddings=gin.REQUIRED, model_type=gin.REQUIRED, nodewise=False, nodewise_loop_chunk_size=None, ): """Single-forward-pass baseline model for edge-supervision task. This model propagates information through the graph, then does a pairwise bilinear readout to determine which edges to add. Args: example: Example to run the automaton on. graph_metadata: Statically-known metadata about the graph size. If encoded_graph is padded, this should reflect the padded size, not the original size. edge_types_to_indices: Mapping from edge type names to edge type indices. forward_edge_types: Edge types to use in the forward direction. As a list of lists to allow configuring groups of edges in config files; this will be flattened before use. reverse_edge_types: Edge types to use in the reverse direction. Note that reversed edge types are given a separate embedding from forward edge types; undirected edges should be represented by adding two edges in opposite directions and then only using `forward_edge_types`. Also a list of lists, as above. use_position_embeddings: Whether to add position embeddings to node embeddings. learn_edge_embeddings: Whether to learn an edge embedding for each edge type (instead of using a one-hot embedding). model_type: One of {"ggnn", "transformer", "nri_encoder"} nodewise: Whether to have separate sets of node embeddings for each possible start node. nodewise_loop_chunk_size: Optional integer, which must divide the number of nodes. Splits the nodes into chunks of this size, and runs the model on each of those splits in a loop; this recudes the memory usage. Only used when nodewise=True. Returns: <float32[num_nodes, num_nodes]> matrix of binary logits for a weighted adjacency matrix corresponding to the predicted output edges. """ # Node types come directly from the schema (for parity with the automaton). num_node_types = len(py_ast_graphs.BUILDER.node_types) # Edge types are potentially task-specific. num_edge_types = len(edge_types_to_indices) # pylint: disable=g-complex-comprehension forward_edge_type_indices = [ edge_types_to_indices[type_str] for group in forward_edge_types for type_str in group ] reverse_edge_type_indices = [ edge_types_to_indices[type_str] for group in reverse_edge_types for type_str in group ] # pylint: enable=g-complex-comprehension # Embed the nodes. if use_position_embeddings: node_embeddings = graph_layers.PositionalAndTypeNodeEmbedding( node_types=example.node_types, num_node_types=num_node_types) else: node_embeddings = graph_layers.NodeTypeNodeEmbedding( node_types=example.node_types, num_node_types=num_node_types) # Embed the edges. if learn_edge_embeddings: edge_embeddings = graph_layers.LearnableEdgeEmbeddings( edges=example.edges, num_nodes=graph_metadata.num_nodes, num_edge_types=num_edge_types, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) else: edge_embeddings = graph_layers.binary_index_edge_embeddings( edges=example.edges, num_nodes=graph_metadata.num_nodes, num_edge_types=num_edge_types, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) def run_steps(node_embeddings): """Runs propagation and updates.""" if model_type == "ggnn": final_embeddings = ggnn_steps(node_embeddings, edge_embeddings) elif model_type == "transformer": neighbor_mask = graph_layers.edge_mask( edges=example.edges, num_nodes=graph_metadata.num_nodes, num_edge_types=num_edge_types, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) # Allow nodes to attend to themselves neighbor_mask = jnp.maximum(neighbor_mask, jnp.eye(graph_metadata.num_nodes)) final_embeddings = transformer_steps( node_embeddings, edge_embeddings, neighbor_mask, num_real_nodes_per_graph=example.graph_metadata.num_nodes) elif model_type == "nri_encoder": final_embeddings = nri_steps( node_embeddings, edge_embeddings, num_real_nodes_per_graph=example.graph_metadata.num_nodes) return final_embeddings if nodewise: assert model_type != "nri_encoder", "Nodewise NRI model is not defined." # Add in a learned start node embedding, and broadcast node states out to # <float32[num_nodes, num_nodes, node_embedding_dim]> node_embedding_dim = node_embeddings.shape[-1] start_node_embedding = self.param( "start_node_embedding", shape=(node_embedding_dim, ), initializer=jax.nn.initializers.normal()) stacked_node_embeddings = (jnp.broadcast_to( node_embeddings[None, :, :], (graph_metadata.num_nodes, graph_metadata.num_nodes, node_embedding_dim)).at[ jnp.arange(graph_metadata.num_nodes), jnp.arange(graph_metadata.num_nodes)].add( jnp.broadcast_to( start_node_embedding, (graph_metadata.num_nodes, node_embedding_dim)))) # final_embeddings_from_each_source will be # [num_nodes, num_nodes, node_embedding_dim] if nodewise_loop_chunk_size and not self.is_initializing(): grouped_stacked_node_embeddings = stacked_node_embeddings.reshape( (-1, nodewise_loop_chunk_size, graph_metadata.num_nodes, node_embedding_dim)) grouped_final_embeddings_from_each_source = jax.lax.map( jax.vmap(run_steps), grouped_stacked_node_embeddings) final_embeddings_from_each_source = ( grouped_final_embeddings_from_each_source.reshape( (graph_metadata.num_nodes, ) + grouped_final_embeddings_from_each_source.shape[2:])) else: final_embeddings_from_each_source = jax.vmap(run_steps)( stacked_node_embeddings) # Extract predictions with a linear transformation. logits = flax.deprecated.nn.Dense( final_embeddings_from_each_source, features=1, name="target_readout", ).squeeze(-1) elif model_type == "nri_encoder": # Propagate the node embeddings as-is, then use NRI to construct edges. final_embeddings = run_steps(node_embeddings) logits = graph_layers.NRIReadout(final_embeddings) else: # Propagate the node embeddings as-is, then extract edges pairwise. final_embeddings = run_steps(node_embeddings) logits = graph_layers.BilinearPairwiseReadout(final_embeddings) return logits