def mock_model_def(example): del example side_outputs.SideOutput( -jnp.arange(5).astype("float32").reshape((1, 5)), name="one_sample_log_prob_per_edge_per_node") side_outputs.SideOutput(0.3, name="one_sample_reward_baseline") return model_util.safe_logit( jnp.array([ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], ]))
def mock_model_def(example, metadata): # Check that we get the right inputs. self.assertIs(example, mock_example) self.assertIs(metadata, mock_metadata) # Register a side output side_outputs.SideOutput(jnp.array(.1234), name="test_penalty") # Make sure we can generate an rng key with flax. _ = flax.deprecated.nn.make_rng() return jnp.log( jnp.array([ [.0, .0, .0, .0, .0], [.1, .0, .0, .2, .0], [.0, .1, .2, .2, .1], # <- This row is the "correct" bug index. [.0, .0, .0, .0, .0], [.1, .0, .0, .0, .0], ]))
def simple_add_model(a, b): side_outputs.SideOutput(a, name="a") side_outputs.SideOutput(b, name="b") return a + b
def apply( self, encoded_graph, dynamic_metadata, static_metadata, builder, num_out_edges=1, num_intermediate_states=0, variant_weights=None, steps=None, share_states_across_edges=True, backtrack_fails_prob=0.001, initialization_noise_factor=0.01, legacy_initialize=True, initialize_smoothing=0.001, use_gate_parameterization=False, gate_noise=0.2, logit_scaling="learned", dynamic_scaling_absolute_shift=1.0, dynamic_scaling_relative_shift=1.0, estimator_type="solver", sampling_max_possible_transitions=None, ): """Apply the graph finite-state automaton layer. Args: encoded_graph: Graph structure to run the automaton on. dynamic_metadata: Metadata about the actual unpadded graph. Unused. static_metadata: Statically-known metadata about the graph size. If encoded_graph is padded, this should reflect the padded size, not the original size. builder: Builder object that encoded the graph. Determines the shape of the automaton parameters. num_out_edges: How many distinct "edge types" to produce. Each of these corresponds to a different automaton start state. num_intermediate_states: How many intermediate states to include. These states are only used inside the automaton, and don't get exposed as start states for an edge type. If share_states_across_edges is True, all of the edge types will share the same set of intermediate states. If share_states_across_edges is False, each edge type will have its OWN set of `num_intermediate_states` states. variant_weights: Optional <float32[num_nodes, num_nodes, num_variants]> or <float32[num_nodes, num_nodes, num_out_edges, num_variants]> array that has nonnegative elements that sum to 1 along the last axis. If provided, variant_weights[i, j, (e,) v] specifies how much policy variant v should be used (for edge type e) when starting from node i and arriving at intermediate node j. Variants correspond to the start-node-conditioned observations described in Appendix C.2. steps: How many steps to use when solving the automaton. If not provided, uses the number of nodes in the graph. Smaller numbers of steps may make it impossible to reach certain nodes in the graph; larger numbers may allow additional backtracking and state changes at the expense of additional compute time. share_states_across_edges: Whether the different edge types share the same set of states. If True, any state can transition to any other state (even to a different start state). If False, every edge type gets a separate set of `1 + num_intermediate_states` that are not shared; in other words, every output edge uses a distinct finite state machine. backtrack_fails_prob: Backtracking decay factor; determines how often the automaton halts when it tries to take the BACKTRACK action. If the automaton attempts to backtrack with close to 100% probability, this ensures numerical stability and counteracts noise. initialization_noise_factor: How much noise to use when initializing the automaton policy (see AutomatonBuilder.initialize_routing_params) legacy_initialize: Whether to use legacy initialization, which sets the log-space softmax weights as Dirichlet random samples (instead of setting the softmax output distribution as Dirichlet random samples). Defaults to True so that we can reload old configs; new runs should use False. initialize_smoothing: Controls how much we smooth the initial parameters toward a uniform distribution. For small values, this is effectively a lower bound of the probability we take each action. If zero, we can sample arbitrarily small starting probabilities from the Dirichlet distribution. (Specifically, we adjust the sampled probabilities as `p_init(x) = (p_sampled(x) + c)/(1 + n*c)` where `n` is the number of possible actions and `c = initialize_smoothing`. use_gate_parameterization: Whether to use gate parameterization instead of default parameterization. If so, the other initialization args are ignored. gate_noise: Logistic noise for gate parameterization. logit_scaling: One of "none", "learned", "dynamic" dynamic_scaling_absolute_shift: For dynamic scaling, how much extra to shift logits, in an absolute sense, after shifting for mean magnitude. dynamic_scaling_relative_shift: For dynamic scaling, how much extra to shift logits, relative to standard deviation of magnitude, after shifting for mean magnitude estimator_type: One of "solver", "one_sample". sampling_max_possible_transitions: Number of possible outgoing transitions for any given start node. Used to accelerate the sampling process when estimator is "one_sample". Returns: <float32[num_out_edges, num_nodes, num_nodes]> weighted adjacency matrix for `num_out_edges` new edge types. """ del dynamic_metadata num_nodes = static_metadata.num_nodes steps = steps if steps is not None else num_nodes assert logit_scaling in ("none", "learned", "dynamic") if variant_weights is None: num_variants = 1 variant_out_edge_axis = None variant_weights = jnp.ones((num_nodes, num_nodes, num_variants)) elif variant_weights.ndim == 3: num_variants = variant_weights.shape[-1] variant_out_edge_axis = None if variant_weights.shape != (num_nodes, num_nodes, num_variants): raise ValueError( f"variant_weights shape {variant_weights.shape} doesn't match " f"expected shape ({num_nodes}, {num_nodes}, <anything>)") elif variant_weights.ndim == 4: num_variants = variant_weights.shape[-1] variant_out_edge_axis = 2 if variant_weights.shape != (num_nodes, num_nodes, num_out_edges, num_variants): raise ValueError( f"variant_weights shape {variant_weights.shape} doesn't match " "expected shape" f"({num_nodes}, {num_nodes}, {num_out_edges}, <anything>)") else: raise ValueError( f"Invalid variant_weights shape {variant_weights.shape};" " expected 3 or 4 axes") if share_states_across_edges: num_fsm_states = num_out_edges + num_intermediate_states # Initialize or retrieve the automaton parameters; these parameters are # shared across all edge types. if use_gate_parameterization: def shared_routing_initializer(rng_key, unused_shape): return builder.initialize_routing_gates( key=rng_key, num_fsm_states=num_fsm_states, num_variants=num_variants, logistic_noise=gate_noise) routing_gates = self.param( "routing_gate_logits_shared", shape=None, initializer=shared_routing_initializer) routing_gates = side_outputs.encourage_discrete_logits( routing_gates, distribution_type="binary", name="routing_gate_logits_shared") routing_params = builder.routing_gates_to_probs(routing_gates) else: def shared_routing_initializer(rng_key, unused_shape): routing_probs = builder.initialize_routing_params( key=rng_key, num_fsm_states=num_fsm_states, num_variants=num_variants, noise_factor=initialization_noise_factor) if legacy_initialize: return routing_probs return jax.tree_map( lambda x: jnp.log(x + initialize_smoothing), routing_probs) log_routing_params = self.param( "log_routing_params_shared", shape=None, initializer=shared_routing_initializer) routing_params = builder.routing_softmax(log_routing_params) # Don't precompute constants if we are tracing an XLA computation; wait # until we know a value for our parameters by adding a fake data # dependence. trigger = jax.tree_leaves(routing_params)[0] variant_weights = jax.lax.tie_in(trigger, variant_weights) # Build the automaton on the provided graph. transition_matrix = builder.build_transition_matrix( routing_params, encoded_graph, static_metadata) # Each edge type is a start state. if num_intermediate_states > 0: start_machine_states = jnp.concatenate([ jax.lax.tie_in(trigger, jnp.eye(num_out_edges)), jax.lax.tie_in( trigger, jnp.zeros((num_out_edges, num_intermediate_states))) ], 1) else: start_machine_states = jax.lax.tie_in(trigger, jnp.eye(num_out_edges)) start_machine_states = jnp.broadcast_to( start_machine_states, (num_nodes, num_out_edges, num_fsm_states)) # Solve for absorbing distribution for each of the starting states by # vmapping across the dimensions that depend on the start state. if estimator_type == "solver": absorbing_solution = jax_util.vmap_with_kwargs( automaton_builder.all_nodes_absorbing_solve, variant_weights_axis=variant_out_edge_axis, start_machine_states_axis=1)( builder=builder, transition_matrix=transition_matrix, variant_weights=variant_weights, start_machine_states=start_machine_states, steps=steps, backtrack_fails_prob=backtrack_fails_prob) elif estimator_type == "one_sample": assert sampling_max_possible_transitions is not None rollout_each_node_fn = jax_util.vmap_with_kwargs( automaton_sampling.roll_out_transitions, variant_weights_axis=0, start_machine_state_axis=0, node_index_axis=0, rng_axis=0) rollout_each_edgetype_fn = jax_util.vmap_with_kwargs( rollout_each_node_fn, variant_weights_axis=variant_out_edge_axis, start_machine_state_axis=1, rng_axis=0) all_states = rollout_each_edgetype_fn( builder=builder, transition_matrix=transition_matrix, variant_weights=variant_weights, start_machine_state=start_machine_states, node_index=jnp.arange(num_nodes), steps=steps, max_possible_transitions=sampling_max_possible_transitions, rng=jax.random.split(flax.nn.make_rng(), num_out_edges * num_nodes).reshape( [num_out_edges, num_nodes, -1])) def set_absorbing(final_node, succeeded): return jnp.zeros([num_nodes]).at[final_node].set(succeeded) # absorbing_solution is [num_out_edges, num_nodes, num_nodes] absorbing_solution = jax.vmap(jax.vmap(set_absorbing))( all_states.final_node, all_states.succeeded) side_outputs.SideOutput( all_states.log_prob, name="one_sample_log_prob_per_edge_per_node") # Somewhat of a hack: associate the log prob with its own learned # baseline as a side output, so it can be trained alongside the rest # of the model, but don't do anything with it until the loss function. one_sample_reward_baseline = self.param( "one_sample_reward_baseline", shape=(), initializer=jax.nn.initializers.zeros) side_outputs.SideOutput(one_sample_reward_baseline, name="one_sample_reward_baseline") else: raise ValueError(f"Invalid estimator {estimator_type}") # Rescale the logits. logits = model_util.safe_logit(absorbing_solution) if logit_scaling == "learned": # Learned scaling and shifting. logits = model_util.ScaleAndShift(logits) elif logit_scaling == "dynamic": # Dynamic scaling only implemented with gates. assert use_gate_parameterization # First, quantify how discrete the gates are. Conceptually, we want # to quantify how far away from zero the logits are, in a differentiable # way. To make it smooth, use logsumexp: relevant_gates = [ routing_gates.move_gates, routing_gates.accept_gates ] soft_abs_logits = [ jnp.logaddexp(g, -g) for g in relevant_gates ] # Take a mean and standard deviation over these logits to summarize. logit_mean = (sum(jnp.sum(x) for x in soft_abs_logits) / sum(x.size for x in soft_abs_logits)) logit_var = (sum( jnp.sum(jnp.square(x - logit_mean)) for x in soft_abs_logits) / sum(x.size for x in soft_abs_logits)) logit_std = jnp.sqrt(logit_var) side_outputs.SideOutput(logit_mean, name="gate_logit_abs_mean") side_outputs.SideOutput(logit_std, name="gate_logit_abs_std") # Now, use these to choose an adjustment factor. Intuitively, the "off" # gates should be centered around (-logit_mean). So anything that # gets sufficiently more mass than that should be "on". We consider # two notions of "sufficiently": either relative to the variance in # logits, or absolute. shift_threshold = (-logit_mean + dynamic_scaling_absolute_shift + dynamic_scaling_relative_shift * logit_std) side_outputs.SideOutput(shift_threshold, name="shift_threshold") # Adjust so that values at the shift threshold are mapped to edges of # weight 0.5 logits = logits - shift_threshold else: num_fsm_states = 1 + num_intermediate_states if estimator_type != "solver": raise NotImplementedError( "Sampling estimators not implemented for unshared states.") # Different automaton parameters for each start state if use_gate_parameterization: def unshared_routing_initializer(rng_key, unused_shape): key_per_edge = jax.random.split(rng_key, num_out_edges) return jax_util.vmap_with_kwargs( builder.initialize_routing_gates, key_axis=0)(key=key_per_edge, num_fsm_states=num_fsm_states, num_variants=num_variants, logistic_noise=gate_noise) routing_gates = self.param( "routing_gate_logits_unshared", shape=None, initializer=unshared_routing_initializer) routing_gates = side_outputs.encourage_discrete_logits( routing_gates, distribution_type="binary", name="routing_gate_logits_unshared") stacked_routing_params = jax.vmap( builder.routing_gates_to_probs)(routing_gates) else: def unshared_routing_initializer(rng_key, unused_shape): key_per_edge = jax.random.split(rng_key, num_out_edges) routing_probs = jax_util.vmap_with_kwargs( builder.initialize_routing_params, key_axis=0)(key=key_per_edge, num_fsm_states=num_fsm_states, num_variants=num_variants, noise_factor=initialization_noise_factor) if legacy_initialize: return routing_probs return jax.tree_map( lambda x: jnp.log(x + initialize_smoothing), routing_probs) log_routing_params = self.param( "log_routing_params_unshared", shape=None, initializer=unshared_routing_initializer) stacked_routing_params = jax.vmap( builder.routing_softmax)(log_routing_params) def solve_one(one_edge_routing_params, one_edge_variant_weights): """Run one of the edge-specific automata.""" # Build the automaton on the provided graph. transition_matrix = builder.build_transition_matrix( one_edge_routing_params, encoded_graph, static_metadata) # Start state is always state 0. start_machine_states = jnp.broadcast_to( (jnp.arange(num_fsm_states) == 0)[None, :], (num_nodes, num_fsm_states)) return automaton_builder.all_nodes_absorbing_solve( builder=builder, transition_matrix=transition_matrix, variant_weights=one_edge_variant_weights, start_machine_states=start_machine_states, steps=steps, backtrack_fails_prob=backtrack_fails_prob) absorbing_solution = jax.vmap( solve_one, in_axes=(0, variant_out_edge_axis), )(stacked_routing_params, variant_weights) # Rescale the logits. logits = model_util.safe_logit(absorbing_solution) if logit_scaling == "learned": # Learned scaling and shifting. logits = model_util.ScaleAndShift(logits) elif logit_scaling == "dynamic": raise NotImplementedError( "Dynamic scaling not implemented for unshared") logits = side_outputs.encourage_discrete_logits( logits, distribution_type="binary", name="edge_logits") result = jax.nn.sigmoid(logits) return result