def zeros_like_padded_example(config):
    """Builds a VarMisuseExample containing only zeros.

  This can be useful to initialize model parameters, or do tests.

  Args:
    config: Configuration specifying the desired padding size.

  Returns:
    An "example" filled with zeros of the given size.
  """
    return VarMisuseExample(
        input_graph=GraphBundleWithTokens(
            bundle=graph_bundle.zeros_like_padded_example(
                config.bundle_padding),
            tokens=sparse_operator.SparseCoordOperator(
                input_indices=np.zeros(shape=(config.max_tokens, 1),
                                       dtype=np.int32),
                output_indices=np.zeros(shape=(config.max_tokens, 1),
                                        dtype=np.int32),
                values=np.zeros(shape=(config.max_tokens, ), dtype=np.int32))),
        bug_node_index=-1,
        repair_node_mask=np.zeros(
            shape=(config.bundle_padding.static_max_metadata.num_nodes, ),
            dtype=np.float32),
        candidate_node_mask=np.zeros(
            shape=(config.bundle_padding.static_max_metadata.num_nodes, ),
            dtype=np.float32),
        unique_candidate_operator=sparse_operator.SparseCoordOperator(
            input_indices=np.zeros(shape=(config.max_tokens, 1),
                                   dtype=np.int32),
            output_indices=np.zeros(shape=(config.max_tokens, 1),
                                    dtype=np.int32),
            values=np.zeros(shape=(config.max_tokens, ), dtype=np.float32)),
        repair_id=0)
Example #2
0
def convert_graph_with_edges(
    graph,
    edges,
    builder,
):
  """Convert a graph with edges into an GraphBundle.

  The order of nodes in the returned example is guaranteed to match the
  order of the keys in `graph`.

  Args:
    graph: Graph to encode.
    edges: List of (source, dest, edge_type) pairs to add to the non-automaton
      graph representation (i.e. GNN edges or targets).
    builder: Builder to use to convert the graph.

  Returns:
    Encoded example.
  """
  # Encode the graph.
  encoded_graph, encoded_metadata = builder.encode_graph(graph, as_jax=False)

  # Look up node types.
  node_types = []
  for node_info in graph.values():
    node_types.append(builder.node_type_to_index[node_info.node_type])
  node_types = np.array(node_types, dtype=np.int32)

  # Build the indices for the edges.
  if edges:
    src_dest_pairs = []
    edge_types = []
    id_to_index_map = {node_id: i for i, node_id in enumerate(graph)}
    for src_id, dest_id, edge_type in edges:
      src_idx = id_to_index_map[src_id]
      dest_idx = id_to_index_map[dest_id]
      src_dest_pairs.append((src_idx, dest_idx))
      edge_types.append(edge_type)

    edge_operator = sparse_operator.SparseCoordOperator(
        input_indices=np.array(edge_types, dtype=np.int32).reshape([-1, 1]),
        output_indices=np.array(src_dest_pairs, dtype=np.int32),
        values=np.ones([len(edges)], dtype=np.int32),
    )
  else:
    # Handle case where there are no edges.
    edge_operator = sparse_operator.SparseCoordOperator(
        input_indices=np.empty([0, 1], dtype=np.int32),
        output_indices=np.empty([0, 2], dtype=np.int32),
        values=np.empty([0], dtype=np.int32),
    )

  return GraphBundle(
      automaton_graph=encoded_graph,
      graph_metadata=encoded_metadata,
      node_types=node_types,
      edges=edge_operator)
def zeros_like_padded_example(config):
    """Build an GraphBundle containing only zeros.

  This can be useful to initialize model parameters, or do tests.

  Args:
    config: Configuration specifying the desired padding size.

  Returns:
    An "example" filled with zeros of the given size.
  """
    return GraphBundle(
        automaton_graph=automaton_builder.EncodedGraph(
            initial_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=np.zeros(shape=(config.max_initial_transitions,
                                              1),
                                       dtype=np.int32),
                output_indices=np.zeros(shape=(config.max_initial_transitions,
                                               2),
                                        dtype=np.int32),
                values=np.zeros(shape=(config.max_initial_transitions, ),
                                dtype=np.float32),
            ),
            initial_to_special=np.zeros(
                shape=(config.static_max_metadata.num_nodes, ),
                dtype=np.int32),
            in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=np.zeros(shape=(config.max_in_tagged_transitions,
                                              1),
                                       dtype=np.int32),
                output_indices=np.zeros(
                    shape=(config.max_in_tagged_transitions, 2),
                    dtype=np.int32),
                values=np.zeros(shape=(config.max_in_tagged_transitions, ),
                                dtype=np.float32),
            ),
            in_tagged_to_special=np.zeros(
                shape=(config.static_max_metadata.num_input_tagged_nodes, ),
                dtype=np.int32),
            in_tagged_node_indices=np.zeros(
                shape=(config.static_max_metadata.num_input_tagged_nodes, ),
                dtype=np.int32),
        ),
        graph_metadata=automaton_builder.EncodedGraphMetadata(
            num_nodes=0, num_input_tagged_nodes=0),
        node_types=np.zeros(shape=(config.static_max_metadata.num_nodes, ),
                            dtype=np.int32),
        edges=sparse_operator.SparseCoordOperator(
            input_indices=np.zeros(shape=(config.max_edges, 1),
                                   dtype=np.int32),
            output_indices=np.zeros(shape=(config.max_edges, 2),
                                    dtype=np.int32),
            values=np.zeros(shape=(config.max_edges, ), dtype=np.int32),
        ),
    )
  def test_variants_from_edges(self):
    example = graph_bundle.zeros_like_padded_example(
        graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=5, num_input_tagged_nodes=0),
            max_initial_transitions=0,
            max_in_tagged_transitions=0,
            max_edges=8))
    example = dataclasses.replace(
        example,
        graph_metadata=automaton_builder.EncodedGraphMetadata(
            num_nodes=4, num_input_tagged_nodes=0),
        edges=sparse_operator.SparseCoordOperator(
            input_indices=jnp.array([[0], [0], [0], [1], [1], [2], [0], [0]]),
            output_indices=jnp.array([[1, 2], [2, 3], [3, 0], [2, 0], [0, 2],
                                      [0, 3], [0, 0], [0, 0]]),
            values=jnp.array([1, 1, 1, 1, 1, 1, 0, 0])))

    weights = edge_supervision_models.variants_from_edges(
        example,
        automaton_builder.EncodedGraphMetadata(
            num_nodes=5, num_input_tagged_nodes=0),
        variant_edge_type_indices=[2, 0],
        num_edge_types=3)
    expected = np.array([
        [[1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0]],
        [[1, 0, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]],
        [[1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 0, 1]],
        [[0, 0, 1], [1, 0, 0], [1, 0, 0], [1, 0, 0]],
    ], np.float32)
    # Only assert on the non-padded part.
    np.testing.assert_allclose(weights[:4, :4], expected)
 def test_sparse_coord_operator_is_a_pytree(self):
     """Tests that jax tree operations work on SparseCoordOperators."""
     op_with_zeros = jax.tree_map(jnp.zeros_like, self.operator)
     expected = sparse_operator.SparseCoordOperator(
         input_indices=jnp.zeros([5, 2], dtype=int),
         output_indices=jnp.zeros([5, 1], dtype=int),
         values=jnp.zeros([5], dtype=jnp.float32),
     )
     jax.tree_multimap(np.testing.assert_allclose, op_with_zeros, expected)
 def _setup_edges(self):
     """Set up an edge operator with type indices."""
     edges = sparse_operator.SparseCoordOperator(
         input_indices=jnp.array([[0], [0], [1], [2], [0], [1], [0]]),
         output_indices=jnp.array([[0, 1], [1, 2], [2, 3], [2, 0], [0, 2],
                                   [1, 3], [0, 0]]),
         values=jnp.array([1, 1, 1, 1, 1, 1, 0]))
     forward_edge_type_indices = [2, 0]
     reverse_edge_type_indices = [0]
     return edges, forward_edge_type_indices, reverse_edge_type_indices
 def test_bad_args(self, in_shape, out_shape, in_dims, out_dims,
                   expected_error):
     operator = sparse_operator.SparseCoordOperator(
         input_indices=jnp.zeros([5, 2], dtype=int),
         output_indices=jnp.zeros([5, 2], dtype=int),
         values=jnp.zeros([5], dtype=jnp.float32),
     )
     with self.assertRaisesRegex(ValueError, re.escape(expected_error)):
         operator.apply_add(jnp.zeros(in_shape), jnp.zeros(out_shape),
                            in_dims, out_dims)
 def test_pad_nonzeros(self):
     operator = sparse_operator.SparseCoordOperator(
         input_indices=jnp.arange(10).reshape([5, 2]),
         output_indices=jnp.arange(5).reshape([5, 1]),
         values=jnp.arange(5, dtype=jnp.float32),
     )
     padded_operator = operator.pad_nonzeros(7)
     apply_orig = operator.apply_add(
         jnp.arange(100).reshape([10, 10]), jnp.zeros([5]))
     apply_padded = padded_operator.apply_add(
         jnp.arange(100).reshape([10, 10]), jnp.zeros([5]))
     np.testing.assert_allclose(apply_orig, apply_padded)
 def setUp(self):
     super(SparseOperatorTest, self).setUp()
     self.operator = sparse_operator.SparseCoordOperator(
         input_indices=jnp.array([
             [0, 0],
             [1, 0],
             [1, 1],
             [1, 0],
             [1, 1],
         ]),
         output_indices=jnp.array([[0], [1], [2], [3], [2]]),
         values=jnp.array([1., 10., 100., 1000., 10000.]),
     )
    def test_TokenOperatorNodeEmbedding_shapes(self, bottleneck_dim):
        outs, _ = graph_layers.TokenOperatorNodeEmbedding.init(
            jax.random.PRNGKey(0),
            operator=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((NUM_TOKENS, 1), jnp.int32),
                output_indices=jnp.zeros((NUM_TOKENS, 1), jnp.int32),
                values=jnp.zeros((NUM_TOKENS, ), jnp.int32)),
            vocab_size=VOCAB_SIZE,
            num_nodes=NUM_NODES,
            embedding_dim=NODE_EMBEDDING_DIM,
            bottleneck_dim=bottleneck_dim)

        expected = jax.ShapeDtypeStruct((NUM_NODES, NODE_EMBEDDING_DIM),
                                        jnp.float32)
        self._check_shape_and_dtype(outs, expected)
  def test_LearnableEdgeEmbeddings_shapes(self):
    outs, _ = graph_layers.LearnableEdgeEmbeddings.init(
        jax.random.PRNGKey(0),
        edges=sparse_operator.SparseCoordOperator(
            input_indices=jnp.zeros((NUM_EDGES, 1), jnp.int32),
            output_indices=jnp.zeros((NUM_EDGES, 2), jnp.int32),
            values=jnp.zeros((NUM_EDGES,), jnp.int32)),
        num_nodes=NUM_NODES,
        num_edge_types=NUM_EDGE_TYPES,
        forward_edge_type_indices=[0, 2],
        reverse_edge_type_indices=[3, 1],
        embedding_dim=EDGE_EMBEDDING_DIM)

    expected = jax.ShapeDtypeStruct((NUM_NODES, NUM_NODES, EDGE_EMBEDDING_DIM),
                                    jnp.float32)
    self._check_shape_and_dtype(outs, expected)
 def _make_example(self):
   example = graph_bundle.zeros_like_padded_example(
       graph_bundle.PaddingConfig(
           static_max_metadata=automaton_builder.EncodedGraphMetadata(
               num_nodes=5, num_input_tagged_nodes=0),
           max_initial_transitions=0,
           max_in_tagged_transitions=0,
           max_edges=8))
   example = dataclasses.replace(
       example,
       graph_metadata=automaton_builder.EncodedGraphMetadata(
           num_nodes=4, num_input_tagged_nodes=0),
       edges=sparse_operator.SparseCoordOperator(
           input_indices=jnp.array([[0], [0], [0], [0], [1], [2], [2], [0]]),
           output_indices=jnp.array([[1, 2], [2, 3], [2, 2], [3, 0], [0, 2],
                                     [0, 3], [0, 0], [0, 0]]),
           values=jnp.array([1, 1, 1, 1, 1, 1, 0, 0])))
   return example
    def test_automaton_layer_abstract_init(self, shared, variant_weights,
                                           use_gate, estimator_type, **kwargs):
        # Create a simple schema and empty encoded graph.
        schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("ai_0")],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
        }
        builder = automaton_builder.AutomatonBuilder(schema)
        encoded_graph = automaton_builder.EncodedGraph(
            initial_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            initial_to_special=jnp.zeros((32, ), dtype=jnp.int32),
            in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            in_tagged_to_special=jnp.zeros((64, ), dtype=jnp.int32),
            in_tagged_node_indices=jnp.zeros((64, ), dtype=jnp.int32),
        )

        # Make sure the layer can be initialized and applied within a model.
        # This model is fairly simple; it just pretends that the encoded graph and
        # variants depend on the input.
        class TestModel(flax.deprecated.nn.Module):
            def apply(self, dummy_ignored):
                abstract_encoded_graph = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph)
                abstract_variant_weights = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y),
                    variant_weights())
                return automaton_layer.FiniteStateGraphAutomaton(
                    encoded_graph=abstract_encoded_graph,
                    variant_weights=abstract_variant_weights,
                    dynamic_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    static_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    builder=builder,
                    num_out_edges=3,
                    num_intermediate_states=4,
                    share_states_across_edges=shared,
                    use_gate_parameterization=use_gate,
                    estimator_type=estimator_type,
                    name="the_layer",
                    **kwargs)

        with side_outputs.collect_side_outputs() as side:
            with flax.deprecated.nn.stochastic(jax.random.PRNGKey(0)):
                # For some reason init_by_shape breaks the custom_vjp?
                abstract_out, unused_params = TestModel.init(
                    jax.random.PRNGKey(1234), jnp.zeros((), jnp.float32))

        del unused_params
        self.assertEqual(abstract_out.shape, (3, 32, 32))

        if estimator_type == "one_sample":
            log_prob_key = "/the_layer/one_sample_log_prob_per_edge_per_node"
            self.assertIn(log_prob_key, side)
            self.assertEqual(side[log_prob_key].shape, (3, 32))
Example #14
0
    def apply(self,
              edges,
              num_nodes,
              num_edge_types,
              forward_edge_type_indices,
              reverse_edge_type_indices,
              embedding_dim=gin.REQUIRED):
        """Compute multi-hot binary edge embeddings.

    Args:
      edges: Edges, represented as a sparse operator from a vector indexed by
        edge type to an adjacency matrix.
      num_nodes: Number of nodes in the graph.
      num_edge_types: How many total edge types there are.
      forward_edge_type_indices: Indices of the edge types to embed in the
        forward direction.
      reverse_edge_type_indices: Indices of the edge types to embed in the
        reverse direction.
      embedding_dim: Dimension of the learned embedding.

    Returns:
      <float32[num_nodes, num_nodes, embedding_dim]> embedding array
    """
        total_edge_count = (len(forward_edge_type_indices) +
                            len(reverse_edge_type_indices))
        edge_type_embeddings = self.param(
            "edge_type_embeddings",
            shape=(total_edge_count, embedding_dim),
            initializer=initializers.variance_scaling(1.0, "fan_out",
                                                      "truncated_normal"))

        # Build new operators that include only our desired edge types by mapping
        # the `num_edge_types` to `total_edge_count`.
        (forward_index_map, forward_values, reverse_index_map,
         reverse_values) = (_forward_and_reverse_subsets(
             num_edge_types, forward_edge_type_indices,
             reverse_edge_type_indices))

        e_in_flat = edges.input_indices.squeeze(1)

        forward_operator = sparse_operator.SparseCoordOperator(
            input_indices=forward_index_map[edges.input_indices],
            output_indices=edges.output_indices,
            values=edges.values * forward_values[e_in_flat])

        reverse_operator = sparse_operator.SparseCoordOperator(
            input_indices=reverse_index_map[edges.input_indices],
            output_indices=edges.output_indices,
            values=edges.values * reverse_values[e_in_flat])

        # Apply our adjusted operators, gathering from our extended embeddings
        # array.
        result = jnp.zeros([embedding_dim, num_nodes, num_nodes])
        result = forward_operator.apply_add(in_array=edge_type_embeddings,
                                            out_array=result,
                                            in_dims=[0],
                                            out_dims=[1, 2])
        result = reverse_operator.apply_add(in_array=edge_type_embeddings,
                                            out_array=result,
                                            in_dims=[0],
                                            out_dims=[2, 1])

        # Force it to actually be materialized as
        # [(batch,) embedding_dim, num_nodes, num_nodes] to reduce downstream
        # effects of the bad padding required by the above.
        result = jax_util.force_physical_layout(result)

        return result.transpose((1, 2, 0))
Example #15
0
    def test_loss_fn(self):
        mock_example = example_definition.VarMisuseExample(
            input_graph=None,
            bug_node_index=2,
            repair_node_mask=jnp.array([0., 1., 1., 0.5, 0.]),
            candidate_node_mask=None,
            unique_candidate_operator=sparse_operator.SparseCoordOperator(
                input_indices=jnp.array([0, 1, 2, 3, 3, 4])[:, None],
                output_indices=jnp.array([0, 1, 1, 1, 2, 3])[:, None],
                values=jnp.array([1, 1, 1, 0.5, 0.5, 1])),
            repair_id=1)

        mock_metadata = object()

        @flax.nn.module
        def mock_model_def(example, metadata):
            # Check that we get the right inputs.
            self.assertIs(example, mock_example)
            self.assertIs(metadata, mock_metadata)

            # Register a side output
            side_outputs.SideOutput(jnp.array(.1234), name="test_penalty")

            # Make sure we can generate an rng key with flax.
            _ = flax.nn.make_rng()

            return jnp.log(
                jnp.array([
                    [.0, .0, .0, .0, .0],
                    [.1, .0, .0, .2, .0],
                    [.0, .1, .2, .2,
                     .1],  # <- This row is the "correct" bug index.
                    [.0, .0, .0, .0, .0],
                    [.1, .0, .0, .0, .0],
                ]))

        with flax.nn.stochastic(jax.random.PRNGKey(0)):
            _, params = mock_model_def.init(jax.random.PRNGKey(0),
                                            mock_example, mock_metadata)

        mock_model = flax.nn.Model(mock_model_def, params)

        loss, metrics = train_var_misuse_lib.loss_fn(
            mock_model, (mock_example, jax.random.PRNGKey(0)),
            mock_metadata,
            regularization_weights={"penalty": 2})

        np.testing.assert_allclose(metrics["nll/joint"],
                                   -np.log(0.4),
                                   atol=1e-7)
        np.testing.assert_allclose(metrics["side/test_penalty"],
                                   .1234,
                                   atol=1e-7)
        np.testing.assert_allclose(loss, -np.log(0.4) + 2 * .1234, atol=1e-7)

        np.testing.assert_allclose(metrics["nll/marginal_bug"],
                                   -np.log(0.6),
                                   atol=1e-7)
        np.testing.assert_allclose(metrics["nll/marginal_repair"],
                                   -np.log(0.5),
                                   atol=1e-7)
        np.testing.assert_allclose(metrics["nll/repair_given_bug"],
                                   -np.log(0.4 / 0.6),
                                   atol=1e-7)
        np.testing.assert_allclose(metrics["nll/bug_given_repair"],
                                   -np.log(0.4 / 0.5),
                                   atol=1e-7)
        np.testing.assert_allclose(
            metrics["inaccuracy/classification_overall"], 0)
        np.testing.assert_allclose(
            metrics["inaccuracy/classification_given_nobug"].numerator, 0)
        np.testing.assert_allclose(
            metrics["inaccuracy/classification_given_nobug"].denominator, 0)
        np.testing.assert_allclose(
            metrics["inaccuracy/classification_given_bug"].numerator, 0)
        np.testing.assert_allclose(
            metrics["inaccuracy/classification_given_bug"].denominator, 1)
    def test_sample_loss_fn(self):
        example = self._make_example()
        example = dataclasses.replace(
            example,
            edges=sparse_operator.SparseCoordOperator(
                input_indices=jnp.array([[0], [0], [0], [0], [1], [2], [0],
                                         [0]]),
                output_indices=jnp.array([[1, 2], [2, 3], [2, 2], [3, 0],
                                          [0, 2], [0, 3], [0, 0], [0, 0]]),
                values=jnp.array([1, 1, 1, 1, 1, 1, 0, 0])))

        @flax.nn.module
        def mock_model_def(example):
            del example
            side_outputs.SideOutput(
                -jnp.arange(5).astype("float32").reshape((1, 5)),
                name="one_sample_log_prob_per_edge_per_node")
            side_outputs.SideOutput(0.3, name="one_sample_reward_baseline")

            return model_util.safe_logit(
                jnp.array([
                    [0.0, 0.0, 0.0, 0.0, 0.0],
                    [0.0, 0.0, 1.0, 0.0, 0.0],
                    [0.0, 0.0, 0.0, 1.0, 0.0],
                    [0.0, 0.0, 0.0, 0.0, 0.0],
                    [0.0, 0.0, 0.0, 0.0, 0.0],
                ]))

        _, params = mock_model_def.init(jax.random.PRNGKey(0), example)
        mock_model = flax.nn.Model(mock_model_def, params)

        _, _, _, loss, metrics = train_edge_supervision_lib.sample_loss_fn(
            mock_model, (example, jax.random.PRNGKey(0)),
            target_edge_index=0,
            num_edge_types=3,
            num_rollouts=1,
            leave_one_out_baseline=False)

        np.testing.assert_allclose(metrics["reward"], 0.75, rtol=1e-5)
        np.testing.assert_allclose(metrics["shifted_reward"],
                                   0.75 - 0.3,
                                   rtol=1e-5)
        np.testing.assert_allclose(metrics["policy_log_prob"], -1.5, rtol=1e-5)
        np.testing.assert_allclose(metrics["learned_baseline"], 0.3, rtol=1e-5)
        np.testing.assert_allclose(metrics["baseline_penalty"],
                                   0.001 * (0.75 * (0.7 * 0.7) + 0.25 *
                                            (0.3 * 0.3)),
                                   rtol=1e-5)
        np.testing.assert_allclose(metrics["reinforce_term"],
                                   (0 * 0.7 + 1 * 0.7 + 2 * 0.7 + 3 * -0.3) /
                                   4,
                                   rtol=1e-5)

        np.testing.assert_allclose(loss,
                                   metrics["reinforce_term"] +
                                   metrics["baseline_penalty"],
                                   rtol=1e-5)

        (output_logits, targets, valid_mask, loss,
         metrics) = train_edge_supervision_lib.sample_loss_fn(
             mock_model, (example, jax.random.PRNGKey(0)),
             target_edge_index=0,
             num_edge_types=3,
             num_rollouts=20,
             leave_one_out_baseline=True)

        self.assertEqual(output_logits.shape, (5, 5))
        self.assertEqual(targets.shape, (5, 5))
        self.assertEqual(valid_mask.shape, (5, 5))

        np.testing.assert_allclose(metrics["reward"], 0.75, rtol=1e-5)
        np.testing.assert_allclose(metrics["shifted_reward"], 0, rtol=1e-5)
        np.testing.assert_allclose(metrics["learned_baseline"], 0.3, rtol=1e-5)
        np.testing.assert_allclose(metrics["baseline_penalty"], 0.0, rtol=1e-5)