def bug_conditional_output_head(
    node_embeddings,
    output_mask):
  """Computes a factorized joint probability distribution.

  First computes a normalized distribution over bug locations, then combines
  it with a normalized distribution over repairs given bug locations.

  Args:
    node_embeddings: Final node embeddings.
    output_mask: Boolean mask for the bug and repair targets.

  Returns:
    NDarray <float32[num_nodes, num_nodes]> of log-probabilities (normalized
      over non-padding nodes).
  """
  bug_logits = flax.nn.Dense(
      node_embeddings, features=1, bias=False).squeeze(-1)
  bug_logits = jnp.where(output_mask, bug_logits, -jnp.inf)
  bug_logits = jax.nn.log_softmax(bug_logits, axis=0)

  repair_logits = graph_layers.BilinearPairwiseReadout(node_embeddings)
  repair_logits = jnp.where(output_mask[None, :], repair_logits, -jnp.inf)
  repair_logits = jax.nn.log_softmax(repair_logits, axis=1)

  return bug_logits[:, None] + repair_logits
def embedding_variant_automaton(graph_context,
                                node_embeddings,
                                edge_embeddings,
                                num_variants=gin.REQUIRED):
    """Runs an automaton with variants based on node embeddings.

  Args:
    graph_context: Input graph for this example.
    node_embeddings: Current node embeddings, as <float32[num_nodes,
      node_embedding_dim]>
    edge_embeddings: Current edge embeddings, as <float32[num_nodes, num_nodes,
      edge_embedding_dim]>
    num_variants: How many variants to use.

  Returns:
    New node and edge embeddings. Node embeddings will not be modified. Edge
    embeddings will be modified by adding a new edge type (either embedded or
    concatenated based on graph_context.edges_are_embedded).
  """
    if num_variants <= 1:
        raise ValueError(
            "Must have at least one variant to use embedding_variant_automaton."
        )
    # Generate variants using a pairwise readout of the node embeddings.
    variant_logits = graph_layers.BilinearPairwiseReadout(
        node_embeddings, num_variants, name="variant_logits")

    variant_logits = side_outputs.encourage_discrete_logits(
        variant_logits, distribution_type="categorical", name="variant_logits")
    variant_weights = jax.nn.softmax(variant_logits)

    return _shared_automaton_logic(graph_context, node_embeddings,
                                   edge_embeddings, variant_weights)
def bilinear_joint_output_head(node_embeddings, output_mask):
    """Computes a joint probability distribution with a bilinear transformation.

  Args:
    node_embeddings: Final node embeddings.
    output_mask: Boolean mask for the bug and repair targets.

  Returns:
    NDarray <float32[num_nodes, num_nodes]> of log-probabilities (normalized
      over non-padding nodes).
  """
    logits = graph_layers.BilinearPairwiseReadout(node_embeddings)
    logit_mask = output_mask[:, None] & output_mask[None, :]
    logits = jnp.where(logit_mask, logits, -jnp.inf)
    logits = jax.nn.log_softmax(logits, axis=(0, 1))
    return logits
    def apply(
        self,
        example,
        graph_metadata,
        edge_types_to_indices,
        forward_edge_types,
        reverse_edge_types,
        use_position_embeddings=gin.REQUIRED,
        learn_edge_embeddings=gin.REQUIRED,
        model_type=gin.REQUIRED,
        nodewise=False,
        nodewise_loop_chunk_size=None,
    ):
        """Single-forward-pass baseline model for edge-supervision task.

    This model propagates information through the graph, then does a pairwise
    bilinear readout to determine which edges to add.

    Args:
      example: Example to run the automaton on.
      graph_metadata: Statically-known metadata about the graph size. If
        encoded_graph is padded, this should reflect the padded size, not the
        original size.
      edge_types_to_indices: Mapping from edge type names to edge type indices.
      forward_edge_types: Edge types to use in the forward direction. As a list
        of lists to allow configuring groups of edges in config files; this will
        be flattened before use.
      reverse_edge_types: Edge types to use in the reverse direction. Note that
        reversed edge types are given a separate embedding from forward edge
        types; undirected edges should be represented by adding two edges in
        opposite directions and then only using `forward_edge_types`. Also a
        list of lists, as above.
      use_position_embeddings: Whether to add position embeddings to node
        embeddings.
      learn_edge_embeddings: Whether to learn an edge embedding for each edge
        type (instead of using a one-hot embedding).
      model_type: One of {"ggnn", "transformer", "nri_encoder"}
      nodewise: Whether to have separate sets of node embeddings for each
        possible start node.
      nodewise_loop_chunk_size: Optional integer, which must divide the number
        of nodes. Splits the nodes into chunks of this size, and runs the model
        on each of those splits in a loop; this recudes the memory usage. Only
        used when nodewise=True.

    Returns:
      <float32[num_nodes, num_nodes]> matrix of binary logits for a weighted
      adjacency matrix corresponding to the predicted output edges.
    """
        # Node types come directly from the schema (for parity with the automaton).
        num_node_types = len(py_ast_graphs.BUILDER.node_types)
        # Edge types are potentially task-specific.
        num_edge_types = len(edge_types_to_indices)
        # pylint: disable=g-complex-comprehension
        forward_edge_type_indices = [
            edge_types_to_indices[type_str] for group in forward_edge_types
            for type_str in group
        ]
        reverse_edge_type_indices = [
            edge_types_to_indices[type_str] for group in reverse_edge_types
            for type_str in group
        ]
        # pylint: enable=g-complex-comprehension

        # Embed the nodes.
        if use_position_embeddings:
            node_embeddings = graph_layers.PositionalAndTypeNodeEmbedding(
                node_types=example.node_types, num_node_types=num_node_types)
        else:
            node_embeddings = graph_layers.NodeTypeNodeEmbedding(
                node_types=example.node_types, num_node_types=num_node_types)

        # Embed the edges.
        if learn_edge_embeddings:
            edge_embeddings = graph_layers.LearnableEdgeEmbeddings(
                edges=example.edges,
                num_nodes=graph_metadata.num_nodes,
                num_edge_types=num_edge_types,
                forward_edge_type_indices=forward_edge_type_indices,
                reverse_edge_type_indices=reverse_edge_type_indices)
        else:
            edge_embeddings = graph_layers.binary_index_edge_embeddings(
                edges=example.edges,
                num_nodes=graph_metadata.num_nodes,
                num_edge_types=num_edge_types,
                forward_edge_type_indices=forward_edge_type_indices,
                reverse_edge_type_indices=reverse_edge_type_indices)

        def run_steps(node_embeddings):
            """Runs propagation and updates."""
            if model_type == "ggnn":
                final_embeddings = ggnn_steps(node_embeddings, edge_embeddings)
            elif model_type == "transformer":
                neighbor_mask = graph_layers.edge_mask(
                    edges=example.edges,
                    num_nodes=graph_metadata.num_nodes,
                    num_edge_types=num_edge_types,
                    forward_edge_type_indices=forward_edge_type_indices,
                    reverse_edge_type_indices=reverse_edge_type_indices)
                # Allow nodes to attend to themselves
                neighbor_mask = jnp.maximum(neighbor_mask,
                                            jnp.eye(graph_metadata.num_nodes))
                final_embeddings = transformer_steps(
                    node_embeddings,
                    edge_embeddings,
                    neighbor_mask,
                    num_real_nodes_per_graph=example.graph_metadata.num_nodes)
            elif model_type == "nri_encoder":
                final_embeddings = nri_steps(
                    node_embeddings,
                    edge_embeddings,
                    num_real_nodes_per_graph=example.graph_metadata.num_nodes)
            return final_embeddings

        if nodewise:
            assert model_type != "nri_encoder", "Nodewise NRI model is not defined."
            # Add in a learned start node embedding, and broadcast node states out to
            # <float32[num_nodes, num_nodes, node_embedding_dim]>
            node_embedding_dim = node_embeddings.shape[-1]
            start_node_embedding = self.param(
                "start_node_embedding",
                shape=(node_embedding_dim, ),
                initializer=jax.nn.initializers.normal())

            stacked_node_embeddings = (jnp.broadcast_to(
                node_embeddings[None, :, :],
                (graph_metadata.num_nodes, graph_metadata.num_nodes,
                 node_embedding_dim)).at[
                     jnp.arange(graph_metadata.num_nodes),
                     jnp.arange(graph_metadata.num_nodes)].add(
                         jnp.broadcast_to(
                             start_node_embedding,
                             (graph_metadata.num_nodes, node_embedding_dim))))

            # final_embeddings_from_each_source will be
            # [num_nodes, num_nodes, node_embedding_dim]
            if nodewise_loop_chunk_size and not self.is_initializing():
                grouped_stacked_node_embeddings = stacked_node_embeddings.reshape(
                    (-1, nodewise_loop_chunk_size, graph_metadata.num_nodes,
                     node_embedding_dim))
                grouped_final_embeddings_from_each_source = jax.lax.map(
                    jax.vmap(run_steps), grouped_stacked_node_embeddings)
                final_embeddings_from_each_source = (
                    grouped_final_embeddings_from_each_source.reshape(
                        (graph_metadata.num_nodes, ) +
                        grouped_final_embeddings_from_each_source.shape[2:]))
            else:
                final_embeddings_from_each_source = jax.vmap(run_steps)(
                    stacked_node_embeddings)

            # Extract predictions with a linear transformation.
            logits = flax.deprecated.nn.Dense(
                final_embeddings_from_each_source,
                features=1,
                name="target_readout",
            ).squeeze(-1)

        elif model_type == "nri_encoder":
            # Propagate the node embeddings as-is, then use NRI to construct edges.
            final_embeddings = run_steps(node_embeddings)
            logits = graph_layers.NRIReadout(final_embeddings)

        else:
            # Propagate the node embeddings as-is, then extract edges pairwise.
            final_embeddings = run_steps(node_embeddings)
            logits = graph_layers.BilinearPairwiseReadout(final_embeddings)

        return logits