def zeros_like_padded_example(config): """Builds a VarMisuseExample containing only zeros. This can be useful to initialize model parameters, or do tests. Args: config: Configuration specifying the desired padding size. Returns: An "example" filled with zeros of the given size. """ return VarMisuseExample( input_graph=GraphBundleWithTokens( bundle=graph_bundle.zeros_like_padded_example( config.bundle_padding), tokens=sparse_operator.SparseCoordOperator( input_indices=np.zeros(shape=(config.max_tokens, 1), dtype=np.int32), output_indices=np.zeros(shape=(config.max_tokens, 1), dtype=np.int32), values=np.zeros(shape=(config.max_tokens, ), dtype=np.int32))), bug_node_index=-1, repair_node_mask=np.zeros( shape=(config.bundle_padding.static_max_metadata.num_nodes, ), dtype=np.float32), candidate_node_mask=np.zeros( shape=(config.bundle_padding.static_max_metadata.num_nodes, ), dtype=np.float32), unique_candidate_operator=sparse_operator.SparseCoordOperator( input_indices=np.zeros(shape=(config.max_tokens, 1), dtype=np.int32), output_indices=np.zeros(shape=(config.max_tokens, 1), dtype=np.int32), values=np.zeros(shape=(config.max_tokens, ), dtype=np.float32)), repair_id=0)
def convert_graph_with_edges( graph, edges, builder, ): """Convert a graph with edges into an GraphBundle. The order of nodes in the returned example is guaranteed to match the order of the keys in `graph`. Args: graph: Graph to encode. edges: List of (source, dest, edge_type) pairs to add to the non-automaton graph representation (i.e. GNN edges or targets). builder: Builder to use to convert the graph. Returns: Encoded example. """ # Encode the graph. encoded_graph, encoded_metadata = builder.encode_graph(graph, as_jax=False) # Look up node types. node_types = [] for node_info in graph.values(): node_types.append(builder.node_type_to_index[node_info.node_type]) node_types = np.array(node_types, dtype=np.int32) # Build the indices for the edges. if edges: src_dest_pairs = [] edge_types = [] id_to_index_map = {node_id: i for i, node_id in enumerate(graph)} for src_id, dest_id, edge_type in edges: src_idx = id_to_index_map[src_id] dest_idx = id_to_index_map[dest_id] src_dest_pairs.append((src_idx, dest_idx)) edge_types.append(edge_type) edge_operator = sparse_operator.SparseCoordOperator( input_indices=np.array(edge_types, dtype=np.int32).reshape([-1, 1]), output_indices=np.array(src_dest_pairs, dtype=np.int32), values=np.ones([len(edges)], dtype=np.int32), ) else: # Handle case where there are no edges. edge_operator = sparse_operator.SparseCoordOperator( input_indices=np.empty([0, 1], dtype=np.int32), output_indices=np.empty([0, 2], dtype=np.int32), values=np.empty([0], dtype=np.int32), ) return GraphBundle( automaton_graph=encoded_graph, graph_metadata=encoded_metadata, node_types=node_types, edges=edge_operator)
def zeros_like_padded_example(config): """Build an GraphBundle containing only zeros. This can be useful to initialize model parameters, or do tests. Args: config: Configuration specifying the desired padding size. Returns: An "example" filled with zeros of the given size. """ return GraphBundle( automaton_graph=automaton_builder.EncodedGraph( initial_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=np.zeros(shape=(config.max_initial_transitions, 1), dtype=np.int32), output_indices=np.zeros(shape=(config.max_initial_transitions, 2), dtype=np.int32), values=np.zeros(shape=(config.max_initial_transitions, ), dtype=np.float32), ), initial_to_special=np.zeros( shape=(config.static_max_metadata.num_nodes, ), dtype=np.int32), in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=np.zeros(shape=(config.max_in_tagged_transitions, 1), dtype=np.int32), output_indices=np.zeros( shape=(config.max_in_tagged_transitions, 2), dtype=np.int32), values=np.zeros(shape=(config.max_in_tagged_transitions, ), dtype=np.float32), ), in_tagged_to_special=np.zeros( shape=(config.static_max_metadata.num_input_tagged_nodes, ), dtype=np.int32), in_tagged_node_indices=np.zeros( shape=(config.static_max_metadata.num_input_tagged_nodes, ), dtype=np.int32), ), graph_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=0, num_input_tagged_nodes=0), node_types=np.zeros(shape=(config.static_max_metadata.num_nodes, ), dtype=np.int32), edges=sparse_operator.SparseCoordOperator( input_indices=np.zeros(shape=(config.max_edges, 1), dtype=np.int32), output_indices=np.zeros(shape=(config.max_edges, 2), dtype=np.int32), values=np.zeros(shape=(config.max_edges, ), dtype=np.int32), ), )
def test_variants_from_edges(self): example = graph_bundle.zeros_like_padded_example( graph_bundle.PaddingConfig( static_max_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=5, num_input_tagged_nodes=0), max_initial_transitions=0, max_in_tagged_transitions=0, max_edges=8)) example = dataclasses.replace( example, graph_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=4, num_input_tagged_nodes=0), edges=sparse_operator.SparseCoordOperator( input_indices=jnp.array([[0], [0], [0], [1], [1], [2], [0], [0]]), output_indices=jnp.array([[1, 2], [2, 3], [3, 0], [2, 0], [0, 2], [0, 3], [0, 0], [0, 0]]), values=jnp.array([1, 1, 1, 1, 1, 1, 0, 0]))) weights = edge_supervision_models.variants_from_edges( example, automaton_builder.EncodedGraphMetadata( num_nodes=5, num_input_tagged_nodes=0), variant_edge_type_indices=[2, 0], num_edge_types=3) expected = np.array([ [[1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0]], [[1, 0, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], [[1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 0, 1]], [[0, 0, 1], [1, 0, 0], [1, 0, 0], [1, 0, 0]], ], np.float32) # Only assert on the non-padded part. np.testing.assert_allclose(weights[:4, :4], expected)
def test_sparse_coord_operator_is_a_pytree(self): """Tests that jax tree operations work on SparseCoordOperators.""" op_with_zeros = jax.tree_map(jnp.zeros_like, self.operator) expected = sparse_operator.SparseCoordOperator( input_indices=jnp.zeros([5, 2], dtype=int), output_indices=jnp.zeros([5, 1], dtype=int), values=jnp.zeros([5], dtype=jnp.float32), ) jax.tree_multimap(np.testing.assert_allclose, op_with_zeros, expected)
def _setup_edges(self): """Set up an edge operator with type indices.""" edges = sparse_operator.SparseCoordOperator( input_indices=jnp.array([[0], [0], [1], [2], [0], [1], [0]]), output_indices=jnp.array([[0, 1], [1, 2], [2, 3], [2, 0], [0, 2], [1, 3], [0, 0]]), values=jnp.array([1, 1, 1, 1, 1, 1, 0])) forward_edge_type_indices = [2, 0] reverse_edge_type_indices = [0] return edges, forward_edge_type_indices, reverse_edge_type_indices
def test_bad_args(self, in_shape, out_shape, in_dims, out_dims, expected_error): operator = sparse_operator.SparseCoordOperator( input_indices=jnp.zeros([5, 2], dtype=int), output_indices=jnp.zeros([5, 2], dtype=int), values=jnp.zeros([5], dtype=jnp.float32), ) with self.assertRaisesRegex(ValueError, re.escape(expected_error)): operator.apply_add(jnp.zeros(in_shape), jnp.zeros(out_shape), in_dims, out_dims)
def test_pad_nonzeros(self): operator = sparse_operator.SparseCoordOperator( input_indices=jnp.arange(10).reshape([5, 2]), output_indices=jnp.arange(5).reshape([5, 1]), values=jnp.arange(5, dtype=jnp.float32), ) padded_operator = operator.pad_nonzeros(7) apply_orig = operator.apply_add( jnp.arange(100).reshape([10, 10]), jnp.zeros([5])) apply_padded = padded_operator.apply_add( jnp.arange(100).reshape([10, 10]), jnp.zeros([5])) np.testing.assert_allclose(apply_orig, apply_padded)
def setUp(self): super(SparseOperatorTest, self).setUp() self.operator = sparse_operator.SparseCoordOperator( input_indices=jnp.array([ [0, 0], [1, 0], [1, 1], [1, 0], [1, 1], ]), output_indices=jnp.array([[0], [1], [2], [3], [2]]), values=jnp.array([1., 10., 100., 1000., 10000.]), )
def test_TokenOperatorNodeEmbedding_shapes(self, bottleneck_dim): outs, _ = graph_layers.TokenOperatorNodeEmbedding.init( jax.random.PRNGKey(0), operator=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((NUM_TOKENS, 1), jnp.int32), output_indices=jnp.zeros((NUM_TOKENS, 1), jnp.int32), values=jnp.zeros((NUM_TOKENS, ), jnp.int32)), vocab_size=VOCAB_SIZE, num_nodes=NUM_NODES, embedding_dim=NODE_EMBEDDING_DIM, bottleneck_dim=bottleneck_dim) expected = jax.ShapeDtypeStruct((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32) self._check_shape_and_dtype(outs, expected)
def test_LearnableEdgeEmbeddings_shapes(self): outs, _ = graph_layers.LearnableEdgeEmbeddings.init( jax.random.PRNGKey(0), edges=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((NUM_EDGES, 1), jnp.int32), output_indices=jnp.zeros((NUM_EDGES, 2), jnp.int32), values=jnp.zeros((NUM_EDGES,), jnp.int32)), num_nodes=NUM_NODES, num_edge_types=NUM_EDGE_TYPES, forward_edge_type_indices=[0, 2], reverse_edge_type_indices=[3, 1], embedding_dim=EDGE_EMBEDDING_DIM) expected = jax.ShapeDtypeStruct((NUM_NODES, NUM_NODES, EDGE_EMBEDDING_DIM), jnp.float32) self._check_shape_and_dtype(outs, expected)
def _make_example(self): example = graph_bundle.zeros_like_padded_example( graph_bundle.PaddingConfig( static_max_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=5, num_input_tagged_nodes=0), max_initial_transitions=0, max_in_tagged_transitions=0, max_edges=8)) example = dataclasses.replace( example, graph_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=4, num_input_tagged_nodes=0), edges=sparse_operator.SparseCoordOperator( input_indices=jnp.array([[0], [0], [0], [0], [1], [2], [2], [0]]), output_indices=jnp.array([[1, 2], [2, 3], [2, 2], [3, 0], [0, 2], [0, 3], [0, 0], [0, 0]]), values=jnp.array([1, 1, 1, 1, 1, 1, 0, 0]))) return example
def test_automaton_layer_abstract_init(self, shared, variant_weights, use_gate, estimator_type, **kwargs): # Create a simple schema and empty encoded graph. schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("ai_0")], out_edges=[graph_types.OutEdgeType("ao_0") ]), } builder = automaton_builder.AutomatonBuilder(schema) encoded_graph = automaton_builder.EncodedGraph( initial_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((128, 1), dtype=jnp.int32), output_indices=jnp.zeros((128, 2), dtype=jnp.int32), values=jnp.zeros((128, ), dtype=jnp.float32), ), initial_to_special=jnp.zeros((32, ), dtype=jnp.int32), in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((128, 1), dtype=jnp.int32), output_indices=jnp.zeros((128, 2), dtype=jnp.int32), values=jnp.zeros((128, ), dtype=jnp.float32), ), in_tagged_to_special=jnp.zeros((64, ), dtype=jnp.int32), in_tagged_node_indices=jnp.zeros((64, ), dtype=jnp.int32), ) # Make sure the layer can be initialized and applied within a model. # This model is fairly simple; it just pretends that the encoded graph and # variants depend on the input. class TestModel(flax.deprecated.nn.Module): def apply(self, dummy_ignored): abstract_encoded_graph = jax.tree_map( lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph) abstract_variant_weights = jax.tree_map( lambda y: jax.lax.tie_in(dummy_ignored, y), variant_weights()) return automaton_layer.FiniteStateGraphAutomaton( encoded_graph=abstract_encoded_graph, variant_weights=abstract_variant_weights, dynamic_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=32, num_input_tagged_nodes=64), static_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=32, num_input_tagged_nodes=64), builder=builder, num_out_edges=3, num_intermediate_states=4, share_states_across_edges=shared, use_gate_parameterization=use_gate, estimator_type=estimator_type, name="the_layer", **kwargs) with side_outputs.collect_side_outputs() as side: with flax.deprecated.nn.stochastic(jax.random.PRNGKey(0)): # For some reason init_by_shape breaks the custom_vjp? abstract_out, unused_params = TestModel.init( jax.random.PRNGKey(1234), jnp.zeros((), jnp.float32)) del unused_params self.assertEqual(abstract_out.shape, (3, 32, 32)) if estimator_type == "one_sample": log_prob_key = "/the_layer/one_sample_log_prob_per_edge_per_node" self.assertIn(log_prob_key, side) self.assertEqual(side[log_prob_key].shape, (3, 32))
def apply(self, edges, num_nodes, num_edge_types, forward_edge_type_indices, reverse_edge_type_indices, embedding_dim=gin.REQUIRED): """Compute multi-hot binary edge embeddings. Args: edges: Edges, represented as a sparse operator from a vector indexed by edge type to an adjacency matrix. num_nodes: Number of nodes in the graph. num_edge_types: How many total edge types there are. forward_edge_type_indices: Indices of the edge types to embed in the forward direction. reverse_edge_type_indices: Indices of the edge types to embed in the reverse direction. embedding_dim: Dimension of the learned embedding. Returns: <float32[num_nodes, num_nodes, embedding_dim]> embedding array """ total_edge_count = (len(forward_edge_type_indices) + len(reverse_edge_type_indices)) edge_type_embeddings = self.param( "edge_type_embeddings", shape=(total_edge_count, embedding_dim), initializer=initializers.variance_scaling(1.0, "fan_out", "truncated_normal")) # Build new operators that include only our desired edge types by mapping # the `num_edge_types` to `total_edge_count`. (forward_index_map, forward_values, reverse_index_map, reverse_values) = (_forward_and_reverse_subsets( num_edge_types, forward_edge_type_indices, reverse_edge_type_indices)) e_in_flat = edges.input_indices.squeeze(1) forward_operator = sparse_operator.SparseCoordOperator( input_indices=forward_index_map[edges.input_indices], output_indices=edges.output_indices, values=edges.values * forward_values[e_in_flat]) reverse_operator = sparse_operator.SparseCoordOperator( input_indices=reverse_index_map[edges.input_indices], output_indices=edges.output_indices, values=edges.values * reverse_values[e_in_flat]) # Apply our adjusted operators, gathering from our extended embeddings # array. result = jnp.zeros([embedding_dim, num_nodes, num_nodes]) result = forward_operator.apply_add(in_array=edge_type_embeddings, out_array=result, in_dims=[0], out_dims=[1, 2]) result = reverse_operator.apply_add(in_array=edge_type_embeddings, out_array=result, in_dims=[0], out_dims=[2, 1]) # Force it to actually be materialized as # [(batch,) embedding_dim, num_nodes, num_nodes] to reduce downstream # effects of the bad padding required by the above. result = jax_util.force_physical_layout(result) return result.transpose((1, 2, 0))
def test_loss_fn(self): mock_example = example_definition.VarMisuseExample( input_graph=None, bug_node_index=2, repair_node_mask=jnp.array([0., 1., 1., 0.5, 0.]), candidate_node_mask=None, unique_candidate_operator=sparse_operator.SparseCoordOperator( input_indices=jnp.array([0, 1, 2, 3, 3, 4])[:, None], output_indices=jnp.array([0, 1, 1, 1, 2, 3])[:, None], values=jnp.array([1, 1, 1, 0.5, 0.5, 1])), repair_id=1) mock_metadata = object() @flax.nn.module def mock_model_def(example, metadata): # Check that we get the right inputs. self.assertIs(example, mock_example) self.assertIs(metadata, mock_metadata) # Register a side output side_outputs.SideOutput(jnp.array(.1234), name="test_penalty") # Make sure we can generate an rng key with flax. _ = flax.nn.make_rng() return jnp.log( jnp.array([ [.0, .0, .0, .0, .0], [.1, .0, .0, .2, .0], [.0, .1, .2, .2, .1], # <- This row is the "correct" bug index. [.0, .0, .0, .0, .0], [.1, .0, .0, .0, .0], ])) with flax.nn.stochastic(jax.random.PRNGKey(0)): _, params = mock_model_def.init(jax.random.PRNGKey(0), mock_example, mock_metadata) mock_model = flax.nn.Model(mock_model_def, params) loss, metrics = train_var_misuse_lib.loss_fn( mock_model, (mock_example, jax.random.PRNGKey(0)), mock_metadata, regularization_weights={"penalty": 2}) np.testing.assert_allclose(metrics["nll/joint"], -np.log(0.4), atol=1e-7) np.testing.assert_allclose(metrics["side/test_penalty"], .1234, atol=1e-7) np.testing.assert_allclose(loss, -np.log(0.4) + 2 * .1234, atol=1e-7) np.testing.assert_allclose(metrics["nll/marginal_bug"], -np.log(0.6), atol=1e-7) np.testing.assert_allclose(metrics["nll/marginal_repair"], -np.log(0.5), atol=1e-7) np.testing.assert_allclose(metrics["nll/repair_given_bug"], -np.log(0.4 / 0.6), atol=1e-7) np.testing.assert_allclose(metrics["nll/bug_given_repair"], -np.log(0.4 / 0.5), atol=1e-7) np.testing.assert_allclose( metrics["inaccuracy/classification_overall"], 0) np.testing.assert_allclose( metrics["inaccuracy/classification_given_nobug"].numerator, 0) np.testing.assert_allclose( metrics["inaccuracy/classification_given_nobug"].denominator, 0) np.testing.assert_allclose( metrics["inaccuracy/classification_given_bug"].numerator, 0) np.testing.assert_allclose( metrics["inaccuracy/classification_given_bug"].denominator, 1)
def test_sample_loss_fn(self): example = self._make_example() example = dataclasses.replace( example, edges=sparse_operator.SparseCoordOperator( input_indices=jnp.array([[0], [0], [0], [0], [1], [2], [0], [0]]), output_indices=jnp.array([[1, 2], [2, 3], [2, 2], [3, 0], [0, 2], [0, 3], [0, 0], [0, 0]]), values=jnp.array([1, 1, 1, 1, 1, 1, 0, 0]))) @flax.nn.module def mock_model_def(example): del example side_outputs.SideOutput( -jnp.arange(5).astype("float32").reshape((1, 5)), name="one_sample_log_prob_per_edge_per_node") side_outputs.SideOutput(0.3, name="one_sample_reward_baseline") return model_util.safe_logit( jnp.array([ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], ])) _, params = mock_model_def.init(jax.random.PRNGKey(0), example) mock_model = flax.nn.Model(mock_model_def, params) _, _, _, loss, metrics = train_edge_supervision_lib.sample_loss_fn( mock_model, (example, jax.random.PRNGKey(0)), target_edge_index=0, num_edge_types=3, num_rollouts=1, leave_one_out_baseline=False) np.testing.assert_allclose(metrics["reward"], 0.75, rtol=1e-5) np.testing.assert_allclose(metrics["shifted_reward"], 0.75 - 0.3, rtol=1e-5) np.testing.assert_allclose(metrics["policy_log_prob"], -1.5, rtol=1e-5) np.testing.assert_allclose(metrics["learned_baseline"], 0.3, rtol=1e-5) np.testing.assert_allclose(metrics["baseline_penalty"], 0.001 * (0.75 * (0.7 * 0.7) + 0.25 * (0.3 * 0.3)), rtol=1e-5) np.testing.assert_allclose(metrics["reinforce_term"], (0 * 0.7 + 1 * 0.7 + 2 * 0.7 + 3 * -0.3) / 4, rtol=1e-5) np.testing.assert_allclose(loss, metrics["reinforce_term"] + metrics["baseline_penalty"], rtol=1e-5) (output_logits, targets, valid_mask, loss, metrics) = train_edge_supervision_lib.sample_loss_fn( mock_model, (example, jax.random.PRNGKey(0)), target_edge_index=0, num_edge_types=3, num_rollouts=20, leave_one_out_baseline=True) self.assertEqual(output_logits.shape, (5, 5)) self.assertEqual(targets.shape, (5, 5)) self.assertEqual(valid_mask.shape, (5, 5)) np.testing.assert_allclose(metrics["reward"], 0.75, rtol=1e-5) np.testing.assert_allclose(metrics["shifted_reward"], 0, rtol=1e-5) np.testing.assert_allclose(metrics["learned_baseline"], 0.3, rtol=1e-5) np.testing.assert_allclose(metrics["baseline_penalty"], 0.0, rtol=1e-5)