def extract_outputs_and_targets(
    model,
    padded_example_and_rng,
    target_edge_index,
    num_edge_types,
):
  """Extract model outputs and targets for an example.

  Args:
    model: Model to run on the example.
    padded_example_and_rng: Example to extract targets from, with RNG.
    target_edge_index: Index of the target edge type.
    num_edge_types: How many edge types there are.

  Returns:
    Tuple (output_logits, targets, valid_mask, num_nodes, captured)
  """
  padded_example, rng = padded_example_and_rng
  # Run the model.
  with side_outputs.collect_side_outputs() as captured:
    with flax.nn.stochastic(rng):
      output_logits = model(padded_example)
  # Extract targets.
  targets = padded_example.edges.apply_add(
      in_array=(
          jnp.arange(num_edge_types) == target_edge_index).astype("int32"),
      out_array=jnp.zeros(output_logits.shape, dtype="int32")).astype("bool")
  targets = preprocess_targets(targets)
  # Compute valid mask for outputs and targets.
  max_num_nodes = output_logits.shape[0]
  num_nodes = padded_example.graph_metadata.num_nodes
  valid_nodes = jnp.arange(max_num_nodes) < num_nodes
  valid_nodes_float = valid_nodes.astype("float32")
  valid_mask = jnp.einsum("i,j->ij", valid_nodes_float, valid_nodes_float)
  return output_logits, targets, valid_mask, num_nodes, captured
Ejemplo n.º 2
0
    def test_collect_side_outputs(self):
        with side_outputs.collect_side_outputs() as penalties:
            _ = simple_add_model.init(jax.random.PRNGKey(0), 3, 4)

        self.assertEqual(penalties, {
            "/a": 3,
            "/b": 4,
        })
Ejemplo n.º 3
0
    def test_encourage_discrete_logits(self, distribution_type):
        if distribution_type == "binary":
            logits = {
                "a": jnp.array([[0., 1.], [2., 3.]]),
                "b": jnp.array(4.),
            }
            p = jax.nn.sigmoid(jnp.arange(5, dtype=jnp.float32))
            expected_entropy = -jnp.mean(p * jnp.log(p) +
                                         (1 - p) * jnp.log(1 - p))
        elif distribution_type == "categorical":
            logits = {
                "a": jnp.array([[0., 1.], [2., 3.]]),
                "b": jnp.array([4., 5.]),
            }
            p = jax.nn.softmax(
                jnp.arange(6, dtype=jnp.float32).reshape([3, 2]))
            expected_entropy = -jnp.mean(jnp.sum(p * jnp.log(p), axis=-1))

        # Penalty only.
        with side_outputs.collect_side_outputs() as penalties:
            out, _ = encourage_discrete_model.init(
                jax.random.PRNGKey(0),
                logits,
                distribution_type=distribution_type,
                regularize=True,
                perturb_scale=None)

        self.assertTrue(jnp.all(out["a"] == logits["a"]))
        self.assertTrue(jnp.all(out["b"] == logits["b"]))
        np.testing.assert_allclose(penalties["/foo_entropy"],
                                   expected_entropy,
                                   rtol=1e-6)

        # Perturbed only.
        with flax.nn.stochastic(jax.random.PRNGKey(0)):
            out, _ = encourage_discrete_model.init(
                jax.random.PRNGKey(0),
                logits,
                distribution_type=distribution_type,
                regularize=False,
                perturb_scale=1)

        # Should be modified.
        self.assertTrue(jnp.all(out["a"] != logits["a"]))
        self.assertTrue(jnp.all(out["b"] != logits["b"]))
    def test_automaton_layer_abstract_init(self, shared, variant_weights,
                                           use_gate, estimator_type, **kwargs):
        # Create a simple schema and empty encoded graph.
        schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("ai_0")],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
        }
        builder = automaton_builder.AutomatonBuilder(schema)
        encoded_graph = automaton_builder.EncodedGraph(
            initial_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            initial_to_special=jnp.zeros((32, ), dtype=jnp.int32),
            in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            in_tagged_to_special=jnp.zeros((64, ), dtype=jnp.int32),
            in_tagged_node_indices=jnp.zeros((64, ), dtype=jnp.int32),
        )

        # Make sure the layer can be initialized and applied within a model.
        # This model is fairly simple; it just pretends that the encoded graph and
        # variants depend on the input.
        class TestModel(flax.deprecated.nn.Module):
            def apply(self, dummy_ignored):
                abstract_encoded_graph = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph)
                abstract_variant_weights = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y),
                    variant_weights())
                return automaton_layer.FiniteStateGraphAutomaton(
                    encoded_graph=abstract_encoded_graph,
                    variant_weights=abstract_variant_weights,
                    dynamic_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    static_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    builder=builder,
                    num_out_edges=3,
                    num_intermediate_states=4,
                    share_states_across_edges=shared,
                    use_gate_parameterization=use_gate,
                    estimator_type=estimator_type,
                    name="the_layer",
                    **kwargs)

        with side_outputs.collect_side_outputs() as side:
            with flax.deprecated.nn.stochastic(jax.random.PRNGKey(0)):
                # For some reason init_by_shape breaks the custom_vjp?
                abstract_out, unused_params = TestModel.init(
                    jax.random.PRNGKey(1234), jnp.zeros((), jnp.float32))

        del unused_params
        self.assertEqual(abstract_out.shape, (3, 32, 32))

        if estimator_type == "one_sample":
            log_prob_key = "/the_layer/one_sample_log_prob_per_edge_per_node"
            self.assertIn(log_prob_key, side)
            self.assertEqual(side[log_prob_key].shape, (3, 32))
def loss_fn(
    model,
    padded_example_and_rng,
    static_metadata,
    regularization_weights = None,
    reinforce_weight = 1.0,
    baseline_weight = 0.001,
):
  """Loss function for multi-pointer task.

  Args:
    model: The model to evaluate.
    padded_example_and_rng: Padded example to evaluate on, with a PRNGKey.
    static_metadata: Padding configuration for the example, since this may vary
      for different examples.
    regularization_weights: Associates side output key regexes with
      regularization penalties.
    reinforce_weight: Weight to give to the reinforce term.
    baseline_weight: Weight to give to the baseline.

  Returns:
    Tuple of loss and metrics.
  """
  padded_example, rng = padded_example_and_rng

  # Run the model.
  with side_outputs.collect_side_outputs() as collected_side_outputs:
    with flax.nn.stochastic(rng):
      joint_log_probs = model(padded_example, static_metadata)

  # Computing the loss:
  # Extract logits for the correct location.
  log_probs_at_bug = joint_log_probs[padded_example.bug_node_index, :]
  # Compute p(repair) = sum[ p(node) p(repair | node) ]
  # -> log p(repair) = logsumexp[ log p(node) + log p (repair | node) ]
  log_prob_joint = jax.scipy.special.logsumexp(
      log_probs_at_bug + jnp.log(padded_example.repair_node_mask))

  # Metrics:
  # Marginal log probabilities:
  log_prob_bug = jax.scipy.special.logsumexp(log_probs_at_bug)
  log_prob_repair = jax.scipy.special.logsumexp(
      jax.scipy.special.logsumexp(joint_log_probs, axis=0) +
      jnp.log(padded_example.repair_node_mask))

  # Conditional log probabilities:
  log_prob_repair_given_bug = log_prob_joint - log_prob_bug
  log_prob_bug_given_repair = log_prob_joint - log_prob_repair

  # Majority accuracy (1 if we assign the correct tuple > 50%):
  # (note that this is easier to compute, since we can't currently aggregate
  # probability separately for each candidate.)
  log_half = jnp.log(0.5)
  majority_acc_joint = log_prob_joint > log_half

  # Probabilities associated with each node.
  node_node_probs = jnp.exp(joint_log_probs)
  # Accumulate across unique candidates by identifier. This has the same shape,
  # but only the first few values will be populated.
  node_candidate_probs = padded_example.unique_candidate_operator.apply_add(
      in_array=node_node_probs,
      out_array=jnp.zeros_like(node_node_probs),
      in_dims=[1],
      out_dims=[1])

  # Classify: 50% decision boundary
  only_buggy_probs = node_candidate_probs.at[0, :].set(0).at[:, 0].set(0)
  p_buggy = jnp.sum(only_buggy_probs)
  pred_nobug = p_buggy <= 0.5

  # Localize/repair: take most likely bug position, conditioned on being buggy
  pred_bug_loc, pred_cand_id = jnp.unravel_index(
      jnp.argmax(only_buggy_probs), only_buggy_probs.shape)

  actual_nobug = jnp.array(padded_example.bug_node_index == 0)

  actual_bug = jnp.logical_not(actual_nobug)
  pred_bug = jnp.logical_not(pred_nobug)

  metrics = {
      'nll/joint':
          -log_prob_joint,
      'nll/marginal_bug':
          -log_prob_bug,
      'nll/marginal_repair':
          -log_prob_repair,
      'nll/repair_given_bug':
          -log_prob_repair_given_bug,
      'nll/bug_given_repair':
          -log_prob_bug_given_repair,
      'inaccuracy/legacy_overall':
          1 - majority_acc_joint,
      'inaccuracy/overall':
          (~((actual_nobug & pred_nobug) |
             (actual_bug & pred_bug &
              (pred_bug_loc == padded_example.bug_node_index) &
              (pred_cand_id == padded_example.repair_id)))),
      'inaccuracy/classification_overall': (actual_nobug != pred_nobug),
      'inaccuracy/classification_given_nobug':
          train_util.RatioMetric(
              numerator=(actual_nobug & ~pred_nobug), denominator=actual_nobug),
      'inaccuracy/classification_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug & ~pred_bug), denominator=actual_bug),
      'inaccuracy/localized_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~(pred_bug_loc == padded_example.bug_node_index)),
              denominator=actual_bug),
      'inaccuracy/repaired_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~(pred_cand_id == padded_example.repair_id)),
              denominator=actual_bug),
      'inaccuracy/localized_repaired_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~((pred_bug_loc == padded_example.bug_node_index) &
                             (pred_cand_id == padded_example.repair_id))),
              denominator=actual_bug),
      'inaccuracy/overall_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~(pred_bug &
                             (pred_bug_loc == padded_example.bug_node_index) &
                             (pred_cand_id == padded_example.repair_id))),
              denominator=actual_bug),
  }

  loss = -log_prob_joint

  for k, v in collected_side_outputs.items():
    # Flax collection keys will start with "/".
    if v.shape == ():  # pylint: disable=g-explicit-bool-comparison
      metrics['side' + k] = v

  if regularization_weights:
    total_regularization = 0
    for query, weight in regularization_weights.items():
      logging.info('Regularizing side outputs matching query %s', query)
      found = False
      for k, v in collected_side_outputs.items():
        if re.search(query, k):
          found = True
          logging.info('Regularizing %s with weight %f', k, weight)
          total_regularization += weight * v
      if not found:
        raise ValueError(
            f'Regularization query {query} did not match any side output. '
            f'Side outputs were {set(collected_side_outputs.keys())}')

    loss = loss + total_regularization

  is_single_sample = any(
      k.endswith('one_sample_log_prob_per_edge_per_node')
      for k in collected_side_outputs)
  if is_single_sample:
    log_prob, = [
        v for k, v in collected_side_outputs.items()
        if k.endswith('one_sample_log_prob_per_edge_per_node')
    ]
    baseline, = [
        v for k, v in collected_side_outputs.items()
        if k.endswith('one_sample_reward_baseline')
    ]

    num_real_nodes = padded_example.input_graph.bundle.graph_metadata.num_nodes
    valid_mask = (
        jnp.arange(static_metadata.bundle_padding.static_max_metadata.num_nodes)
        < num_real_nodes)
    log_prob = jnp.where(valid_mask[None, :], log_prob, 0)
    total_log_prob = jnp.sum(log_prob)

    reinforce_virtual_cost = (
        total_log_prob * jax.lax.stop_gradient(loss - baseline))
    baseline_penalty = jnp.square(loss - baseline)

    reinforce_virtual_cost_zeroed = reinforce_virtual_cost - jax.lax.stop_gradient(
        reinforce_virtual_cost)

    loss = (
        loss + reinforce_weight * reinforce_virtual_cost_zeroed +
        baseline_weight * baseline_penalty)
    metrics['reinforce_virtual_cost'] = reinforce_virtual_cost
    metrics['baseline_penalty'] = baseline_penalty
    metrics['baseline'] = baseline
    metrics['total_log_prob'] = total_log_prob

  metrics = jax.tree_map(lambda x: x.astype(jnp.float32), metrics)
  return loss, metrics