def automaton_model(example, graph_metadata, edge_types_to_indices, variant_edge_types=(), platt_scale=False, with_backtrack=True): """Automaton-based module for edge supervision task. Args: example: Example to run the automaton on. graph_metadata: Statically-known metadata about the graph size. If encoded_graph is padded, this should reflect the padded size, not the original size. edge_types_to_indices: Mapping from edge type names to edge type indices. variant_edge_types: Edge types to use as variants. Assumes without checking that the given variants are mutually exclusive (at most one edge of one of these types exists between any pair of nodes). platt_scale: Whether to scale and shift the logits produced by the automaton. This can be viewed as a form of Platt scaling applied to the automaton logits. If True, this allows the model's output probabilities to sum to more than 1, so that it can express one-to-many relations. with_backtrack: Whether the automaton can restart the search as an action. Returns: <float32[num_nodes, num_nodes]> matrix of binary logits for a weighted adjacency matrix corresponding to the predicted output edges. """ if variant_edge_types: variant_edge_type_indices = [ edge_types_to_indices[type_str] for type_str in variant_edge_types ] num_edge_types = len(edge_types_to_indices) variant_weights = variants_from_edges(example, graph_metadata, variant_edge_type_indices, num_edge_types) else: variant_weights = None absorbing_probs = automaton_layer.FiniteStateGraphAutomaton( encoded_graph=example.automaton_graph, variant_weights=variant_weights, static_metadata=graph_metadata, dynamic_metadata=example.graph_metadata, builder=automaton_builder.AutomatonBuilder( py_ast_graphs.SCHEMA, with_backtrack=with_backtrack), num_out_edges=1, share_states_across_edges=True).squeeze(axis=0) logits = model_util.safe_logit(absorbing_probs) if platt_scale: logits = model_util.ScaleAndShift(logits) return logits
def residual_layer_norm_update(node_states, messages): """Update node states using a residual step and layer norm. This is based on the update step in a normal transformer model. We assume the node states and messages are the same size. Args: node_states: <float32[num_nodes, node_embedding_dim]> messages: <float32[num_nodes, node_embedding_dim]> Returns: <float32[num_nodes, node_embedding_dim]> new state. """ combined = node_states + messages return model_util.ScaleAndShift(jax.nn.normalize(combined, axis=-1))
def apply( self, encoded_graph, dynamic_metadata, static_metadata, builder, num_out_edges=1, num_intermediate_states=0, variant_weights=None, steps=None, share_states_across_edges=True, backtrack_fails_prob=0.001, initialization_noise_factor=0.01, legacy_initialize=True, initialize_smoothing=0.001, use_gate_parameterization=False, gate_noise=0.2, logit_scaling="learned", dynamic_scaling_absolute_shift=1.0, dynamic_scaling_relative_shift=1.0, estimator_type="solver", sampling_max_possible_transitions=None, ): """Apply the graph finite-state automaton layer. Args: encoded_graph: Graph structure to run the automaton on. dynamic_metadata: Metadata about the actual unpadded graph. Unused. static_metadata: Statically-known metadata about the graph size. If encoded_graph is padded, this should reflect the padded size, not the original size. builder: Builder object that encoded the graph. Determines the shape of the automaton parameters. num_out_edges: How many distinct "edge types" to produce. Each of these corresponds to a different automaton start state. num_intermediate_states: How many intermediate states to include. These states are only used inside the automaton, and don't get exposed as start states for an edge type. If share_states_across_edges is True, all of the edge types will share the same set of intermediate states. If share_states_across_edges is False, each edge type will have its OWN set of `num_intermediate_states` states. variant_weights: Optional <float32[num_nodes, num_nodes, num_variants]> or <float32[num_nodes, num_nodes, num_out_edges, num_variants]> array that has nonnegative elements that sum to 1 along the last axis. If provided, variant_weights[i, j, (e,) v] specifies how much policy variant v should be used (for edge type e) when starting from node i and arriving at intermediate node j. Variants correspond to the start-node-conditioned observations described in Appendix C.2. steps: How many steps to use when solving the automaton. If not provided, uses the number of nodes in the graph. Smaller numbers of steps may make it impossible to reach certain nodes in the graph; larger numbers may allow additional backtracking and state changes at the expense of additional compute time. share_states_across_edges: Whether the different edge types share the same set of states. If True, any state can transition to any other state (even to a different start state). If False, every edge type gets a separate set of `1 + num_intermediate_states` that are not shared; in other words, every output edge uses a distinct finite state machine. backtrack_fails_prob: Backtracking decay factor; determines how often the automaton halts when it tries to take the BACKTRACK action. If the automaton attempts to backtrack with close to 100% probability, this ensures numerical stability and counteracts noise. initialization_noise_factor: How much noise to use when initializing the automaton policy (see AutomatonBuilder.initialize_routing_params) legacy_initialize: Whether to use legacy initialization, which sets the log-space softmax weights as Dirichlet random samples (instead of setting the softmax output distribution as Dirichlet random samples). Defaults to True so that we can reload old configs; new runs should use False. initialize_smoothing: Controls how much we smooth the initial parameters toward a uniform distribution. For small values, this is effectively a lower bound of the probability we take each action. If zero, we can sample arbitrarily small starting probabilities from the Dirichlet distribution. (Specifically, we adjust the sampled probabilities as `p_init(x) = (p_sampled(x) + c)/(1 + n*c)` where `n` is the number of possible actions and `c = initialize_smoothing`. use_gate_parameterization: Whether to use gate parameterization instead of default parameterization. If so, the other initialization args are ignored. gate_noise: Logistic noise for gate parameterization. logit_scaling: One of "none", "learned", "dynamic" dynamic_scaling_absolute_shift: For dynamic scaling, how much extra to shift logits, in an absolute sense, after shifting for mean magnitude. dynamic_scaling_relative_shift: For dynamic scaling, how much extra to shift logits, relative to standard deviation of magnitude, after shifting for mean magnitude estimator_type: One of "solver", "one_sample". sampling_max_possible_transitions: Number of possible outgoing transitions for any given start node. Used to accelerate the sampling process when estimator is "one_sample". Returns: <float32[num_out_edges, num_nodes, num_nodes]> weighted adjacency matrix for `num_out_edges` new edge types. """ del dynamic_metadata num_nodes = static_metadata.num_nodes steps = steps if steps is not None else num_nodes assert logit_scaling in ("none", "learned", "dynamic") if variant_weights is None: num_variants = 1 variant_out_edge_axis = None variant_weights = jnp.ones((num_nodes, num_nodes, num_variants)) elif variant_weights.ndim == 3: num_variants = variant_weights.shape[-1] variant_out_edge_axis = None if variant_weights.shape != (num_nodes, num_nodes, num_variants): raise ValueError( f"variant_weights shape {variant_weights.shape} doesn't match " f"expected shape ({num_nodes}, {num_nodes}, <anything>)") elif variant_weights.ndim == 4: num_variants = variant_weights.shape[-1] variant_out_edge_axis = 2 if variant_weights.shape != (num_nodes, num_nodes, num_out_edges, num_variants): raise ValueError( f"variant_weights shape {variant_weights.shape} doesn't match " "expected shape" f"({num_nodes}, {num_nodes}, {num_out_edges}, <anything>)") else: raise ValueError( f"Invalid variant_weights shape {variant_weights.shape};" " expected 3 or 4 axes") if share_states_across_edges: num_fsm_states = num_out_edges + num_intermediate_states # Initialize or retrieve the automaton parameters; these parameters are # shared across all edge types. if use_gate_parameterization: def shared_routing_initializer(rng_key, unused_shape): return builder.initialize_routing_gates( key=rng_key, num_fsm_states=num_fsm_states, num_variants=num_variants, logistic_noise=gate_noise) routing_gates = self.param( "routing_gate_logits_shared", shape=None, initializer=shared_routing_initializer) routing_gates = side_outputs.encourage_discrete_logits( routing_gates, distribution_type="binary", name="routing_gate_logits_shared") routing_params = builder.routing_gates_to_probs(routing_gates) else: def shared_routing_initializer(rng_key, unused_shape): routing_probs = builder.initialize_routing_params( key=rng_key, num_fsm_states=num_fsm_states, num_variants=num_variants, noise_factor=initialization_noise_factor) if legacy_initialize: return routing_probs return jax.tree_map( lambda x: jnp.log(x + initialize_smoothing), routing_probs) log_routing_params = self.param( "log_routing_params_shared", shape=None, initializer=shared_routing_initializer) routing_params = builder.routing_softmax(log_routing_params) # Don't precompute constants if we are tracing an XLA computation; wait # until we know a value for our parameters by adding a fake data # dependence. trigger = jax.tree_leaves(routing_params)[0] variant_weights = jax.lax.tie_in(trigger, variant_weights) # Build the automaton on the provided graph. transition_matrix = builder.build_transition_matrix( routing_params, encoded_graph, static_metadata) # Each edge type is a start state. if num_intermediate_states > 0: start_machine_states = jnp.concatenate([ jax.lax.tie_in(trigger, jnp.eye(num_out_edges)), jax.lax.tie_in( trigger, jnp.zeros((num_out_edges, num_intermediate_states))) ], 1) else: start_machine_states = jax.lax.tie_in(trigger, jnp.eye(num_out_edges)) start_machine_states = jnp.broadcast_to( start_machine_states, (num_nodes, num_out_edges, num_fsm_states)) # Solve for absorbing distribution for each of the starting states by # vmapping across the dimensions that depend on the start state. if estimator_type == "solver": absorbing_solution = jax_util.vmap_with_kwargs( automaton_builder.all_nodes_absorbing_solve, variant_weights_axis=variant_out_edge_axis, start_machine_states_axis=1)( builder=builder, transition_matrix=transition_matrix, variant_weights=variant_weights, start_machine_states=start_machine_states, steps=steps, backtrack_fails_prob=backtrack_fails_prob) elif estimator_type == "one_sample": assert sampling_max_possible_transitions is not None rollout_each_node_fn = jax_util.vmap_with_kwargs( automaton_sampling.roll_out_transitions, variant_weights_axis=0, start_machine_state_axis=0, node_index_axis=0, rng_axis=0) rollout_each_edgetype_fn = jax_util.vmap_with_kwargs( rollout_each_node_fn, variant_weights_axis=variant_out_edge_axis, start_machine_state_axis=1, rng_axis=0) all_states = rollout_each_edgetype_fn( builder=builder, transition_matrix=transition_matrix, variant_weights=variant_weights, start_machine_state=start_machine_states, node_index=jnp.arange(num_nodes), steps=steps, max_possible_transitions=sampling_max_possible_transitions, rng=jax.random.split(flax.nn.make_rng(), num_out_edges * num_nodes).reshape( [num_out_edges, num_nodes, -1])) def set_absorbing(final_node, succeeded): return jnp.zeros([num_nodes]).at[final_node].set(succeeded) # absorbing_solution is [num_out_edges, num_nodes, num_nodes] absorbing_solution = jax.vmap(jax.vmap(set_absorbing))( all_states.final_node, all_states.succeeded) side_outputs.SideOutput( all_states.log_prob, name="one_sample_log_prob_per_edge_per_node") # Somewhat of a hack: associate the log prob with its own learned # baseline as a side output, so it can be trained alongside the rest # of the model, but don't do anything with it until the loss function. one_sample_reward_baseline = self.param( "one_sample_reward_baseline", shape=(), initializer=jax.nn.initializers.zeros) side_outputs.SideOutput(one_sample_reward_baseline, name="one_sample_reward_baseline") else: raise ValueError(f"Invalid estimator {estimator_type}") # Rescale the logits. logits = model_util.safe_logit(absorbing_solution) if logit_scaling == "learned": # Learned scaling and shifting. logits = model_util.ScaleAndShift(logits) elif logit_scaling == "dynamic": # Dynamic scaling only implemented with gates. assert use_gate_parameterization # First, quantify how discrete the gates are. Conceptually, we want # to quantify how far away from zero the logits are, in a differentiable # way. To make it smooth, use logsumexp: relevant_gates = [ routing_gates.move_gates, routing_gates.accept_gates ] soft_abs_logits = [ jnp.logaddexp(g, -g) for g in relevant_gates ] # Take a mean and standard deviation over these logits to summarize. logit_mean = (sum(jnp.sum(x) for x in soft_abs_logits) / sum(x.size for x in soft_abs_logits)) logit_var = (sum( jnp.sum(jnp.square(x - logit_mean)) for x in soft_abs_logits) / sum(x.size for x in soft_abs_logits)) logit_std = jnp.sqrt(logit_var) side_outputs.SideOutput(logit_mean, name="gate_logit_abs_mean") side_outputs.SideOutput(logit_std, name="gate_logit_abs_std") # Now, use these to choose an adjustment factor. Intuitively, the "off" # gates should be centered around (-logit_mean). So anything that # gets sufficiently more mass than that should be "on". We consider # two notions of "sufficiently": either relative to the variance in # logits, or absolute. shift_threshold = (-logit_mean + dynamic_scaling_absolute_shift + dynamic_scaling_relative_shift * logit_std) side_outputs.SideOutput(shift_threshold, name="shift_threshold") # Adjust so that values at the shift threshold are mapped to edges of # weight 0.5 logits = logits - shift_threshold else: num_fsm_states = 1 + num_intermediate_states if estimator_type != "solver": raise NotImplementedError( "Sampling estimators not implemented for unshared states.") # Different automaton parameters for each start state if use_gate_parameterization: def unshared_routing_initializer(rng_key, unused_shape): key_per_edge = jax.random.split(rng_key, num_out_edges) return jax_util.vmap_with_kwargs( builder.initialize_routing_gates, key_axis=0)(key=key_per_edge, num_fsm_states=num_fsm_states, num_variants=num_variants, logistic_noise=gate_noise) routing_gates = self.param( "routing_gate_logits_unshared", shape=None, initializer=unshared_routing_initializer) routing_gates = side_outputs.encourage_discrete_logits( routing_gates, distribution_type="binary", name="routing_gate_logits_unshared") stacked_routing_params = jax.vmap( builder.routing_gates_to_probs)(routing_gates) else: def unshared_routing_initializer(rng_key, unused_shape): key_per_edge = jax.random.split(rng_key, num_out_edges) routing_probs = jax_util.vmap_with_kwargs( builder.initialize_routing_params, key_axis=0)(key=key_per_edge, num_fsm_states=num_fsm_states, num_variants=num_variants, noise_factor=initialization_noise_factor) if legacy_initialize: return routing_probs return jax.tree_map( lambda x: jnp.log(x + initialize_smoothing), routing_probs) log_routing_params = self.param( "log_routing_params_unshared", shape=None, initializer=unshared_routing_initializer) stacked_routing_params = jax.vmap( builder.routing_softmax)(log_routing_params) def solve_one(one_edge_routing_params, one_edge_variant_weights): """Run one of the edge-specific automata.""" # Build the automaton on the provided graph. transition_matrix = builder.build_transition_matrix( one_edge_routing_params, encoded_graph, static_metadata) # Start state is always state 0. start_machine_states = jnp.broadcast_to( (jnp.arange(num_fsm_states) == 0)[None, :], (num_nodes, num_fsm_states)) return automaton_builder.all_nodes_absorbing_solve( builder=builder, transition_matrix=transition_matrix, variant_weights=one_edge_variant_weights, start_machine_states=start_machine_states, steps=steps, backtrack_fails_prob=backtrack_fails_prob) absorbing_solution = jax.vmap( solve_one, in_axes=(0, variant_out_edge_axis), )(stacked_routing_params, variant_weights) # Rescale the logits. logits = model_util.safe_logit(absorbing_solution) if logit_scaling == "learned": # Learned scaling and shifting. logits = model_util.ScaleAndShift(logits) elif logit_scaling == "dynamic": raise NotImplementedError( "Dynamic scaling not implemented for unshared") logits = side_outputs.encourage_discrete_logits( logits, distribution_type="binary", name="edge_logits") result = jax.nn.sigmoid(logits) return result
def apply( self, graph_context, node_embeddings, edge_embeddings, forward_edge_types=gin.REQUIRED, reverse_edge_types=gin.REQUIRED, walk_length_log2=gin.REQUIRED, ): """Modifies edge embeddings using a uniform random walk. Uses an efficient repeated-squaring technique to compute the absorbing distribution. Args: graph_context: Input graph for this example. node_embeddings: Current node embeddings, as <float32[num_nodes, node_embedding_dim]> edge_embeddings: Current edge embeddings, as <float32[num_nodes, num_nodes, edge_embedding_dim]> forward_edge_types: Edge types to use in the forward direction. As a list of lists to allow configuring groups of edges in config files; this will be flattened before use. reverse_edge_types: Edge types to use in the reverse direction. Note that reversed edge types are given a separate embedding from forward edge types; undirected edges should be represented by adding two edges in opposite directions and then only using `forward_edge_types`. Also a list of lists, as above. walk_length_log2: Base-2 logarithm of maximum walk length; this determines how many times we will square the transition matrix (doubling the walk length). Returns: New node and edge embeddings. Node embeddings will not be modified. Edge embeddings will be modified by adding a new edge type (either embedded or concatenated based on graph_context.edges_are_embedded). """ num_nodes = node_embeddings.shape[0] # pylint: disable=g-complex-comprehension forward_edge_type_indices = [ graph_context.edge_types_to_indices[type_str] for group in forward_edge_types for type_str in group ] reverse_edge_type_indices = [ graph_context.edge_types_to_indices[type_str] for group in reverse_edge_types for type_str in group ] # pylint: enable=g-complex-comprehension adjacency = graph_layers.edge_mask( edges=graph_context.bundle.edges, num_nodes=num_nodes, num_edge_types=len(graph_context.edge_types_to_indices), forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) adjacency = jnp.maximum(adjacency, jnp.eye(num_nodes)) absorbing_logit = self.param( "absorbing_logit", shape=(), initializer=lambda *_: jax.scipy.special.logit(0.1)) absorbing_prob = jax.nn.sigmoid(absorbing_logit) nonabsorbing_prob = jax.nn.sigmoid(-absorbing_logit) walk_matrix = nonabsorbing_prob * adjacency / jnp.sum( adjacency, axis=1, keepdims=True) # A, I # A^2, A + I # (A^2)^2 = A^4, (A + I)A^2 + (A + I) = A^3 + A^2 + A + I # ... def step(state, _): nth_power, nth_partial_sum = state return (nth_power @ nth_power, nth_power @ nth_partial_sum + nth_partial_sum), None (_, partial_sum), _ = jax.lax.scan(step, (walk_matrix, jnp.eye(num_nodes)), None, length=walk_length_log2) approx_visits = absorbing_prob * partial_sum logits = model_util.safe_logit(approx_visits) logits = model_util.ScaleAndShift(logits) edge_weights = jax.nn.sigmoid(logits) return (node_embeddings, _add_edges(edge_embeddings, edge_weights[:, :, None], graph_context.edges_are_embedded))