def test_schema_edges(self):
        mini_schema = {
            "foo":
            graph_types.NodeSchema(in_edges=["ignored"], out_edges=["a", "b"]),
            "bar":
            graph_types.NodeSchema(in_edges=["ignored"], out_edges=["b", "c"]),
        }

        mini_graph = {
            "foo_node":
            graph_types.GraphNode(
                "foo", {
                    "a": [graph_types.InputTaggedNode("bar_node", "ignored")],
                    "b": [graph_types.InputTaggedNode("foo_node", "ignored")]
                }),
            "bar_node":
            graph_types.GraphNode(
                "bar", {
                    "b": [graph_types.InputTaggedNode("bar_node", "ignored")],
                    "c": [graph_types.InputTaggedNode("foo_node", "ignored")]
                }),
        }

        schema_edge_types = graph_edge_util.schema_edge_types(
            mini_schema, with_node_types=False)
        self.assertEqual(schema_edge_types,
                         {"SCHEMA_a", "SCHEMA_b", "SCHEMA_c"})

        schema_edges = graph_edge_util.compute_schema_edges(
            mini_graph, with_node_types=False)
        self.assertEqual(schema_edges, [
            ("foo_node", "bar_node", "SCHEMA_a"),
            ("foo_node", "foo_node", "SCHEMA_b"),
            ("bar_node", "bar_node", "SCHEMA_b"),
            ("bar_node", "foo_node", "SCHEMA_c"),
        ])

        schema_edge_types = graph_edge_util.schema_edge_types(
            mini_schema, with_node_types=True)
        self.assertEqual(
            schema_edge_types, {
                "SCHEMA_a_FROM_foo", "SCHEMA_b_FROM_foo", "SCHEMA_b_FROM_bar",
                "SCHEMA_c_FROM_bar"
            })

        schema_edges = graph_edge_util.compute_schema_edges(
            mini_graph, with_node_types=True)
        self.assertEqual(schema_edges, [
            ("foo_node", "bar_node", "SCHEMA_a_FROM_foo"),
            ("foo_node", "foo_node", "SCHEMA_b_FROM_foo"),
            ("bar_node", "bar_node", "SCHEMA_b_FROM_bar"),
            ("bar_node", "foo_node", "SCHEMA_c_FROM_bar"),
        ])
 def build_doubly_linked_list_graph(self, length):
     """Helper method to build a doubly-linked-list graph and schema."""
     schema = {
         graph_types.NodeType("node"):
         graph_types.NodeSchema(in_edges=[
             graph_types.InEdgeType("next_in"),
             graph_types.InEdgeType("prev_in"),
         ],
                                out_edges=[
                                    graph_types.OutEdgeType("next_out"),
                                    graph_types.OutEdgeType("prev_out"),
                                ])
     }
     graph = {}
     for i in range(length):
         graph[graph_types.NodeId(str(i))] = graph_types.GraphNode(
             graph_types.NodeType("node"), {
                 graph_types.OutEdgeType("next_out"): [
                     graph_types.InputTaggedNode(
                         node_id=graph_types.NodeId(str((i + 1) % length)),
                         in_edge=graph_types.InEdgeType("prev_in"))
                 ],
                 graph_types.OutEdgeType("prev_out"): [
                     graph_types.InputTaggedNode(
                         node_id=graph_types.NodeId(str((i - 1) % length)),
                         in_edge=graph_types.InEdgeType("next_in"))
                 ]
             })
     return schema, graph
예제 #3
0
 def build_simple_schema(self):
     return {
         graph_types.NodeType("a"):
         graph_types.NodeSchema(in_edges=[
             graph_types.InEdgeType("ai_0"),
             graph_types.InEdgeType("ai_1")
         ],
                                out_edges=[graph_types.OutEdgeType("ao_0")
                                           ]),
         graph_types.NodeType("b"):
         graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                out_edges=[
                                    graph_types.OutEdgeType("bo_0"),
                                    graph_types.OutEdgeType("bo_1")
                                ]),
     }
예제 #4
0
    def test_conforms_to_schema(self):
        test_schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[
                graph_types.InEdgeType("ai_0"),
                graph_types.InEdgeType("ai_1")
            ],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
            graph_types.NodeType("b"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                   out_edges=[
                                       graph_types.OutEdgeType("bo_0"),
                                       graph_types.OutEdgeType("bo_1")
                                   ]),
        }

        # Valid graph
        test_graph = {
            graph_types.NodeId("A"):
            graph_types.GraphNode(
                graph_types.NodeType("a"), {
                    graph_types.OutEdgeType("ao_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B"),
                            graph_types.InEdgeType("bi_0")),
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ]
                }),
            graph_types.NodeId("B"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                })
        }
        schema_util.assert_conforms_to_schema(test_graph, test_schema)
예제 #5
0
 def test_does_not_conform_to_schema(self, graph, expected_error):
     test_schema = {
         graph_types.NodeType("a"):
         graph_types.NodeSchema(in_edges=[
             graph_types.InEdgeType("ai_0"),
             graph_types.InEdgeType("ai_1")
         ],
                                out_edges=[graph_types.OutEdgeType("ao_0")
                                           ]),
         graph_types.NodeType("b"):
         graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                out_edges=[
                                    graph_types.OutEdgeType("bo_0"),
                                    graph_types.OutEdgeType("bo_1")
                                ]),
     }
     # pylint: disable=g-error-prone-assert-raises
     with self.assertRaisesRegex(ValueError, expected_error):
         schema_util.assert_conforms_to_schema(graph, test_schema)
예제 #6
0
def build_maze_schema(min_neighbors):
    """Build a schema for a 2d gridworld maze.

  A node type like "cell_LxUD" corresponds to a cell that has neighbors in the
  left, up, and down directions. The in and out edges for this cell would be
  "L_out", "U_out", "D_out", "L_in", "U_in", "D_in".

  Args:
    min_neighbors: Minimum number of neighbors a grid cell will have. For
      instance, if min_neighbors=2 then the schema will only have types for
      cells that are connected to at least 2 other cells.

  Returns:
    Schema describing the maze.
  """
    maze_schema = {}
    for (has_l, has_r, has_u, has_d) in itertools.product([False, True],
                                                          repeat=4):
        if np.count_nonzero([has_l, has_r, has_u, has_d]) < min_neighbors:
            continue

        name = (f"cell_{'L' if has_l else 'x'}{'R' if has_r else 'x'}"
                f"{'U' if has_u else 'x'}{'D' if has_d else 'x'}")
        node_schema = graph_types.NodeSchema([], [])
        for direction, used in (
            ("L", has_l),
            ("R", has_r),
            ("U", has_u),
            ("D", has_d),
        ):
            if used:
                node_schema.in_edges.append(
                    graph_types.InEdgeType(direction + "_in"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(direction + "_out"))

        maze_schema[graph_types.NodeType(name)] = node_schema
    return maze_schema
    def test_component_shapes(self,
                              component,
                              embed_edges,
                              expected_dims,
                              extra_config=None):
        gin.clear_config()
        gin.parse_config(CONFIG)
        if extra_config:
            gin.parse_config(extra_config)

        # Run the computation with placeholder inputs.
        (node_out,
         edge_out), _ = end_to_end_stack.ALL_COMPONENTS[component].init(
             jax.random.PRNGKey(0),
             graph_context=end_to_end_stack.SharedGraphContext(
                 bundle=graph_bundle.zeros_like_padded_example(
                     graph_bundle.PaddingConfig(
                         static_max_metadata=automaton_builder.
                         EncodedGraphMetadata(num_nodes=16,
                                              num_input_tagged_nodes=32),
                         max_initial_transitions=11,
                         max_in_tagged_transitions=12,
                         max_edges=13)),
                 static_metadata=automaton_builder.EncodedGraphMetadata(
                     num_nodes=16, num_input_tagged_nodes=32),
                 edge_types_to_indices={"foo": 0},
                 builder=automaton_builder.AutomatonBuilder({
                     graph_types.NodeType("node"):
                     graph_types.NodeSchema(
                         in_edges=[graph_types.InEdgeType("in")],
                         out_edges=[graph_types.InEdgeType("out")])
                 }),
                 edges_are_embedded=embed_edges),
             node_embeddings=jnp.zeros((16, NODE_DIM)),
             edge_embeddings=jnp.zeros((16, 16, EDGE_DIM)))

        self.assertEqual(node_out.shape, (16, expected_dims["node"]))
        self.assertEqual(edge_out.shape, (16, 16, expected_dims["edge"]))
    def test_automaton_layer_abstract_init(self, shared, variant_weights,
                                           use_gate, estimator_type, **kwargs):
        # Create a simple schema and empty encoded graph.
        schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("ai_0")],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
        }
        builder = automaton_builder.AutomatonBuilder(schema)
        encoded_graph = automaton_builder.EncodedGraph(
            initial_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            initial_to_special=jnp.zeros((32, ), dtype=jnp.int32),
            in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            in_tagged_to_special=jnp.zeros((64, ), dtype=jnp.int32),
            in_tagged_node_indices=jnp.zeros((64, ), dtype=jnp.int32),
        )

        # Make sure the layer can be initialized and applied within a model.
        # This model is fairly simple; it just pretends that the encoded graph and
        # variants depend on the input.
        class TestModel(flax.deprecated.nn.Module):
            def apply(self, dummy_ignored):
                abstract_encoded_graph = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph)
                abstract_variant_weights = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y),
                    variant_weights())
                return automaton_layer.FiniteStateGraphAutomaton(
                    encoded_graph=abstract_encoded_graph,
                    variant_weights=abstract_variant_weights,
                    dynamic_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    static_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    builder=builder,
                    num_out_edges=3,
                    num_intermediate_states=4,
                    share_states_across_edges=shared,
                    use_gate_parameterization=use_gate,
                    estimator_type=estimator_type,
                    name="the_layer",
                    **kwargs)

        with side_outputs.collect_side_outputs() as side:
            with flax.deprecated.nn.stochastic(jax.random.PRNGKey(0)):
                # For some reason init_by_shape breaks the custom_vjp?
                abstract_out, unused_params = TestModel.init(
                    jax.random.PRNGKey(1234), jnp.zeros((), jnp.float32))

        del unused_params
        self.assertEqual(abstract_out.shape, (3, 32, 32))

        if estimator_type == "one_sample":
            log_prob_key = "/the_layer/one_sample_log_prob_per_edge_per_node"
            self.assertIn(log_prob_key, side)
            self.assertEqual(side[log_prob_key].shape, (3, 32))
예제 #9
0
def build_ast_graph_schema(ast_spec):
    """Builds a graph schema for an AST.

  This logic is described in Appendix B.1 of the paper.

  Each AST node becomes a new node type, whose edges are determined by its
  fields (along with whether it has a parent). Additionally, we generate helper
  node types for each grammar category that appears as a sequence (i.e. one
  for sequences of statments, another for sequences of expressions).

  Nodes with a parent get two edge types for that parent:
  - (in) "parent_in": the edge from the parent to this node
  - (out) "parent_out": the edge from this node to its parent

  ONE_CHILD fields become two edge types:
  - (in) "{field_name}_in": the edge from the child to this node
  - (out) "{field_name}_out": the edge from this node to the child

  OPTIONAL_CHILD fields become three edge types (to account for the
  invariant where every outgoing edge type must have at least one edge):
  - (in) "{field_name}_in": the edge from the child to this node (if it exists)
  - (in) "{field_name}_missing": a loopback edge from this node to itself, used
      as a sentinel value for when the child is missing
  - (out) "{field_name}_out": if there is a child of this type, this edge
      connects to that child; otherwise, it is a loopback edge to this node
      with incoming type "{field_name}_missing".

  NONEMPTY_SEQUENCE or SEQUENCE fields become 4 or 5 edge types:
  - (in) "{field_name}_in": used for EVERY edge from a child-item-helper to this
      node
  - (in) "{field_name}_missing": loopback edges for missing outgoing edges
      (only for SEQUENCE)
  - (out) "{field_name}_all": edges to every child-item-helper, or a loop back
      to "{field_name}_missing" if the sequence is empty
  - (out) "{field_name}_first": edges to the first child-item-helper, or a loop
      back to "{field_name}_missing" if the sequence is empty
  - (out) "{field_name}_last": edges to the last child-item-helper, or a loop
      back to "{field_name}_missing" if the sequence is empty

  NO_CHILDREN fields must be empty ([] or None) and will throw an error
  otherwise. IGNORE fields will be simply ignored.

  Child item helpers are added between any AST node with a sequence of children
  and the children in that sequence. These helpers have the following edge
  types:
  - "item_{in/out}": between the helper and the child
  - "parent_{in/out}": between the helper and the parent
  - "next_{in/out/missing}": between this helper and the next one in the
      sequence, or a loop from "out" to "missing" if this is the last element
  - "prev_{in/out/missing}": between this helper and the previous one in the
      sequence, or a loop from "out" to "missing" if this is the first element
  There is one helper type for each unique value in `sequence_item_types` across
  all node types.

  Args:
    ast_spec: AST spec to generate a schema for.

  Returns:
    GraphSchema with nodes as described above.
  """
    result = {}
    seen_sequence_categories = set()
    # Build node schemas for each AST node
    for node_type, node_spec in ast_spec.items():
        node_schema = graph_types.NodeSchema([], [])

        # Add possible edge types
        if node_spec.has_parent:
            node_schema.in_edges.append(graph_types.InEdgeType("parent_in"))
            node_schema.out_edges.append(graph_types.OutEdgeType("parent_out"))

        for field, field_type in node_spec.fields.items():
            if field_type in {FieldType.ONE_CHILD, FieldType.OPTIONAL_CHILD}:
                node_schema.in_edges.append(
                    graph_types.InEdgeType(f"{field}_in"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(f"{field}_out"))
                if field_type == FieldType.OPTIONAL_CHILD:
                    node_schema.in_edges.append(
                        graph_types.InEdgeType(f"{field}_missing"))
            elif field_type in {
                    FieldType.SEQUENCE, FieldType.NONEMPTY_SEQUENCE
            }:
                seen_sequence_categories.add(
                    node_spec.sequence_item_types[field])
                node_schema.in_edges.append(
                    graph_types.InEdgeType(f"{field}_in"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(f"{field}_out_all"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(f"{field}_out_first"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(f"{field}_out_last"))
                if field_type == FieldType.SEQUENCE:
                    node_schema.in_edges.append(
                        graph_types.InEdgeType(f"{field}_missing"))
            elif field_type in {FieldType.NO_CHILDREN, FieldType.IGNORE}:
                # No edges for these fields.
                pass
            else:
                raise ValueError(f"Unexpected field type {field_type}")

        result[graph_types.NodeType(node_type)] = node_schema

    # Build node schemas for each category helper
    for category in sorted(seen_sequence_categories):
        helper_type = graph_types.NodeType(f"{category}-seq-helper")
        assert helper_type not in result
        node_schema = graph_types.NodeSchema(
            in_edges=[
                graph_types.InEdgeType("parent_in"),
                graph_types.InEdgeType("item_in"),
                graph_types.InEdgeType("next_in"),
                graph_types.InEdgeType("next_missing"),
                graph_types.InEdgeType("prev_in"),
                graph_types.InEdgeType("prev_missing")
            ],
            out_edges=[
                graph_types.OutEdgeType("parent_out"),
                graph_types.OutEdgeType("item_out"),
                graph_types.OutEdgeType("next_out"),
                graph_types.OutEdgeType("prev_out")
            ])
        result[helper_type] = node_schema

    return result
예제 #10
0
    def build_loop_graph(self):
        """Helper method to build this complex graph, to test graph encodings.

      ┌───────<──┐  ┌───────<─────────<─────┐
      │          │  │                       │
      │    [ao_0]│  ↓[ai_0]                 │
      │          (a0)                       │
      │    [ai_1]↑  │[ao_1]                 │
      │          │  │                       │
      │    [ao_0]│  ↓[ai_0]                 │
      ↓          (a1)                       ↑
      │    [ai_1]↑  │[ao_1]                 │
      │          │  │                       │
      │    [bo_0]│  ↓[bi_0]                 │
      │          ╭──╮───────>[bo_2]────┐    │
      │          │b0│───────>[bo_2]──┐ │    │
      │          │  │<──[bi_0]─────<─┘ │    │
      │          ╰──╯<──[bi_0]─────<─┐ │    │
      │    [bi_0]↑  │[bo_1]          │ │    ↑
      │          │  │                │ │    │
      │    [bo_0]│  ↓[bi_0]          │ │    │
      │          ╭──╮───────>[bo_2]──┘ │    │
      │          │b1│───────>[bo_2]──┐ │    │
      ↓          │  │<──[bi_0]─────<─┘ │    │
      │          ╰──╯<──[bi_0]─────<───┘    │
      │    [bi_0]↑  │[bo_1]                 │
      │          │  │                       │
      └───────>──┘  └───────>─────────>─────┘

    Returns:
      Tuple (schema, graph) for the above structure.
    """
        a = graph_types.NodeType("a")
        b = graph_types.NodeType("b")

        ai_0 = graph_types.InEdgeType("ai_0")
        ai_1 = graph_types.InEdgeType("ai_1")
        bi_0 = graph_types.InEdgeType("bi_0")
        bi_0 = graph_types.InEdgeType("bi_0")

        ao_0 = graph_types.OutEdgeType("ao_0")
        ao_1 = graph_types.OutEdgeType("ao_1")
        bo_0 = graph_types.OutEdgeType("bo_0")
        bo_1 = graph_types.OutEdgeType("bo_1")
        bo_2 = graph_types.OutEdgeType("bo_2")

        a0 = graph_types.NodeId("a0")
        a1 = graph_types.NodeId("a1")
        b0 = graph_types.NodeId("b0")
        b1 = graph_types.NodeId("b1")

        schema = {
            a:
            graph_types.NodeSchema(in_edges=[ai_0, ai_1],
                                   out_edges=[ao_0, ao_1]),
            b:
            graph_types.NodeSchema(in_edges=[bi_0],
                                   out_edges=[bo_0, bo_1, bo_2]),
        }
        test_graph = {
            a0:
            graph_types.GraphNode(
                a, {
                    ao_0: [graph_types.InputTaggedNode(b1, bi_0)],
                    ao_1: [graph_types.InputTaggedNode(a1, ai_0)]
                }),
            a1:
            graph_types.GraphNode(
                a, {
                    ao_0: [graph_types.InputTaggedNode(a0, ai_1)],
                    ao_1: [graph_types.InputTaggedNode(b0, bi_0)]
                }),
            b0:
            graph_types.GraphNode(
                b, {
                    bo_0: [graph_types.InputTaggedNode(a1, ai_1)],
                    bo_1: [graph_types.InputTaggedNode(b1, bi_0)],
                    bo_2: [
                        graph_types.InputTaggedNode(b0, bi_0),
                        graph_types.InputTaggedNode(b1, bi_0)
                    ]
                }),
            b1:
            graph_types.GraphNode(
                b, {
                    bo_0: [graph_types.InputTaggedNode(b0, bi_0)],
                    bo_1: [graph_types.InputTaggedNode(a0, ai_0)],
                    bo_2: [
                        graph_types.InputTaggedNode(b0, bi_0),
                        graph_types.InputTaggedNode(b1, bi_0)
                    ]
                }),
        }
        return schema, test_graph