def test_initial_routing_params_with_noise(self): builder = automaton_builder.AutomatonBuilder( self.build_simple_schema()) # Small amounts of noise shouldn't change parameters much initializer_kwargs = dict(num_fsm_states=3, num_variants=2, state_change_prob=0.2, move_prob=0.9) noiseless_params = builder.initialize_routing_params( key=None, noise_factor=0, **initializer_kwargs) eps_noise_params = builder.initialize_routing_params( key=jax.random.PRNGKey(1234), noise_factor=1e-6, **initializer_kwargs) np.testing.assert_allclose(noiseless_params.move, eps_noise_params.move, rtol=0.02) np.testing.assert_allclose(noiseless_params.special, eps_noise_params.special, rtol=0.02) # Even with more noise, should still be normalized noisy_params = builder.initialize_routing_params( key=jax.random.PRNGKey(1234), noise_factor=0.8, **initializer_kwargs) noisy_sums = builder.routing_reduce(noisy_params, "sum") np.testing.assert_allclose(noisy_sums.move, 1.0, rtol=1e-6) np.testing.assert_allclose(noisy_sums.special, 1.0, rtol=1e-6)
def test_constructor_information_removing_mappings(self): builder = automaton_builder.AutomatonBuilder( self.build_simple_schema()) # Check consistency of information-removing mappings with the # corresponding pairs of lists. for in_out_route_type in builder.in_out_route_types: in_route_type = automaton_builder.InRouteType( in_out_route_type.node_type, in_out_route_type.in_edge) self.assertEqual( builder.in_out_route_to_in_route[ builder.in_out_route_type_to_index[in_out_route_type]], builder.in_route_type_to_index[in_route_type]) for in_route_type in builder.in_route_types: node_type = graph_types.NodeType(in_route_type.node_type) self.assertEqual( builder.in_route_to_node_type[ builder.in_route_type_to_index[in_route_type]], builder.node_type_to_index[node_type]) for in_out_route_type in builder.in_out_route_types: in_route_type = automaton_builder.InRouteType( in_out_route_type.node_type, in_out_route_type.in_edge) self.assertEqual( builder.in_out_route_to_in_route[ builder.in_out_route_type_to_index[in_out_route_type]], builder.in_route_type_to_index[in_route_type])
def initialize_routing_gates(self): """Just make sure that we can initialize routing gates.""" builder = automaton_builder.AutomatonBuilder( self.build_simple_schema()) # Noiseless noiseless_gates = builder.initialize_routing_gates(key=None, logistic_noise=0, num_fsm_states=3, num_variants=2) self.assertEqual(noiseless_gates.move_gates.shape, (2, len(builder.in_out_route_types), 3, 3)) self.assertEqual(noiseless_gates.accept_gates.shape, (2, len(builder.in_route_types), 3)) self.assertEqual(noiseless_gates.backtracKkgates.shape, (2, len(builder.in_route_types), 3)) # Perturbed noisy_gates = builder.initialize_routing_gates( key=jax.random.PRNGKey(0), logistic_noise=0.2, num_fsm_states=3, num_variants=2) self.assertEqual(noisy_gates.move_gates.shape, (2, len(builder.in_out_route_types), 3, 3)) self.assertEqual(noisy_gates.accept_gates.shape, (2, len(builder.in_route_types), 3)) self.assertEqual(noisy_gates.backtracKkgates.shape, (2, len(builder.in_route_types), 3))
def test_one_node_particle_estimate_padding(self): schema, graph = self.build_doubly_linked_list_graph(4) builder = automaton_builder.AutomatonBuilder(schema) enc_graph, enc_meta = builder.encode_graph(graph) enc_graph_padded = automaton_builder.EncodedGraph( initial_to_in_tagged=enc_graph.initial_to_in_tagged.pad_nonzeros(64), initial_to_special=jax_util.pad_to(enc_graph.initial_to_special, 64), in_tagged_to_in_tagged=( enc_graph.in_tagged_to_in_tagged.pad_nonzeros(64)), in_tagged_to_special=(jax_util.pad_to(enc_graph.in_tagged_to_special, 64)), in_tagged_node_indices=(jax_util.pad_to( enc_graph.in_tagged_node_indices, 64))) enc_meta_padded = automaton_builder.EncodedGraphMetadata( num_nodes=64, num_input_tagged_nodes=64) variant_weights = jnp.full([64, 5], 0.2) routing_params = automaton_builder.RoutingParams( move=jnp.full([5, 6, 2, 2], 0.2), special=jnp.full([5, 3, 2, 3], 0.2)) tmat = builder.build_transition_matrix(routing_params, enc_graph_padded, enc_meta_padded) outs = automaton_sampling.one_node_particle_estimate( builder, tmat, variant_weights, start_machine_state=jnp.array([1., 0.]), node_index=0, steps=100, num_rollouts=100, max_possible_transitions=2, num_valid_nodes=enc_meta.num_nodes, rng=jax.random.PRNGKey(0)) self.assertEqual(outs.shape, (64,)) self.assertTrue(jnp.all(outs[:enc_meta.num_nodes] > 0)) self.assertTrue(jnp.all(outs[enc_meta.num_nodes:] == 0))
def test_initial_routing_params_noiseless(self): schema = self.build_simple_schema() builder = automaton_builder.AutomatonBuilder(schema) routing_params = builder.initialize_routing_params( key=None, num_fsm_states=3, num_variants=2, state_change_prob=0.2, move_prob=0.9, noise_factor=0) outgoing_count = np.array([ len(schema[in_out_route.node_type].out_edges) for in_out_route in builder.in_out_route_types ])[None, :, None] all_same_state_moves = routing_params.move[:, :, np.arange(3), np.arange(3)] expected = np.broadcast_to(0.9 * 0.8 / outgoing_count, all_same_state_moves.shape) np.testing.assert_allclose(all_same_state_moves, expected) state_1 = [] state_2 = [] for i in range(3): for j in range(3): if i != j: state_1.append(i) state_2.append(j) all_different_state_moves = routing_params.move[:, :, state_1, state_2] expected = np.broadcast_to(0.9 * 0.2 / (2 * outgoing_count), all_different_state_moves.shape) np.testing.assert_allclose(all_different_state_moves, expected) np.testing.assert_allclose(routing_params.special, 0.1 / 3)
def test_routing_gates_to_probs(self): builder = automaton_builder.AutomatonBuilder(self.build_simple_schema()) # [variants, in_out_routes, fsm_states, fsm_states] # [variants, in_routes, fsm_states] move_gates = np.full([3, len(builder.in_out_route_types), 2, 2], 0.5) accept_gates = np.full([3, len(builder.in_route_types), 2], 0.5) backtrack_gates = np.full([3, len(builder.in_route_types), 2], 0.5) # Set one distribution to sum to more than 1. idx_d1_move1 = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_0"))] move_gates[0, idx_d1_move1, 0, :] = [.2, .3] idx_d1_move2 = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_1"))] move_gates[0, idx_d1_move2, 0, :] = [.4, .5] idx_d1_special = builder.in_route_type_to_index[ automaton_builder.InRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"))] accept_gates[0, idx_d1_special, 0] = .6 backtrack_gates[0, idx_d1_special, 0] = .3 # Set another to sum to less than 1. idx_d2_move = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"), graph_types.OutEdgeType("ao_0"))] move_gates[2, idx_d2_move, 1, :] = [.1, .2] idx_d2_special = builder.in_route_type_to_index[ automaton_builder.InRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"))] accept_gates[2, idx_d2_special, 1] = .3 backtrack_gates[2, idx_d2_special, 1] = .75 routing_gates = automaton_builder.RoutingGateParams( move_gates=jax.scipy.special.logit(move_gates), accept_gates=jax.scipy.special.logit(accept_gates), backtrack_gates=jax.scipy.special.logit(backtrack_gates)) routing_probs = builder.routing_gates_to_probs(routing_gates) # Check probabilities for first distribution: should divide evenly. np.testing.assert_allclose(routing_probs.move[0, idx_d1_move1, 0, :], np.array([.2, .3]) / 2.0) np.testing.assert_allclose(routing_probs.move[0, idx_d1_move2, 0, :], np.array([.4, .5]) / 2.0) np.testing.assert_allclose(routing_probs.special[0, idx_d1_special, 0, :], np.array([.6, 0, 0]) / 2.0) # Check probabilities for second distribution: should assign remainder to # backtrack and fail. np.testing.assert_allclose(routing_probs.move[2, idx_d2_move, 1, :], np.array([.1, .2])) np.testing.assert_allclose(routing_probs.special[2, idx_d2_special, 1, :], np.array([.3, .3, .1]))
def test_constructor_actions_nodes_routes(self): builder = automaton_builder.AutomatonBuilder( self.build_simple_schema(), with_backtrack=False, with_fail=True) self.assertEqual( set(builder.special_actions), { automaton_builder.SpecialActions.FINISH, automaton_builder.SpecialActions.FAIL }) self.assertEqual( set(builder.node_types), {graph_types.NodeType("a"), graph_types.NodeType("b")}) self.assertEqual( set(builder.in_route_types), { automaton_builder.InRouteType( graph_types.NodeType("a"), automaton_builder.SOURCE_INITIAL), automaton_builder.InRouteType(graph_types.NodeType("a"), graph_types.InEdgeType("ai_0")), automaton_builder.InRouteType(graph_types.NodeType("a"), graph_types.InEdgeType("ai_1")), automaton_builder.InRouteType( graph_types.NodeType("b"), automaton_builder.SOURCE_INITIAL), automaton_builder.InRouteType(graph_types.NodeType("b"), graph_types.InEdgeType("bi_0")), }) self.assertEqual( set(builder.in_out_route_types), { automaton_builder.InOutRouteType( graph_types.NodeType("a"), automaton_builder.SOURCE_INITIAL, graph_types.OutEdgeType("ao_0")), automaton_builder.InOutRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"), graph_types.OutEdgeType("ao_0")), automaton_builder.InOutRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_1"), graph_types.OutEdgeType("ao_0")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), automaton_builder.SOURCE_INITIAL, graph_types.OutEdgeType("bo_0")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), automaton_builder.SOURCE_INITIAL, graph_types.OutEdgeType("bo_1")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_0")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_1")), })
def test_transition_sentinel_integers(self): """Test that the transition matrix puts each element in the right place.""" schema, test_graph = self.build_loop_graph() builder = automaton_builder.AutomatonBuilder(schema) encoded_graph, graph_meta = builder.encode_graph(test_graph, as_jax=False) # Apply to some sentinel integers to check correct indexing (with only one # variant and state, since indexing doesn't use those) # Each integer is of the form XYZ where # X = {a:1, b:2}[node_type] # Y = {initial:9, i0:0, i1:1}[in_edge] # Z = {o0:0, o1:1, o2:2, finish:3, backtrack:4, fail:5}[action] sentinel_routing_params = automaton_builder.RoutingParams( move=jnp.array([[100, 101, 110, 111, 190, 191], [200, 201, 202, 290, 291, 292]]).reshape( (1, 12, 1, 1)), special=jnp.array([ [103, 104, 105], [113, 114, 115], [193, 194, 195], [203, 204, 205], [293, 294, 295], ]).reshape((1, 5, 1, 3))) range_transition_matrix = builder.build_transition_matrix( sentinel_routing_params, encoded_graph, graph_meta, ).concatenated_transitions() self.assertEqual(range_transition_matrix.shape, (1, 4 + 6, 1, 6 * 1 + 3)) # pyformat: disable # pylint: disable=bad-continuation,bad-whitespace,g-inline-comment-too-close expected = np.array([ # a0i0 a0i1 a1i0 a1i1 b0i0 b1i0 specials < next [ # | | | | | | |---------| current V [[0, 0, 191, 0, 0, 190, 193, 194, 195]], # ┬ a0 [[0, 190, 0, 0, 191, 0, 193, 194, 195]], # | a1 [[0, 0, 0, 290, 292 / 2, 291 + 292 / 2, 293, 294, 295]], # | b0 [[291, 0, 0, 0, 290 + 292 / 2, 292 / 2, 293, 294, 295]], # ┴ b1 [[0, 0, 101, 0, 0, 100, 103, 104, 105]], # ┬ a0i0 [[0, 0, 111, 0, 0, 110, 113, 114, 115]], # | a0i1 [[0, 100, 0, 0, 101, 0, 103, 104, 105]], # | a1i0 [[0, 110, 0, 0, 111, 0, 113, 114, 115]], # | a1i1 [[0, 0, 0, 200, 202 / 2, 201 + 202 / 2, 203, 204, 205]], # | b0i0 [[201, 0, 0, 0, 200 + 202 / 2, 202 / 2, 203, 204, 205]], # ┴ b0i1 ] ]) # pyformat: enable # pylint: enable=bad-continuation,bad-whitespace,g-inline-comment-too-close np.testing.assert_allclose(range_transition_matrix, expected)
def __post_init__(self): """Populates non-init fields based on `ast_spec`.""" self.schema = generic_ast_graphs.build_ast_graph_schema(self.ast_spec) self.edge_types = sorted({ graph_edge_util.SAME_IDENTIFIER_EDGE_TYPE, *graph_edge_util.PROGRAM_GRAPH_EDGE_TYPES, *graph_edge_util.schema_edge_types(self.schema), *graph_edge_util.nth_child_edge_types(EDGE_NTH_CHILD_MAX), }) self.builder = automaton_builder.AutomatonBuilder(self.schema)
def automaton_model(example, graph_metadata, edge_types_to_indices, variant_edge_types=(), platt_scale=False, with_backtrack=True): """Automaton-based module for edge supervision task. Args: example: Example to run the automaton on. graph_metadata: Statically-known metadata about the graph size. If encoded_graph is padded, this should reflect the padded size, not the original size. edge_types_to_indices: Mapping from edge type names to edge type indices. variant_edge_types: Edge types to use as variants. Assumes without checking that the given variants are mutually exclusive (at most one edge of one of these types exists between any pair of nodes). platt_scale: Whether to scale and shift the logits produced by the automaton. This can be viewed as a form of Platt scaling applied to the automaton logits. If True, this allows the model's output probabilities to sum to more than 1, so that it can express one-to-many relations. with_backtrack: Whether the automaton can restart the search as an action. Returns: <float32[num_nodes, num_nodes]> matrix of binary logits for a weighted adjacency matrix corresponding to the predicted output edges. """ if variant_edge_types: variant_edge_type_indices = [ edge_types_to_indices[type_str] for type_str in variant_edge_types ] num_edge_types = len(edge_types_to_indices) variant_weights = variants_from_edges(example, graph_metadata, variant_edge_type_indices, num_edge_types) else: variant_weights = None absorbing_probs = automaton_layer.FiniteStateGraphAutomaton( encoded_graph=example.automaton_graph, variant_weights=variant_weights, static_metadata=graph_metadata, dynamic_metadata=example.graph_metadata, builder=automaton_builder.AutomatonBuilder( py_ast_graphs.SCHEMA, with_backtrack=with_backtrack), num_out_edges=1, share_states_across_edges=True).squeeze(axis=0) logits = model_util.safe_logit(absorbing_probs) if platt_scale: logits = model_util.ScaleAndShift(logits) return logits
def test_constructor_inverse_mappings(self): builder = automaton_builder.AutomatonBuilder(self.build_simple_schema()) # Mappings should be inverses of the corresponding lists for i, node_type in enumerate(builder.node_types): self.assertEqual(builder.node_type_to_index[node_type], i) for i, in_route_type in enumerate(builder.in_route_types): self.assertEqual(builder.in_route_type_to_index[in_route_type], i) for i, in_out_route_type in enumerate(builder.in_out_route_types): self.assertEqual(builder.in_out_route_type_to_index[in_out_route_type], i)
def test_transition_all_ones(self): """Test the transition matrix of an all-ones routing parameter vector.""" schema, test_graph = self.build_loop_graph() builder = automaton_builder.AutomatonBuilder(schema) encoded_graph, graph_meta = builder.encode_graph(test_graph, as_jax=False) # The transition matrix for an all-ones routing params should be a # (weighted) directed adjacency matrix. We use 3 variants, 2 states. ones_routing_params = automaton_builder.RoutingParams( move=jnp.ones([3, 12, 2, 2]), special=jnp.ones([3, 5, 2, 3])) ones_transition_matrix = builder.build_transition_matrix( ones_routing_params, encoded_graph, graph_meta, ).concatenated_transitions() self.assertEqual(ones_transition_matrix.shape, (3, 4 + 6, 2, 6 * 2 + 3)) # pyformat: disable # pylint: disable=bad-continuation,bad-whitespace,g-inline-comment-too-close expected = np.array([ # a0i0 a0i1 a1i0 a1i1 b0i0 b1i0 specials < next [ #|------|-------|-------|-------|-------|-------| |--------| current V [[ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1], # ┬ a0 [ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1]], # | [[ 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1], # | a1 [ 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1]], # | [[ 0, 0, 0, 0, 0, 0, 1, 1,0.5,0.5,1.5,1.5, 1, 1, 1], # | b0 [ 0, 0, 0, 0, 0, 0, 1, 1,0.5,0.5,1.5,1.5, 1, 1, 1]], # | [[ 1, 1, 0, 0, 0, 0, 0, 0,1.5,1.5,0.5,0.5, 1, 1, 1], # | b1 [ 1, 1, 0, 0, 0, 0, 0, 0,1.5,1.5,0.5,0.5, 1, 1, 1]], # ┴ [[ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1], # ┬ a0i0 [ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1]], # | [[ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1], # | a0i1 [ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1]], # | [[ 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1], # | a1i0 [ 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1]], # | [[ 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1], # | a1i1 [ 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1]], # | [[ 0, 0, 0, 0, 0, 0, 1, 1,0.5,0.5,1.5,1.5, 1, 1, 1], # | b0i0 [ 0, 0, 0, 0, 0, 0, 1, 1,0.5,0.5,1.5,1.5, 1, 1, 1]], # | [[ 1, 1, 0, 0, 0, 0, 0, 0,1.5,1.5,0.5,0.5, 1, 1, 1], # | b0i1 [ 1, 1, 0, 0, 0, 0, 0, 0,1.5,1.5,0.5,0.5, 1, 1, 1]], # ┴ ] ] * 3) # pyformat: enable # pylint: enable=bad-continuation,bad-whitespace,g-inline-comment-too-close np.testing.assert_allclose(ones_transition_matrix, expected)
def test_graph_encoding_size(self): """Test the size of the encoded graph.""" schema, test_graph = self.build_loop_graph() builder = automaton_builder.AutomatonBuilder(schema) encoded_graph, graph_meta = builder.encode_graph(test_graph, as_jax=False) # Graph metadata should match our graph's actual size self.assertEqual(graph_meta.num_nodes, 4) self.assertEqual(graph_meta.num_input_tagged_nodes, 6) # Nonzero entries should match the number of possible transitions # Initial transition counts each NODE once, so each A node has 2 and each # B node has 4 outgoing transitions self.assertEqual(encoded_graph.initial_to_in_tagged.values.shape[0], 12) # Normal transitions count each input-tagged node once, so each A node has # 2*2=4 and each B node has 1*4=4 outgoing transitions self.assertEqual(encoded_graph.in_tagged_to_in_tagged.values.shape[0], 16)
def test_all_nodes_absorbing_solve_explicit_conv(self): schema, graph = self.build_doubly_linked_list_graph(4) builder = automaton_builder.AutomatonBuilder(schema) enc_graph, enc_meta = builder.encode_graph(graph) variant_weights = jax.random.dirichlet(jax.random.PRNGKey(0), jnp.ones((4, 4, 5))) routing_params = builder.initialize_routing_params( jax.random.PRNGKey(1), num_fsm_states=3, num_variants=5) start_states = jax.random.dirichlet(jax.random.PRNGKey(0), jnp.ones((4, 3))) # Confirm that the explicit conv doesn't change results. def go(routing_params, variant_weights, start_states, explicit_conv=True): tmat = builder.build_transition_matrix(routing_params, enc_graph, enc_meta) return automaton_builder.all_nodes_absorbing_solve( builder, tmat, variant_weights, start_states, steps=1000, backtrack_fails_prob=0.01, explicit_conv=explicit_conv) vals, vjpfun = jax.vjp(go, routing_params, variant_weights, start_states) unopt_vals, unopt_vjpfun = jax.vjp( functools.partial(go, explicit_conv=False), routing_params, variant_weights, start_states) jax.test_util.check_close(vals, unopt_vals) some_cotangent = jax.random.normal(jax.random.PRNGKey(0), vals.shape) jax.test_util.check_close(vjpfun(some_cotangent), unopt_vjpfun(some_cotangent))
def test_component_shapes(self, component, embed_edges, expected_dims, extra_config=None): gin.clear_config() gin.parse_config(CONFIG) if extra_config: gin.parse_config(extra_config) # Run the computation with placeholder inputs. (node_out, edge_out), _ = end_to_end_stack.ALL_COMPONENTS[component].init( jax.random.PRNGKey(0), graph_context=end_to_end_stack.SharedGraphContext( bundle=graph_bundle.zeros_like_padded_example( graph_bundle.PaddingConfig( static_max_metadata=automaton_builder. EncodedGraphMetadata(num_nodes=16, num_input_tagged_nodes=32), max_initial_transitions=11, max_in_tagged_transitions=12, max_edges=13)), static_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=16, num_input_tagged_nodes=32), edge_types_to_indices={"foo": 0}, builder=automaton_builder.AutomatonBuilder({ graph_types.NodeType("node"): graph_types.NodeSchema( in_edges=[graph_types.InEdgeType("in")], out_edges=[graph_types.InEdgeType("out")]) }), edges_are_embedded=embed_edges), node_embeddings=jnp.zeros((16, NODE_DIM)), edge_embeddings=jnp.zeros((16, 16, EDGE_DIM))) self.assertEqual(node_out.shape, (16, expected_dims["node"])) self.assertEqual(edge_out.shape, (16, 16, expected_dims["edge"]))
def test_automaton_layer_abstract_init(self, shared, variant_weights, use_gate, estimator_type, **kwargs): # Create a simple schema and empty encoded graph. schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("ai_0")], out_edges=[graph_types.OutEdgeType("ao_0") ]), } builder = automaton_builder.AutomatonBuilder(schema) encoded_graph = automaton_builder.EncodedGraph( initial_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((128, 1), dtype=jnp.int32), output_indices=jnp.zeros((128, 2), dtype=jnp.int32), values=jnp.zeros((128, ), dtype=jnp.float32), ), initial_to_special=jnp.zeros((32, ), dtype=jnp.int32), in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((128, 1), dtype=jnp.int32), output_indices=jnp.zeros((128, 2), dtype=jnp.int32), values=jnp.zeros((128, ), dtype=jnp.float32), ), in_tagged_to_special=jnp.zeros((64, ), dtype=jnp.int32), in_tagged_node_indices=jnp.zeros((64, ), dtype=jnp.int32), ) # Make sure the layer can be initialized and applied within a model. # This model is fairly simple; it just pretends that the encoded graph and # variants depend on the input. class TestModel(flax.deprecated.nn.Module): def apply(self, dummy_ignored): abstract_encoded_graph = jax.tree_map( lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph) abstract_variant_weights = jax.tree_map( lambda y: jax.lax.tie_in(dummy_ignored, y), variant_weights()) return automaton_layer.FiniteStateGraphAutomaton( encoded_graph=abstract_encoded_graph, variant_weights=abstract_variant_weights, dynamic_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=32, num_input_tagged_nodes=64), static_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=32, num_input_tagged_nodes=64), builder=builder, num_out_edges=3, num_intermediate_states=4, share_states_across_edges=shared, use_gate_parameterization=use_gate, estimator_type=estimator_type, name="the_layer", **kwargs) with side_outputs.collect_side_outputs() as side: with flax.deprecated.nn.stochastic(jax.random.PRNGKey(0)): # For some reason init_by_shape breaks the custom_vjp? abstract_out, unused_params = TestModel.init( jax.random.PRNGKey(1234), jnp.zeros((), jnp.float32)) del unused_params self.assertEqual(abstract_out.shape, (3, 32, 32)) if estimator_type == "one_sample": log_prob_key = "/the_layer/one_sample_log_prob_per_edge_per_node" self.assertIn(log_prob_key, side) self.assertEqual(side[log_prob_key].shape, (3, 32))
def test_all_nodes_particle_estimate(self): schema, graph = self.build_doubly_linked_list_graph(4) builder = automaton_builder.AutomatonBuilder(schema) enc_graph, enc_meta = builder.encode_graph(graph) # We set up the automaton with 5 variants and 2 states, but only use the # first state, to make sure that variants and states are interleaved # correctly. # Variant 0: move forward # Variant 1: move backward # Variant 2: finish # Variant 3: restart # Variant 4: fail variant_weights = jnp.array([ # From node 0, go forward. [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [.7, 0, .3, 0, 0], [0, 0, 1, 0, 0]], # From node 1, go backward with small failure probabilities. [[0, 0.9, 0, 0, 0.1], [0, 0.9, 0, 0, 0.1], [.7, 0, .3, 0, 0], [0, 0, 1, 0, 0]], # Node 2 bounces around and ultimately accepts on node 0. [[0.9, 0, 0.1, 0, 0], [0, 1, 0, 0, 0], [0.5, 0.5, 0, 0, 0], [0, 1, 0, 0, 0]], # Node 3 immediately accepts, or restarts after 0 or 1 steps. [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0.1, 0.8, 0.1, 0]], ]) routing_params = automaton_builder.RoutingParams( move=jnp.broadcast_to( jnp.pad( jnp.array([ [1., 0., 1., 0., 1., 0.], [0., 1., 0., 1., 0., 1.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], ]).reshape([5, 6, 1, 1]), [(0, 0), (0, 0), (0, 0), (0, 1)]), [5, 6, 2, 2]), special=jnp.broadcast_to( jnp.array([ [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]], [[0., 1., 0.], [0., 1., 0.], [0., 1., 0.]], [[0., 0., 1.], [0., 0., 1.], [0., 0., 1.]], ]).reshape([5, 3, 1, 3]), [5, 3, 2, 3])) @jax.jit def go(variant_weights, routing_params, eps): variant_weights = ((1 - eps) * variant_weights + eps * jnp.ones_like(variant_weights) / 5) routing_params = automaton_builder.RoutingParams( move=((1 - eps) * routing_params.move + eps * jnp.ones_like(routing_params.move) / 5), special=((1 - eps) * routing_params.special + eps * jnp.ones_like(routing_params.special) / 5), ) variant_weights = variant_weights / jnp.sum( variant_weights, axis=-1, keepdims=True) routing_params_sum = builder.routing_reduce(routing_params, "sum") routing_params = jax.tree_multimap(jax.lax.div, routing_params, routing_params_sum) tmat = builder.build_transition_matrix(routing_params, enc_graph, enc_meta) return automaton_sampling.all_nodes_particle_estimate( builder, tmat, variant_weights, jnp.pad(jnp.ones([4, 1]), [(0, 0), (0, 1)]), steps=100, rng=jax.random.PRNGKey(1), num_rollouts=10000, max_possible_transitions=2, # only two edges to leave each node num_valid_nodes=enc_meta.num_nodes, ) # Absorbing probs follow the paths described above. expected_absorbing_probs = jnp.array([ [0, 0, 0.3, 0.7], [0, 0, 0, 0.81], [1, 0, 0, 0], [0, 0, 0, 1], ]) particle_probs = go(variant_weights, routing_params, eps=0) # With 10000 rollouts we expect a standard deviation of up to # sqrt(0.3*0.7/10000) ~= 5e-3. Check that we are within 2 sigma. np.testing.assert_allclose(particle_probs, expected_absorbing_probs, atol=2 * 5e-3)
def test_unroll_and_aggregate(self): schema, graph = self.build_doubly_linked_list_graph(4) builder = automaton_builder.AutomatonBuilder(schema) enc_graph, enc_meta = builder.encode_graph(graph) # Three variants, but all the same, just to check shape consistency variant_weights = jnp.broadcast_to(jnp.array([0.7, 0.3, 0.]), [4, 3]) # In state 0: keep moving in the current direction with prob 0.1, swap # directions and states with prob 0.9 (except init, which is a special case) # In state 1: take a special action routing_params = automaton_builder.RoutingParams( move=jnp.broadcast_to( jnp.array([ # from next, to next [[0.0, 0.9], [0.0, 0.0]], # from next, to prev [[0.1, 0.0], [0.0, 0.0]], # from prev, to next [[0.1, 0.0], [0.0, 0.0]], # from next, to prev [[0.0, 0.9], [0.0, 0.0]], # from init, to next [[0.1, 0.0], [0.0, 0.0]], # from init, to prev [[0.1, 0.0], [0.0, 0.0]], ]), [3, 6, 2, 2]), special=jnp.broadcast_to( jnp.array([ # from next [[0, 0, 0], [0.2, 0.3, 0.5]], # from prev [[0, 0, 0], [0.5, 0.2, 0.3]], # from init [[0.1, 0.3, 0.4], [0.0, 0.0, 1.0]], ]), [3, 3, 2, 3])) tmat = builder.build_transition_matrix(routing_params, enc_graph, enc_meta) unrolled = automaton_builder.unroll_chain_steps(builder, tmat, variant_weights, jnp.array([1., 0.]), node_index=0, steps=6) expected_initial_special = np.array([0.1, 0.3, 0.4]) np.testing.assert_allclose(unrolled["initial_special"], expected_initial_special) # pyformat: disable # pylint: disable=bad-whitespace expected_in_tagged_states = np.array([ # In-tagged node key: # 0 from 1, 0 from 3, 1 from 2, 1 from 0, # 2 from 3, 2 from 1, 3 from 0, 3 from 2 # State key: [prob of being in state 0, prob of being in state 1] # -- First step from initial -- [[0, 0], [0, 0], [0, 0], [1e-1, 0], [0, 0], [0, 0], [1e-1, 0], [0, 0]], # -- Second step -- [[0, 9e-2], [0, 9e-2], [0, 0], [0, 0], [1e-2, 0], [1e-2, 0], [0, 0], [0, 0]], # ----------------- [[0, 0], [0, 0], [1e-3, 9e-3], [0, 0], [0, 0], [0, 0], [0, 0], [1e-3, 9e-3]], # ----------------- [[1e-4, 0], [1e-4, 0], [0, 0], [0, 0], [0, 9e-4], [0, 9e-4], [0, 0], [0, 0]], # ----------------- [[0, 0], [0, 0], [0, 0], [1e-5, 9e-5], [0, 0], [0, 0], [1e-5, 9e-5], [0, 0]], # ----------------- [[0, 9e-6], [0, 9e-6], [0, 0], [0, 0], [1e-6, 0], [1e-6, 0], [0, 0], [0, 0]], ]) np.testing.assert_allclose(unrolled["in_tagged_states"], expected_in_tagged_states, atol=1e-8) expected_in_tagged_special = np.array([ # In-tagged node key: # 0 from 1, 0 from 3 # 1 from 2, 1 from 0, # 2 from 3, 2 from 1 # 3 from 0, 3 from 2 # Action key: [finish, backtrack, fail] (cumulative) # -- First step from initial -- # (no special actions because we just left the initial node) [[0, 0, 0]] * 8, # -- Second step -- # (no special actions yet because everything was in state 0) [[0, 0, 0]] * 8, # ----------------- [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], # ----------------- [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2], [1.8e-3, 2.7e-3, 4.5e-3], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [4.5e-3, 1.8e-3, 2.7e-3]], # ----------------- [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2], [1.8e-3, 2.7e-3, 4.5e-3], [0, 0, 0], [1.8e-4, 2.7e-4, 4.5e-4], [4.5e-4, 1.8e-4, 2.7e-4], [0, 0, 0], [4.5e-3, 1.8e-3, 2.7e-3]], # ----------------- [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2], [1.8e-3, 2.7e-3, 4.5e-3], [4.5e-5, 1.8e-5, 2.7e-5], [1.8e-4, 2.7e-4, 4.5e-4], [4.5e-4, 1.8e-4, 2.7e-4], [1.8e-5, 2.7e-5, 4.5e-5], [4.5e-3, 1.8e-3, 2.7e-3]], ]) np.testing.assert_allclose(unrolled["in_tagged_special"], expected_in_tagged_special, atol=1e-8) # pyformat: enable # pylint: enable=bad-whitespace unrolled_combined = automaton_builder.aggregate_unrolled_per_node( unrolled, 0, 0, tmat, enc_meta) expected_unrolled = np.array([ # "Zeroth" step: at initial node, no specials have happened yet [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], # The other steps are either the states from above, or sums of entries # in initial_special and in_tagged_special [[0, 0, 0.1, 0.3, 0.4], [0.1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0.1, 0, 0, 0, 0]], # ----------------- [[0, 0.18, 0.1, 0.3, 0.4], [0, 0, 0, 0, 0], [0.02, 0, 0, 0, 0], [0, 0, 0, 0, 0]], # ----------------- [[0, 0, 0.163, 0.34500003, 0.472], [0.001, 0.009, 0, 0, 0], [0, 0, 0, 0, 0], [0.001, 0.009, 0, 0, 0]], # ----------------- [[0.0002, 0, 0.163, 0.34500003, 0.472], [0, 0, 0.0018, 0.0027, 0.0045], [0, 0.0018, 0, 0, 0], [0, 0, 0.0045, 0.0018, 0.0027]], # ----------------- [[0, 0, 0.163, 0.34500003, 0.472], [1e-05, 9e-05, 0.0018, 0.0027, 0.0045], [0, 0, 0.00063, 0.00045, 0.00072], [1e-05, 9e-05, 0.0045, 0.0018, 0.0027]], # ----------------- [[0, 1.8e-05, 0.163, 0.34500003, 0.472], [0, 0, 0.001845, 0.002718, 0.004527], [2e-06, 0, 0.00063, 0.00045, 0.00072], [0, 0, 0.004518, 0.001827, 0.002745]], ]) np.testing.assert_allclose(unrolled_combined, expected_unrolled, atol=1e-8)
def test_all_nodes_absorbing_solve(self): schema, graph = self.build_doubly_linked_list_graph(4) builder = automaton_builder.AutomatonBuilder(schema) enc_graph, enc_meta = builder.encode_graph(graph) # We set up the automaton with 5 variants and 2 states, but only use the # first state, to make sure that variants and states are interleaved # correctly. # Variant 0: move forward # Variant 1: move backward # Variant 2: finish # Variant 3: restart # Variant 4: fail variant_weights = jnp.array([ # From node 0, go forward. [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [.7, 0, .3, 0, 0], [0, 0, 1, 0, 0]], # From node 1, go backward with small failure probabilities. [[0, 0.9, 0, 0, 0.1], [0, 0.9, 0, 0, 0.1], [.7, 0, .3, 0, 0], [0, 0, 1, 0, 0]], # Node 2 bounces around and ultimately accepts on node 0. [[0.9, 0, 0.1, 0, 0], [0, 1, 0, 0, 0], [0.5, 0.5, 0, 0, 0], [0, 1, 0, 0, 0]], # Node 3 immediately accepts, or restarts after 0 or 1 steps. [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0.1, 0.8, 0.1, 0]], ]) routing_params = automaton_builder.RoutingParams( move=jnp.pad( jnp.array([ [1., 0., 1., 0., 1., 0.], [0., 1., 0., 1., 0., 1.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], ]).reshape([5, 6, 1, 1]), [(0, 0), (0, 0), (0, 1), (0, 1)]), special=jnp.pad( jnp.array([ [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]], [[0., 1., 0.], [0., 1., 0.], [0., 1., 0.]], [[0., 0., 1.], [0., 0., 1.], [0., 0., 1.]], ]).reshape([5, 3, 1, 3]), [(0, 0), (0, 0), (0, 1), (0, 0)])) tmat = builder.build_transition_matrix(routing_params, enc_graph, enc_meta) # Absorbing probs follow the paths described above. # Note that when starting at node 3, with probability 0.2 the automaton # tries to backtrack, but with probability 0.2 * 0.01 backtracking fails # (as specified by backtrack_fails_prob) and thus the total absorbing # probability is 0.8 / (0.8 + 0.2 * 0.01) = 0.997506 expected_absorbing_probs = jnp.array([ [0, 0, 0.3, 0.7], [0, 0, 0, 0.81], [1, 0, 0, 0], [0, 0, 0, 0.997506], ]) absorbing_probs = automaton_builder.all_nodes_absorbing_solve( builder, tmat, variant_weights, jnp.pad(jnp.ones([4, 1]), [(0, 0), (0, 1)]), steps=1000, backtrack_fails_prob=0.01) jax.test_util.check_close(absorbing_probs, expected_absorbing_probs)
Args: maze_graph: Encoded graph representing the maze. Returns: List of edges corresponding to primitive actions in the maze. """ primitives = [] for node_id, node_info in maze_graph.items(): for i, direction in enumerate(DIRECTION_ORDERING): out_key = graph_types.OutEdgeType(f"{direction}_out") if out_key in node_info.out_edges: dest, = node_info.out_edges[out_key] primitives.append((node_id, dest.node_id, i)) else: primitives.append((node_id, node_id, i)) return primitives SCHEMA = maze_schema.build_maze_schema(2) # Backtracking doesn't make sense for maze environment. BUILDER = automaton_builder.AutomatonBuilder(SCHEMA, with_backtrack=False) PADDING_CONFIG = graph_bundle.PaddingConfig( static_max_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=256, num_input_tagged_nodes=512), max_initial_transitions=512, max_in_tagged_transitions=2048, max_edges=1024)
def test_routing_reduce_correct(self, reduction): """Compare JAX implementations to a (slow but correct) iterative one.""" n_variants = 2 n_states = 4 def make_range_shaped(shape): return np.arange(np.prod(shape)).reshape(shape).astype("float32") schema = self.build_simple_schema() builder = automaton_builder.AutomatonBuilder(schema) routing_params = automaton_builder.RoutingParams( move=make_range_shaped([ n_variants, len(builder.in_out_route_types), n_states, n_states, ]), special=make_range_shaped([ n_variants, len(builder.in_route_types), n_states, len(builder.special_actions), ]), ) # Compute aggregates with JAX if reduction == "softmax": routing_aggregates = builder.routing_softmax(routing_params) else: routing_aggregates = builder.routing_reduce(routing_params, reduction=reduction) routing_aggregates = jax.tree_multimap( lambda s, p: np.array(jnp.broadcast_to(s, p.shape)), routing_aggregates, routing_params) # Manual looping aggregates for variant in range(n_variants): for current_state in range(n_states): for in_route_type in builder.in_route_types: # Compute aggregates distn_vals = [] iroute_idx = builder.in_route_type_to_index[in_route_type] for out_edge_type in schema[ in_route_type.node_type].out_edges: ioroute_idx = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( in_route_type.node_type, in_route_type.in_edge, out_edge_type)] for next_state in range(n_states): distn_vals.append( routing_params.move[variant, ioroute_idx, current_state, next_state]) for action_idx in range(len(builder.special_actions)): distn_vals.append(routing_params.special[variant, iroute_idx, current_state, action_idx]) if reduction == "sum": distn_aggregate = [sum(distn_vals)] * len(distn_vals) elif reduction == "max": distn_aggregate = [max(distn_vals)] * len(distn_vals) elif reduction == "softmax": distn_aggregate = list( jax.nn.softmax(jnp.array(distn_vals))) else: raise ValueError(f"Invalid reduction {reduction}") i = 0 # Check them with the JAX version for out_edge_type in schema[ in_route_type.node_type].out_edges: ioroute_idx = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( in_route_type.node_type, in_route_type.in_edge, out_edge_type)] for next_state in range(n_states): np.testing.assert_allclose( routing_aggregates.move[variant, ioroute_idx, current_state, next_state], distn_aggregate[i], rtol=1e-6) i += 1 for action_idx in range(len(builder.special_actions)): np.testing.assert_allclose( routing_aggregates.special[variant, iroute_idx, current_state, action_idx], distn_aggregate[i], rtol=1e-6) i += 1
if isinstance(value, gast.AST): fields[field_name] = [py_ast_to_generic(value)] elif isinstance(value, list): if value and isinstance(value[0], gast.AST): fields[field_name] = [py_ast_to_generic(child) for child in value] else: # Doesn't contain any AST nodes, so ignore it. pass return generic_ast_graphs.GenericASTNode( node_id=id(tree), node_type=type(tree).__name__, fields=fields) # Default definitions used elsewhere SCHEMA = generic_ast_graphs.build_ast_graph_schema(PY_AST_SPECS) BUILDER = automaton_builder.AutomatonBuilder(SCHEMA) def py_ast_to_graph( tree): """Convert an unsimplified AST into a graph, with a forward mapping. Args: tree: The (unsimplified) AST for the program. Returns: - Graph representing the tree. - Dictionary that maps from AST node ids to graph node ids. """ return generic_ast_graphs.ast_to_graph( root=py_ast_to_generic(tree),