def automaton_model(example, graph_metadata, edge_types_to_indices, variant_edge_types=(), platt_scale=False, with_backtrack=True): """Automaton-based module for edge supervision task. Args: example: Example to run the automaton on. graph_metadata: Statically-known metadata about the graph size. If encoded_graph is padded, this should reflect the padded size, not the original size. edge_types_to_indices: Mapping from edge type names to edge type indices. variant_edge_types: Edge types to use as variants. Assumes without checking that the given variants are mutually exclusive (at most one edge of one of these types exists between any pair of nodes). platt_scale: Whether to scale and shift the logits produced by the automaton. This can be viewed as a form of Platt scaling applied to the automaton logits. If True, this allows the model's output probabilities to sum to more than 1, so that it can express one-to-many relations. with_backtrack: Whether the automaton can restart the search as an action. Returns: <float32[num_nodes, num_nodes]> matrix of binary logits for a weighted adjacency matrix corresponding to the predicted output edges. """ if variant_edge_types: variant_edge_type_indices = [ edge_types_to_indices[type_str] for type_str in variant_edge_types ] num_edge_types = len(edge_types_to_indices) variant_weights = variants_from_edges(example, graph_metadata, variant_edge_type_indices, num_edge_types) else: variant_weights = None absorbing_probs = automaton_layer.FiniteStateGraphAutomaton( encoded_graph=example.automaton_graph, variant_weights=variant_weights, static_metadata=graph_metadata, dynamic_metadata=example.graph_metadata, builder=automaton_builder.AutomatonBuilder( py_ast_graphs.SCHEMA, with_backtrack=with_backtrack), num_out_edges=1, share_states_across_edges=True).squeeze(axis=0) logits = model_util.safe_logit(absorbing_probs) if platt_scale: logits = model_util.ScaleAndShift(logits) return logits
def _shared_automaton_logic(graph_context, node_embeddings, edge_embeddings, variant_weights): """Helper function for shared automaton logic.""" # Run the automaton. edge_weights = automaton_layer.FiniteStateGraphAutomaton( encoded_graph=graph_context.bundle.automaton_graph, variant_weights=variant_weights, dynamic_metadata=graph_context.bundle.graph_metadata, static_metadata=graph_context.static_metadata, builder=graph_context.builder) return (node_embeddings, _add_edges(edge_embeddings, edge_weights.transpose([1, 2, 0]), graph_context.edges_are_embedded))
def apply(self, dummy_ignored): abstract_encoded_graph = jax.tree_map( lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph) abstract_variant_weights = jax.tree_map( lambda y: jax.lax.tie_in(dummy_ignored, y), variant_weights()) return automaton_layer.FiniteStateGraphAutomaton( encoded_graph=abstract_encoded_graph, variant_weights=abstract_variant_weights, dynamic_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=32, num_input_tagged_nodes=64), static_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=32, num_input_tagged_nodes=64), builder=builder, num_out_edges=3, num_intermediate_states=4, share_states_across_edges=shared, use_gate_parameterization=use_gate, estimator_type=estimator_type, name="the_layer", **kwargs)