def automaton_model(example,
                    graph_metadata,
                    edge_types_to_indices,
                    variant_edge_types=(),
                    platt_scale=False,
                    with_backtrack=True):
    """Automaton-based module for edge supervision task.

  Args:
    example: Example to run the automaton on.
    graph_metadata: Statically-known metadata about the graph size. If
      encoded_graph is padded, this should reflect the padded size, not the
      original size.
    edge_types_to_indices: Mapping from edge type names to edge type indices.
    variant_edge_types: Edge types to use as variants. Assumes without checking
      that the given variants are mutually exclusive (at most one edge of one of
      these types exists between any pair of nodes).
    platt_scale: Whether to scale and shift the logits produced by the
      automaton. This can be viewed as a form of Platt scaling applied to the
      automaton logits. If True, this allows the model's output probabilities to
      sum to more than 1, so that it can express one-to-many relations.
    with_backtrack: Whether the automaton can restart the search as an action.

  Returns:
    <float32[num_nodes, num_nodes]> matrix of binary logits for a weighted
    adjacency matrix corresponding to the predicted output edges.
  """
    if variant_edge_types:
        variant_edge_type_indices = [
            edge_types_to_indices[type_str] for type_str in variant_edge_types
        ]
        num_edge_types = len(edge_types_to_indices)
        variant_weights = variants_from_edges(example, graph_metadata,
                                              variant_edge_type_indices,
                                              num_edge_types)
    else:
        variant_weights = None

    absorbing_probs = automaton_layer.FiniteStateGraphAutomaton(
        encoded_graph=example.automaton_graph,
        variant_weights=variant_weights,
        static_metadata=graph_metadata,
        dynamic_metadata=example.graph_metadata,
        builder=automaton_builder.AutomatonBuilder(
            py_ast_graphs.SCHEMA, with_backtrack=with_backtrack),
        num_out_edges=1,
        share_states_across_edges=True).squeeze(axis=0)

    logits = model_util.safe_logit(absorbing_probs)

    if platt_scale:
        logits = model_util.ScaleAndShift(logits)

    return logits
def _shared_automaton_logic(graph_context, node_embeddings, edge_embeddings,
                            variant_weights):
    """Helper function for shared automaton logic."""

    # Run the automaton.
    edge_weights = automaton_layer.FiniteStateGraphAutomaton(
        encoded_graph=graph_context.bundle.automaton_graph,
        variant_weights=variant_weights,
        dynamic_metadata=graph_context.bundle.graph_metadata,
        static_metadata=graph_context.static_metadata,
        builder=graph_context.builder)

    return (node_embeddings,
            _add_edges(edge_embeddings, edge_weights.transpose([1, 2, 0]),
                       graph_context.edges_are_embedded))
 def apply(self, dummy_ignored):
   abstract_encoded_graph = jax.tree_map(
       lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph)
   abstract_variant_weights = jax.tree_map(
       lambda y: jax.lax.tie_in(dummy_ignored, y), variant_weights())
   return automaton_layer.FiniteStateGraphAutomaton(
       encoded_graph=abstract_encoded_graph,
       variant_weights=abstract_variant_weights,
       dynamic_metadata=automaton_builder.EncodedGraphMetadata(
           num_nodes=32, num_input_tagged_nodes=64),
       static_metadata=automaton_builder.EncodedGraphMetadata(
           num_nodes=32, num_input_tagged_nodes=64),
       builder=builder,
       num_out_edges=3,
       num_intermediate_states=4,
       share_states_across_edges=shared,
       use_gate_parameterization=use_gate,
       estimator_type=estimator_type,
       name="the_layer",
       **kwargs)