コード例 #1
0
    def test_binary_index_edge_embeddings(self):
        (edges, forward_edge_type_indices,
         reverse_edge_type_indices) = self._setup_edges()
        result = graph_layers.binary_index_edge_embeddings(
            edges,
            num_nodes=4,
            num_edge_types=3,
            forward_edge_type_indices=forward_edge_type_indices,
            reverse_edge_type_indices=reverse_edge_type_indices)

        expected = np.zeros([4, 4, 3], np.float32)
        expected[2, 0, 0] = 1
        expected[0, 1, 1] = 1
        expected[1, 2, 1] = 1
        expected[0, 2, 1] = 1
        expected[1, 0, 2] = 1
        expected[2, 1, 2] = 1
        expected[2, 0, 2] = 1
        np.testing.assert_allclose(result, expected)
コード例 #2
0
def token_graph_model_core(
    input_graph,
    static_metadata,
    encoding_info,
    node_embedding_dim=gin.REQUIRED,
    edge_embedding_dim=gin.REQUIRED,
    forward_edge_types=gin.REQUIRED,
    reverse_edge_types=gin.REQUIRED,
    components=gin.REQUIRED,
):
    """Transforms an input graph into final node embeddings, with no output head.

  Args:
    input_graph: Input graph for this example.
    static_metadata: Metadata about the padded size of this graph.
    encoding_info: How the example was encoded.
    node_embedding_dim: Dimension of node embedding space.
    edge_embedding_dim: Dimension of edge embedding space. If None, just
      concatenate each edge type's adjacency matrix.
    forward_edge_types: Edge types to use in the forward direction. As a list of
      lists to allow configuring groups of edges in config files; this will be
      flattened before use.
    reverse_edge_types: Edge types to use in the reverse direction. Note that
      reversed edge types are given a separate embedding from forward edge
      types; undirected edges should be represented by adding two edges in
      opposite directions and then only using `forward_edge_types`. Also a list
      of lists, as above.
    components: List of sublayer types. Each element should be the name of one
      of the components defined above.

  Returns:
    Final node embeddings after running all components.
  """
    num_node_types = len(encoding_info.builder.node_types)
    num_edge_types = len(encoding_info.edge_types)
    edge_types_to_indices = {
        edge_type: i
        for i, edge_type in enumerate(encoding_info.edge_types)
    }
    # pylint: disable=g-complex-comprehension
    forward_edge_type_indices = [
        edge_types_to_indices[type_str] for group in forward_edge_types
        for type_str in group
    ]
    reverse_edge_type_indices = [
        edge_types_to_indices[type_str] for group in reverse_edge_types
        for type_str in group
    ]
    # pylint: enable=g-complex-comprehension

    # Build initial node embeddings.
    node_embeddings = (graph_layers.PositionalAndTypeNodeEmbedding(
        node_types=input_graph.bundle.node_types,
        num_node_types=num_node_types,
        embedding_dim=node_embedding_dim,
    ) + graph_layers.TokenOperatorNodeEmbedding(
        operator=input_graph.tokens,
        vocab_size=encoding_info.token_encoder.vocab_size,
        num_nodes=static_metadata.num_nodes,
        embedding_dim=node_embedding_dim,
    ))

    if edge_embedding_dim is not None:
        # Learn initial edge embeddings.
        edge_embeddings = graph_layers.LearnableEdgeEmbeddings(
            edges=input_graph.bundle.edges,
            num_nodes=static_metadata.num_nodes,
            num_edge_types=num_edge_types,
            forward_edge_type_indices=forward_edge_type_indices,
            reverse_edge_type_indices=reverse_edge_type_indices,
            embedding_dim=edge_embedding_dim)
    else:
        # Build binary edge embeddings.
        edge_embeddings = graph_layers.binary_index_edge_embeddings(
            edges=input_graph.bundle.edges,
            num_nodes=static_metadata.num_nodes,
            num_edge_types=num_edge_types,
            forward_edge_type_indices=forward_edge_type_indices,
            reverse_edge_type_indices=reverse_edge_type_indices)

    # Run the core component stack.
    # Depending on whether edge_embedding_dim is provided, we either concatenate
    # new edge types or embed them (see end_to_end_stack for details).
    graph_context = end_to_end_stack.SharedGraphContext(
        bundle=input_graph.bundle,
        static_metadata=static_metadata,
        edge_types_to_indices=edge_types_to_indices,
        builder=encoding_info.builder,
        edges_are_embedded=edge_embedding_dim is not None)
    for component in components:
        component_fn = end_to_end_stack.ALL_COMPONENTS[component]
        node_embeddings, edge_embeddings = component_fn(
            graph_context, node_embeddings, edge_embeddings)

    return node_embeddings
コード例 #3
0
    def apply(
        self,
        example,
        graph_metadata,
        edge_types_to_indices,
        forward_edge_types,
        reverse_edge_types,
        use_position_embeddings=gin.REQUIRED,
        learn_edge_embeddings=gin.REQUIRED,
        model_type=gin.REQUIRED,
        nodewise=False,
        nodewise_loop_chunk_size=None,
    ):
        """Single-forward-pass baseline model for edge-supervision task.

    This model propagates information through the graph, then does a pairwise
    bilinear readout to determine which edges to add.

    Args:
      example: Example to run the automaton on.
      graph_metadata: Statically-known metadata about the graph size. If
        encoded_graph is padded, this should reflect the padded size, not the
        original size.
      edge_types_to_indices: Mapping from edge type names to edge type indices.
      forward_edge_types: Edge types to use in the forward direction. As a list
        of lists to allow configuring groups of edges in config files; this will
        be flattened before use.
      reverse_edge_types: Edge types to use in the reverse direction. Note that
        reversed edge types are given a separate embedding from forward edge
        types; undirected edges should be represented by adding two edges in
        opposite directions and then only using `forward_edge_types`. Also a
        list of lists, as above.
      use_position_embeddings: Whether to add position embeddings to node
        embeddings.
      learn_edge_embeddings: Whether to learn an edge embedding for each edge
        type (instead of using a one-hot embedding).
      model_type: One of {"ggnn", "transformer", "nri_encoder"}
      nodewise: Whether to have separate sets of node embeddings for each
        possible start node.
      nodewise_loop_chunk_size: Optional integer, which must divide the number
        of nodes. Splits the nodes into chunks of this size, and runs the model
        on each of those splits in a loop; this recudes the memory usage. Only
        used when nodewise=True.

    Returns:
      <float32[num_nodes, num_nodes]> matrix of binary logits for a weighted
      adjacency matrix corresponding to the predicted output edges.
    """
        # Node types come directly from the schema (for parity with the automaton).
        num_node_types = len(py_ast_graphs.BUILDER.node_types)
        # Edge types are potentially task-specific.
        num_edge_types = len(edge_types_to_indices)
        # pylint: disable=g-complex-comprehension
        forward_edge_type_indices = [
            edge_types_to_indices[type_str] for group in forward_edge_types
            for type_str in group
        ]
        reverse_edge_type_indices = [
            edge_types_to_indices[type_str] for group in reverse_edge_types
            for type_str in group
        ]
        # pylint: enable=g-complex-comprehension

        # Embed the nodes.
        if use_position_embeddings:
            node_embeddings = graph_layers.PositionalAndTypeNodeEmbedding(
                node_types=example.node_types, num_node_types=num_node_types)
        else:
            node_embeddings = graph_layers.NodeTypeNodeEmbedding(
                node_types=example.node_types, num_node_types=num_node_types)

        # Embed the edges.
        if learn_edge_embeddings:
            edge_embeddings = graph_layers.LearnableEdgeEmbeddings(
                edges=example.edges,
                num_nodes=graph_metadata.num_nodes,
                num_edge_types=num_edge_types,
                forward_edge_type_indices=forward_edge_type_indices,
                reverse_edge_type_indices=reverse_edge_type_indices)
        else:
            edge_embeddings = graph_layers.binary_index_edge_embeddings(
                edges=example.edges,
                num_nodes=graph_metadata.num_nodes,
                num_edge_types=num_edge_types,
                forward_edge_type_indices=forward_edge_type_indices,
                reverse_edge_type_indices=reverse_edge_type_indices)

        def run_steps(node_embeddings):
            """Runs propagation and updates."""
            if model_type == "ggnn":
                final_embeddings = ggnn_steps(node_embeddings, edge_embeddings)
            elif model_type == "transformer":
                neighbor_mask = graph_layers.edge_mask(
                    edges=example.edges,
                    num_nodes=graph_metadata.num_nodes,
                    num_edge_types=num_edge_types,
                    forward_edge_type_indices=forward_edge_type_indices,
                    reverse_edge_type_indices=reverse_edge_type_indices)
                # Allow nodes to attend to themselves
                neighbor_mask = jnp.maximum(neighbor_mask,
                                            jnp.eye(graph_metadata.num_nodes))
                final_embeddings = transformer_steps(
                    node_embeddings,
                    edge_embeddings,
                    neighbor_mask,
                    num_real_nodes_per_graph=example.graph_metadata.num_nodes)
            elif model_type == "nri_encoder":
                final_embeddings = nri_steps(
                    node_embeddings,
                    edge_embeddings,
                    num_real_nodes_per_graph=example.graph_metadata.num_nodes)
            return final_embeddings

        if nodewise:
            assert model_type != "nri_encoder", "Nodewise NRI model is not defined."
            # Add in a learned start node embedding, and broadcast node states out to
            # <float32[num_nodes, num_nodes, node_embedding_dim]>
            node_embedding_dim = node_embeddings.shape[-1]
            start_node_embedding = self.param(
                "start_node_embedding",
                shape=(node_embedding_dim, ),
                initializer=jax.nn.initializers.normal())

            stacked_node_embeddings = (jnp.broadcast_to(
                node_embeddings[None, :, :],
                (graph_metadata.num_nodes, graph_metadata.num_nodes,
                 node_embedding_dim)).at[
                     jnp.arange(graph_metadata.num_nodes),
                     jnp.arange(graph_metadata.num_nodes)].add(
                         jnp.broadcast_to(
                             start_node_embedding,
                             (graph_metadata.num_nodes, node_embedding_dim))))

            # final_embeddings_from_each_source will be
            # [num_nodes, num_nodes, node_embedding_dim]
            if nodewise_loop_chunk_size and not self.is_initializing():
                grouped_stacked_node_embeddings = stacked_node_embeddings.reshape(
                    (-1, nodewise_loop_chunk_size, graph_metadata.num_nodes,
                     node_embedding_dim))
                grouped_final_embeddings_from_each_source = jax.lax.map(
                    jax.vmap(run_steps), grouped_stacked_node_embeddings)
                final_embeddings_from_each_source = (
                    grouped_final_embeddings_from_each_source.reshape(
                        (graph_metadata.num_nodes, ) +
                        grouped_final_embeddings_from_each_source.shape[2:]))
            else:
                final_embeddings_from_each_source = jax.vmap(run_steps)(
                    stacked_node_embeddings)

            # Extract predictions with a linear transformation.
            logits = flax.deprecated.nn.Dense(
                final_embeddings_from_each_source,
                features=1,
                name="target_readout",
            ).squeeze(-1)

        elif model_type == "nri_encoder":
            # Propagate the node embeddings as-is, then use NRI to construct edges.
            final_embeddings = run_steps(node_embeddings)
            logits = graph_layers.NRIReadout(final_embeddings)

        else:
            # Propagate the node embeddings as-is, then extract edges pairwise.
            final_embeddings = run_steps(node_embeddings)
            logits = graph_layers.BilinearPairwiseReadout(final_embeddings)

        return logits