def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: # We will first linearly project the original node features as 'embeddings'. embedder = jraph.GraphMapFeatures( embed_node_fn=nn.Dense(self.latent_size)) processed_graphs = embedder(graphs) # Now, we will apply the GCN once for each message-passing round. for _ in range(self.message_passing_steps): mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers update_node_fn = jraph.concatenated_args( MLP(mlp_feature_sizes, dropout_rate=self.dropout_rate, deterministic=self.deterministic)) graph_conv = jraph.GraphConvolution(update_node_fn=update_node_fn, add_self_edges=True) if self.skip_connections: processed_graphs = add_graphs_tuples( graph_conv(processed_graphs), processed_graphs) else: processed_graphs = graph_conv(processed_graphs) if self.layer_norm: processed_graphs = processed_graphs._replace( nodes=nn.LayerNorm()(processed_graphs.nodes), ) # We apply the pooling operation to get a 'global' embedding. processed_graphs = self.pool(processed_graphs) # Now, we decode this to get the required output logits. decoder = jraph.GraphMapFeatures( embed_global_fn=nn.Dense(self.output_globals_size)) processed_graphs = decoder(processed_graphs) return processed_graphs
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree: """Defines a graph neural network. Args: graph: GraphsTuple the network processes. Returns: processed nodes. """ gn = jraph.GraphConvolution( update_node_fn=lambda n: jax.nn.relu(hk.Linear(5)(n)), add_self_edges=True) graph = gn(graph) gn = jraph.GraphConvolution(update_node_fn=hk.Linear(2)) graph = gn(graph) return graph.nodes
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree: """Implements the GCN from Kipf et al https://arxiv.org/pdf/1609.02907.pdf. A' = D^{-0.5} A D^{-0.5} Z = f(X, A') = A' relu(A' X W_0) W_1 Args: graph: GraphsTuple the network processes. Returns: processed nodes. """ gn = jraph.GraphConvolution(update_node_fn=hk.Linear(5, with_bias=False), add_self_edges=True) graph = gn(graph) graph = graph._replace(nodes=jax.nn.relu(graph.nodes)) gn = jraph.GraphConvolution(update_node_fn=hk.Linear(2, with_bias=False)) graph = gn(graph) return graph.nodes
def __call__(self, graph, train=True): for i, latent_size in enumerate(self.features): gc = jraph.GraphConvolution(nn.Dense(latent_size), add_self_edges=False) graph = gc(graph) act_fn = layers.Activation(self.activation) graph = graph._replace(nodes=act_fn(graph.nodes)) if i == len(self.features) - 1: return graph.nodes dout = nn.Dropout(rate=self.drop_rate) graph = graph._replace( nodes=dout(graph.nodes, deterministic=not train))