def extract_outputs_and_targets( model, padded_example_and_rng, target_edge_index, num_edge_types, ): """Extract model outputs and targets for an example. Args: model: Model to run on the example. padded_example_and_rng: Example to extract targets from, with RNG. target_edge_index: Index of the target edge type. num_edge_types: How many edge types there are. Returns: Tuple (output_logits, targets, valid_mask, num_nodes, captured) """ padded_example, rng = padded_example_and_rng # Run the model. with side_outputs.collect_side_outputs() as captured: with flax.nn.stochastic(rng): output_logits = model(padded_example) # Extract targets. targets = padded_example.edges.apply_add( in_array=( jnp.arange(num_edge_types) == target_edge_index).astype("int32"), out_array=jnp.zeros(output_logits.shape, dtype="int32")).astype("bool") targets = preprocess_targets(targets) # Compute valid mask for outputs and targets. max_num_nodes = output_logits.shape[0] num_nodes = padded_example.graph_metadata.num_nodes valid_nodes = jnp.arange(max_num_nodes) < num_nodes valid_nodes_float = valid_nodes.astype("float32") valid_mask = jnp.einsum("i,j->ij", valid_nodes_float, valid_nodes_float) return output_logits, targets, valid_mask, num_nodes, captured
def test_collect_side_outputs(self): with side_outputs.collect_side_outputs() as penalties: _ = simple_add_model.init(jax.random.PRNGKey(0), 3, 4) self.assertEqual(penalties, { "/a": 3, "/b": 4, })
def test_encourage_discrete_logits(self, distribution_type): if distribution_type == "binary": logits = { "a": jnp.array([[0., 1.], [2., 3.]]), "b": jnp.array(4.), } p = jax.nn.sigmoid(jnp.arange(5, dtype=jnp.float32)) expected_entropy = -jnp.mean(p * jnp.log(p) + (1 - p) * jnp.log(1 - p)) elif distribution_type == "categorical": logits = { "a": jnp.array([[0., 1.], [2., 3.]]), "b": jnp.array([4., 5.]), } p = jax.nn.softmax( jnp.arange(6, dtype=jnp.float32).reshape([3, 2])) expected_entropy = -jnp.mean(jnp.sum(p * jnp.log(p), axis=-1)) # Penalty only. with side_outputs.collect_side_outputs() as penalties: out, _ = encourage_discrete_model.init( jax.random.PRNGKey(0), logits, distribution_type=distribution_type, regularize=True, perturb_scale=None) self.assertTrue(jnp.all(out["a"] == logits["a"])) self.assertTrue(jnp.all(out["b"] == logits["b"])) np.testing.assert_allclose(penalties["/foo_entropy"], expected_entropy, rtol=1e-6) # Perturbed only. with flax.nn.stochastic(jax.random.PRNGKey(0)): out, _ = encourage_discrete_model.init( jax.random.PRNGKey(0), logits, distribution_type=distribution_type, regularize=False, perturb_scale=1) # Should be modified. self.assertTrue(jnp.all(out["a"] != logits["a"])) self.assertTrue(jnp.all(out["b"] != logits["b"]))
def test_automaton_layer_abstract_init(self, shared, variant_weights, use_gate, estimator_type, **kwargs): # Create a simple schema and empty encoded graph. schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("ai_0")], out_edges=[graph_types.OutEdgeType("ao_0") ]), } builder = automaton_builder.AutomatonBuilder(schema) encoded_graph = automaton_builder.EncodedGraph( initial_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((128, 1), dtype=jnp.int32), output_indices=jnp.zeros((128, 2), dtype=jnp.int32), values=jnp.zeros((128, ), dtype=jnp.float32), ), initial_to_special=jnp.zeros((32, ), dtype=jnp.int32), in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((128, 1), dtype=jnp.int32), output_indices=jnp.zeros((128, 2), dtype=jnp.int32), values=jnp.zeros((128, ), dtype=jnp.float32), ), in_tagged_to_special=jnp.zeros((64, ), dtype=jnp.int32), in_tagged_node_indices=jnp.zeros((64, ), dtype=jnp.int32), ) # Make sure the layer can be initialized and applied within a model. # This model is fairly simple; it just pretends that the encoded graph and # variants depend on the input. class TestModel(flax.deprecated.nn.Module): def apply(self, dummy_ignored): abstract_encoded_graph = jax.tree_map( lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph) abstract_variant_weights = jax.tree_map( lambda y: jax.lax.tie_in(dummy_ignored, y), variant_weights()) return automaton_layer.FiniteStateGraphAutomaton( encoded_graph=abstract_encoded_graph, variant_weights=abstract_variant_weights, dynamic_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=32, num_input_tagged_nodes=64), static_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=32, num_input_tagged_nodes=64), builder=builder, num_out_edges=3, num_intermediate_states=4, share_states_across_edges=shared, use_gate_parameterization=use_gate, estimator_type=estimator_type, name="the_layer", **kwargs) with side_outputs.collect_side_outputs() as side: with flax.deprecated.nn.stochastic(jax.random.PRNGKey(0)): # For some reason init_by_shape breaks the custom_vjp? abstract_out, unused_params = TestModel.init( jax.random.PRNGKey(1234), jnp.zeros((), jnp.float32)) del unused_params self.assertEqual(abstract_out.shape, (3, 32, 32)) if estimator_type == "one_sample": log_prob_key = "/the_layer/one_sample_log_prob_per_edge_per_node" self.assertIn(log_prob_key, side) self.assertEqual(side[log_prob_key].shape, (3, 32))
def loss_fn( model, padded_example_and_rng, static_metadata, regularization_weights = None, reinforce_weight = 1.0, baseline_weight = 0.001, ): """Loss function for multi-pointer task. Args: model: The model to evaluate. padded_example_and_rng: Padded example to evaluate on, with a PRNGKey. static_metadata: Padding configuration for the example, since this may vary for different examples. regularization_weights: Associates side output key regexes with regularization penalties. reinforce_weight: Weight to give to the reinforce term. baseline_weight: Weight to give to the baseline. Returns: Tuple of loss and metrics. """ padded_example, rng = padded_example_and_rng # Run the model. with side_outputs.collect_side_outputs() as collected_side_outputs: with flax.nn.stochastic(rng): joint_log_probs = model(padded_example, static_metadata) # Computing the loss: # Extract logits for the correct location. log_probs_at_bug = joint_log_probs[padded_example.bug_node_index, :] # Compute p(repair) = sum[ p(node) p(repair | node) ] # -> log p(repair) = logsumexp[ log p(node) + log p (repair | node) ] log_prob_joint = jax.scipy.special.logsumexp( log_probs_at_bug + jnp.log(padded_example.repair_node_mask)) # Metrics: # Marginal log probabilities: log_prob_bug = jax.scipy.special.logsumexp(log_probs_at_bug) log_prob_repair = jax.scipy.special.logsumexp( jax.scipy.special.logsumexp(joint_log_probs, axis=0) + jnp.log(padded_example.repair_node_mask)) # Conditional log probabilities: log_prob_repair_given_bug = log_prob_joint - log_prob_bug log_prob_bug_given_repair = log_prob_joint - log_prob_repair # Majority accuracy (1 if we assign the correct tuple > 50%): # (note that this is easier to compute, since we can't currently aggregate # probability separately for each candidate.) log_half = jnp.log(0.5) majority_acc_joint = log_prob_joint > log_half # Probabilities associated with each node. node_node_probs = jnp.exp(joint_log_probs) # Accumulate across unique candidates by identifier. This has the same shape, # but only the first few values will be populated. node_candidate_probs = padded_example.unique_candidate_operator.apply_add( in_array=node_node_probs, out_array=jnp.zeros_like(node_node_probs), in_dims=[1], out_dims=[1]) # Classify: 50% decision boundary only_buggy_probs = node_candidate_probs.at[0, :].set(0).at[:, 0].set(0) p_buggy = jnp.sum(only_buggy_probs) pred_nobug = p_buggy <= 0.5 # Localize/repair: take most likely bug position, conditioned on being buggy pred_bug_loc, pred_cand_id = jnp.unravel_index( jnp.argmax(only_buggy_probs), only_buggy_probs.shape) actual_nobug = jnp.array(padded_example.bug_node_index == 0) actual_bug = jnp.logical_not(actual_nobug) pred_bug = jnp.logical_not(pred_nobug) metrics = { 'nll/joint': -log_prob_joint, 'nll/marginal_bug': -log_prob_bug, 'nll/marginal_repair': -log_prob_repair, 'nll/repair_given_bug': -log_prob_repair_given_bug, 'nll/bug_given_repair': -log_prob_bug_given_repair, 'inaccuracy/legacy_overall': 1 - majority_acc_joint, 'inaccuracy/overall': (~((actual_nobug & pred_nobug) | (actual_bug & pred_bug & (pred_bug_loc == padded_example.bug_node_index) & (pred_cand_id == padded_example.repair_id)))), 'inaccuracy/classification_overall': (actual_nobug != pred_nobug), 'inaccuracy/classification_given_nobug': train_util.RatioMetric( numerator=(actual_nobug & ~pred_nobug), denominator=actual_nobug), 'inaccuracy/classification_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~pred_bug), denominator=actual_bug), 'inaccuracy/localized_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~(pred_bug_loc == padded_example.bug_node_index)), denominator=actual_bug), 'inaccuracy/repaired_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~(pred_cand_id == padded_example.repair_id)), denominator=actual_bug), 'inaccuracy/localized_repaired_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~((pred_bug_loc == padded_example.bug_node_index) & (pred_cand_id == padded_example.repair_id))), denominator=actual_bug), 'inaccuracy/overall_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~(pred_bug & (pred_bug_loc == padded_example.bug_node_index) & (pred_cand_id == padded_example.repair_id))), denominator=actual_bug), } loss = -log_prob_joint for k, v in collected_side_outputs.items(): # Flax collection keys will start with "/". if v.shape == (): # pylint: disable=g-explicit-bool-comparison metrics['side' + k] = v if regularization_weights: total_regularization = 0 for query, weight in regularization_weights.items(): logging.info('Regularizing side outputs matching query %s', query) found = False for k, v in collected_side_outputs.items(): if re.search(query, k): found = True logging.info('Regularizing %s with weight %f', k, weight) total_regularization += weight * v if not found: raise ValueError( f'Regularization query {query} did not match any side output. ' f'Side outputs were {set(collected_side_outputs.keys())}') loss = loss + total_regularization is_single_sample = any( k.endswith('one_sample_log_prob_per_edge_per_node') for k in collected_side_outputs) if is_single_sample: log_prob, = [ v for k, v in collected_side_outputs.items() if k.endswith('one_sample_log_prob_per_edge_per_node') ] baseline, = [ v for k, v in collected_side_outputs.items() if k.endswith('one_sample_reward_baseline') ] num_real_nodes = padded_example.input_graph.bundle.graph_metadata.num_nodes valid_mask = ( jnp.arange(static_metadata.bundle_padding.static_max_metadata.num_nodes) < num_real_nodes) log_prob = jnp.where(valid_mask[None, :], log_prob, 0) total_log_prob = jnp.sum(log_prob) reinforce_virtual_cost = ( total_log_prob * jax.lax.stop_gradient(loss - baseline)) baseline_penalty = jnp.square(loss - baseline) reinforce_virtual_cost_zeroed = reinforce_virtual_cost - jax.lax.stop_gradient( reinforce_virtual_cost) loss = ( loss + reinforce_weight * reinforce_virtual_cost_zeroed + baseline_weight * baseline_penalty) metrics['reinforce_virtual_cost'] = reinforce_virtual_cost metrics['baseline_penalty'] = baseline_penalty metrics['baseline'] = baseline metrics['total_log_prob'] = total_log_prob metrics = jax.tree_map(lambda x: x.astype(jnp.float32), metrics) return loss, metrics