def bug_conditional_output_head( node_embeddings, output_mask): """Computes a factorized joint probability distribution. First computes a normalized distribution over bug locations, then combines it with a normalized distribution over repairs given bug locations. Args: node_embeddings: Final node embeddings. output_mask: Boolean mask for the bug and repair targets. Returns: NDarray <float32[num_nodes, num_nodes]> of log-probabilities (normalized over non-padding nodes). """ bug_logits = flax.nn.Dense( node_embeddings, features=1, bias=False).squeeze(-1) bug_logits = jnp.where(output_mask, bug_logits, -jnp.inf) bug_logits = jax.nn.log_softmax(bug_logits, axis=0) repair_logits = graph_layers.BilinearPairwiseReadout(node_embeddings) repair_logits = jnp.where(output_mask[None, :], repair_logits, -jnp.inf) repair_logits = jax.nn.log_softmax(repair_logits, axis=1) return bug_logits[:, None] + repair_logits
def embedding_variant_automaton(graph_context, node_embeddings, edge_embeddings, num_variants=gin.REQUIRED): """Runs an automaton with variants based on node embeddings. Args: graph_context: Input graph for this example. node_embeddings: Current node embeddings, as <float32[num_nodes, node_embedding_dim]> edge_embeddings: Current edge embeddings, as <float32[num_nodes, num_nodes, edge_embedding_dim]> num_variants: How many variants to use. Returns: New node and edge embeddings. Node embeddings will not be modified. Edge embeddings will be modified by adding a new edge type (either embedded or concatenated based on graph_context.edges_are_embedded). """ if num_variants <= 1: raise ValueError( "Must have at least one variant to use embedding_variant_automaton." ) # Generate variants using a pairwise readout of the node embeddings. variant_logits = graph_layers.BilinearPairwiseReadout( node_embeddings, num_variants, name="variant_logits") variant_logits = side_outputs.encourage_discrete_logits( variant_logits, distribution_type="categorical", name="variant_logits") variant_weights = jax.nn.softmax(variant_logits) return _shared_automaton_logic(graph_context, node_embeddings, edge_embeddings, variant_weights)
def bilinear_joint_output_head(node_embeddings, output_mask): """Computes a joint probability distribution with a bilinear transformation. Args: node_embeddings: Final node embeddings. output_mask: Boolean mask for the bug and repair targets. Returns: NDarray <float32[num_nodes, num_nodes]> of log-probabilities (normalized over non-padding nodes). """ logits = graph_layers.BilinearPairwiseReadout(node_embeddings) logit_mask = output_mask[:, None] & output_mask[None, :] logits = jnp.where(logit_mask, logits, -jnp.inf) logits = jax.nn.log_softmax(logits, axis=(0, 1)) return logits
def apply( self, example, graph_metadata, edge_types_to_indices, forward_edge_types, reverse_edge_types, use_position_embeddings=gin.REQUIRED, learn_edge_embeddings=gin.REQUIRED, model_type=gin.REQUIRED, nodewise=False, nodewise_loop_chunk_size=None, ): """Single-forward-pass baseline model for edge-supervision task. This model propagates information through the graph, then does a pairwise bilinear readout to determine which edges to add. Args: example: Example to run the automaton on. graph_metadata: Statically-known metadata about the graph size. If encoded_graph is padded, this should reflect the padded size, not the original size. edge_types_to_indices: Mapping from edge type names to edge type indices. forward_edge_types: Edge types to use in the forward direction. As a list of lists to allow configuring groups of edges in config files; this will be flattened before use. reverse_edge_types: Edge types to use in the reverse direction. Note that reversed edge types are given a separate embedding from forward edge types; undirected edges should be represented by adding two edges in opposite directions and then only using `forward_edge_types`. Also a list of lists, as above. use_position_embeddings: Whether to add position embeddings to node embeddings. learn_edge_embeddings: Whether to learn an edge embedding for each edge type (instead of using a one-hot embedding). model_type: One of {"ggnn", "transformer", "nri_encoder"} nodewise: Whether to have separate sets of node embeddings for each possible start node. nodewise_loop_chunk_size: Optional integer, which must divide the number of nodes. Splits the nodes into chunks of this size, and runs the model on each of those splits in a loop; this recudes the memory usage. Only used when nodewise=True. Returns: <float32[num_nodes, num_nodes]> matrix of binary logits for a weighted adjacency matrix corresponding to the predicted output edges. """ # Node types come directly from the schema (for parity with the automaton). num_node_types = len(py_ast_graphs.BUILDER.node_types) # Edge types are potentially task-specific. num_edge_types = len(edge_types_to_indices) # pylint: disable=g-complex-comprehension forward_edge_type_indices = [ edge_types_to_indices[type_str] for group in forward_edge_types for type_str in group ] reverse_edge_type_indices = [ edge_types_to_indices[type_str] for group in reverse_edge_types for type_str in group ] # pylint: enable=g-complex-comprehension # Embed the nodes. if use_position_embeddings: node_embeddings = graph_layers.PositionalAndTypeNodeEmbedding( node_types=example.node_types, num_node_types=num_node_types) else: node_embeddings = graph_layers.NodeTypeNodeEmbedding( node_types=example.node_types, num_node_types=num_node_types) # Embed the edges. if learn_edge_embeddings: edge_embeddings = graph_layers.LearnableEdgeEmbeddings( edges=example.edges, num_nodes=graph_metadata.num_nodes, num_edge_types=num_edge_types, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) else: edge_embeddings = graph_layers.binary_index_edge_embeddings( edges=example.edges, num_nodes=graph_metadata.num_nodes, num_edge_types=num_edge_types, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) def run_steps(node_embeddings): """Runs propagation and updates.""" if model_type == "ggnn": final_embeddings = ggnn_steps(node_embeddings, edge_embeddings) elif model_type == "transformer": neighbor_mask = graph_layers.edge_mask( edges=example.edges, num_nodes=graph_metadata.num_nodes, num_edge_types=num_edge_types, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) # Allow nodes to attend to themselves neighbor_mask = jnp.maximum(neighbor_mask, jnp.eye(graph_metadata.num_nodes)) final_embeddings = transformer_steps( node_embeddings, edge_embeddings, neighbor_mask, num_real_nodes_per_graph=example.graph_metadata.num_nodes) elif model_type == "nri_encoder": final_embeddings = nri_steps( node_embeddings, edge_embeddings, num_real_nodes_per_graph=example.graph_metadata.num_nodes) return final_embeddings if nodewise: assert model_type != "nri_encoder", "Nodewise NRI model is not defined." # Add in a learned start node embedding, and broadcast node states out to # <float32[num_nodes, num_nodes, node_embedding_dim]> node_embedding_dim = node_embeddings.shape[-1] start_node_embedding = self.param( "start_node_embedding", shape=(node_embedding_dim, ), initializer=jax.nn.initializers.normal()) stacked_node_embeddings = (jnp.broadcast_to( node_embeddings[None, :, :], (graph_metadata.num_nodes, graph_metadata.num_nodes, node_embedding_dim)).at[ jnp.arange(graph_metadata.num_nodes), jnp.arange(graph_metadata.num_nodes)].add( jnp.broadcast_to( start_node_embedding, (graph_metadata.num_nodes, node_embedding_dim)))) # final_embeddings_from_each_source will be # [num_nodes, num_nodes, node_embedding_dim] if nodewise_loop_chunk_size and not self.is_initializing(): grouped_stacked_node_embeddings = stacked_node_embeddings.reshape( (-1, nodewise_loop_chunk_size, graph_metadata.num_nodes, node_embedding_dim)) grouped_final_embeddings_from_each_source = jax.lax.map( jax.vmap(run_steps), grouped_stacked_node_embeddings) final_embeddings_from_each_source = ( grouped_final_embeddings_from_each_source.reshape( (graph_metadata.num_nodes, ) + grouped_final_embeddings_from_each_source.shape[2:])) else: final_embeddings_from_each_source = jax.vmap(run_steps)( stacked_node_embeddings) # Extract predictions with a linear transformation. logits = flax.deprecated.nn.Dense( final_embeddings_from_each_source, features=1, name="target_readout", ).squeeze(-1) elif model_type == "nri_encoder": # Propagate the node embeddings as-is, then use NRI to construct edges. final_embeddings = run_steps(node_embeddings) logits = graph_layers.NRIReadout(final_embeddings) else: # Propagate the node embeddings as-is, then extract edges pairwise. final_embeddings = run_steps(node_embeddings) logits = graph_layers.BilinearPairwiseReadout(final_embeddings) return logits