Exemplo n.º 1
0
    def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
        # We will first linearly project the original node features as 'embeddings'.
        embedder = jraph.GraphMapFeatures(
            embed_node_fn=nn.Dense(self.latent_size))
        processed_graphs = embedder(graphs)

        # Now, we will apply the GCN once for each message-passing round.
        for _ in range(self.message_passing_steps):
            mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers
            update_node_fn = jraph.concatenated_args(
                MLP(mlp_feature_sizes,
                    dropout_rate=self.dropout_rate,
                    deterministic=self.deterministic))
            graph_conv = jraph.GraphConvolution(update_node_fn=update_node_fn,
                                                add_self_edges=True)

            if self.skip_connections:
                processed_graphs = add_graphs_tuples(
                    graph_conv(processed_graphs), processed_graphs)
            else:
                processed_graphs = graph_conv(processed_graphs)

            if self.layer_norm:
                processed_graphs = processed_graphs._replace(
                    nodes=nn.LayerNorm()(processed_graphs.nodes), )

        # We apply the pooling operation to get a 'global' embedding.
        processed_graphs = self.pool(processed_graphs)

        # Now, we decode this to get the required output logits.
        decoder = jraph.GraphMapFeatures(
            embed_global_fn=nn.Dense(self.output_globals_size))
        processed_graphs = decoder(processed_graphs)

        return processed_graphs
Exemplo n.º 2
0
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree:
    """Defines a graph neural network.

  Args:
    graph: GraphsTuple the network processes.

  Returns:
    processed nodes.
  """
    gn = jraph.GraphConvolution(
        update_node_fn=lambda n: jax.nn.relu(hk.Linear(5)(n)),
        add_self_edges=True)
    graph = gn(graph)

    gn = jraph.GraphConvolution(update_node_fn=hk.Linear(2))
    graph = gn(graph)
    return graph.nodes
Exemplo n.º 3
0
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree:
    """Implements the GCN from Kipf et al https://arxiv.org/pdf/1609.02907.pdf.

     A' = D^{-0.5} A D^{-0.5}
     Z = f(X, A') = A' relu(A' X W_0) W_1

  Args:
    graph: GraphsTuple the network processes.

  Returns:
    processed nodes.
  """
    gn = jraph.GraphConvolution(update_node_fn=hk.Linear(5, with_bias=False),
                                add_self_edges=True)
    graph = gn(graph)
    graph = graph._replace(nodes=jax.nn.relu(graph.nodes))
    gn = jraph.GraphConvolution(update_node_fn=hk.Linear(2, with_bias=False))
    graph = gn(graph)
    return graph.nodes
Exemplo n.º 4
0
 def __call__(self, graph, train=True):
     for i, latent_size in enumerate(self.features):
         gc = jraph.GraphConvolution(nn.Dense(latent_size),
                                     add_self_edges=False)
         graph = gc(graph)
         act_fn = layers.Activation(self.activation)
         graph = graph._replace(nodes=act_fn(graph.nodes))
         if i == len(self.features) - 1:
             return graph.nodes
         dout = nn.Dropout(rate=self.drop_rate)
         graph = graph._replace(
             nodes=dout(graph.nodes, deterministic=not train))