Ejemplo n.º 1
0
    def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
        # We will first linearly project the original node features as 'embeddings'.
        embedder = jraph.GraphMapFeatures(
            embed_node_fn=nn.Dense(self.latent_size))
        processed_graphs = embedder(graphs)

        # Now, we will apply the GCN once for each message-passing round.
        for _ in range(self.message_passing_steps):
            mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers
            update_node_fn = jraph.concatenated_args(
                MLP(mlp_feature_sizes,
                    dropout_rate=self.dropout_rate,
                    deterministic=self.deterministic))
            graph_conv = jraph.GraphConvolution(update_node_fn=update_node_fn,
                                                add_self_edges=True)

            if self.skip_connections:
                processed_graphs = add_graphs_tuples(
                    graph_conv(processed_graphs), processed_graphs)
            else:
                processed_graphs = graph_conv(processed_graphs)

            if self.layer_norm:
                processed_graphs = processed_graphs._replace(
                    nodes=nn.LayerNorm()(processed_graphs.nodes), )

        # We apply the pooling operation to get a 'global' embedding.
        processed_graphs = self.pool(processed_graphs)

        # Now, we decode this to get the required output logits.
        decoder = jraph.GraphMapFeatures(
            embed_global_fn=nn.Dense(self.output_globals_size))
        processed_graphs = decoder(processed_graphs)

        return processed_graphs
Ejemplo n.º 2
0
    def __call__(self, graph, train):
        dropout = nn.Dropout(rate=self.dropout_rate, deterministic=not train)

        graph = graph._replace(
            globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))

        embedder = jraph.GraphMapFeatures(
            embed_node_fn=_make_embed(self.latent_dim),
            embed_edge_fn=_make_embed(self.latent_dim))
        graph = embedder(graph)

        for _ in range(self.num_message_passing_steps):
            net = jraph.GraphNetwork(update_edge_fn=_make_mlp(self.hidden_dims,
                                                              dropout=dropout),
                                     update_node_fn=_make_mlp(self.hidden_dims,
                                                              dropout=dropout),
                                     update_global_fn=_make_mlp(
                                         self.hidden_dims, dropout=dropout))

            graph = net(graph)

        # Map globals to represent the final result
        decoder = jraph.GraphMapFeatures(
            embed_global_fn=nn.Dense(self.num_outputs))
        graph = decoder(graph)

        return graph.globals
Ejemplo n.º 3
0
    def __call__(self, graph):
        # Encoder.
        encoder = MultiLayerPerceptron([self.latent_size] *
                                       self.num_encoder_layers,
                                       self.activation,
                                       skip_connections=False,
                                       activate_final=True,
                                       name='encoder')
        graph = jraph.GraphMapFeatures(embed_node_fn=encoder)(graph)

        # Core.
        for hop in range(self.num_message_passing_steps):
            node_update_fn = MultiLayerPerceptron([self.latent_size],
                                                  self.activation,
                                                  skip_connections=True,
                                                  activate_final=True,
                                                  name=f'core_{hop}')
            core = OneHopGraphConvolution(update_fn=node_update_fn)
            graph = core(graph)

        # Decoder.
        decoder = MultiLayerPerceptron([self.latent_size] *
                                       (self.num_decoder_layers - 1) +
                                       [self.num_classes],
                                       self.activation,
                                       skip_connections=False,
                                       activate_final=False,
                                       name='decoder')
        graph = jraph.GraphMapFeatures(embed_node_fn=decoder)(graph)
        return graph
Ejemplo n.º 4
0
    def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
        # We will first linearly project the original features as 'embeddings'.
        embedder = jraph.GraphMapFeatures(
            embed_node_fn=nn.Dense(self.latent_size),
            embed_edge_fn=nn.Dense(self.latent_size),
            embed_global_fn=nn.Dense(self.latent_size))
        processed_graphs = embedder(graphs)

        # Now, we will apply a Graph Network once for each message-passing round.
        mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers
        for _ in range(self.message_passing_steps):
            if self.use_edge_model:
                update_edge_fn = jraph.concatenated_args(
                    MLP(mlp_feature_sizes,
                        dropout_rate=self.dropout_rate,
                        deterministic=self.deterministic))
            else:
                update_edge_fn = None

            update_node_fn = jraph.concatenated_args(
                MLP(mlp_feature_sizes,
                    dropout_rate=self.dropout_rate,
                    deterministic=self.deterministic))
            update_global_fn = jraph.concatenated_args(
                MLP(mlp_feature_sizes,
                    dropout_rate=self.dropout_rate,
                    deterministic=self.deterministic))

            graph_net = jraph.GraphNetwork(update_node_fn=update_node_fn,
                                           update_edge_fn=update_edge_fn,
                                           update_global_fn=update_global_fn)

            if self.skip_connections:
                processed_graphs = add_graphs_tuples(
                    graph_net(processed_graphs), processed_graphs)
            else:
                processed_graphs = graph_net(processed_graphs)

            if self.layer_norm:
                processed_graphs = processed_graphs._replace(
                    nodes=nn.LayerNorm()(processed_graphs.nodes),
                    edges=nn.LayerNorm()(processed_graphs.edges),
                    globals=nn.LayerNorm()(processed_graphs.globals),
                )

        # Since our graph-level predictions will be at globals, we will
        # decode to get the required output logits.
        decoder = jraph.GraphMapFeatures(
            embed_global_fn=nn.Dense(self.output_globals_size))
        processed_graphs = decoder(processed_graphs)

        return processed_graphs
Ejemplo n.º 5
0
def network_definition(
    graph: jraph.GraphsTuple,
    num_message_passing_steps: int = 5) -> jraph.ArrayTree:
  """Defines a graph neural network.

  Args:
    graph: Graphstuple the network processes.
    num_message_passing_steps: number of message passing steps.

  Returns:
    Decoded nodes.
  """
  embedding = jraph.GraphMapFeatures(
      embed_edge_fn=jax.vmap(hk.Linear(output_size=16)),
      embed_node_fn=jax.vmap(hk.Linear(output_size=16)))
  graph = embedding(graph)

  @jax.vmap
  @jraph.concatenated_args
  def update_fn(features):
    net = hk.Sequential([
        hk.Linear(10), jax.nn.relu,
        hk.Linear(10), jax.nn.relu,
        hk.Linear(10), jax.nn.relu])
    return net(features)

  for _ in range(num_message_passing_steps):
    gn = jraph.InteractionNetwork(
        update_edge_fn=update_fn,
        update_node_fn=update_fn,
        include_sent_messages_in_node_update=True)
    graph = gn(graph)

  return hk.Linear(2)(graph.nodes)
Ejemplo n.º 6
0
    def _encoder(
        self,
        graph: jraph.GraphsTuple,
        is_training: bool,
    ) -> jraph.GraphsTuple:
        """Builds the encoder."""
        del is_training  # unused
        graph = self._prepare_features(graph)

        # Run encoders in all of the node, edge and global features.
        output_sizes = [self._config.mlp_hidden_size] * self._config.mlp_layers
        output_sizes += [self._config.latent_size]
        build_mlp = functools.partial(
            _build_mlp,
            output_sizes=output_sizes,
            use_layer_norm=self._config.use_layer_norm,
        )

        gmf = jraph.GraphMapFeatures(
            embed_edge_fn=build_mlp("edge_encoder"),
            embed_node_fn=build_mlp("node_encoder"),
            embed_global_fn=None
            if self._config.ignore_globals else build_mlp("global_encoder"),
        )
        return gmf(graph)
Ejemplo n.º 7
0
def net_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    # Add a global paramater for graph classification.
    graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1]))
    embedder = jraph.GraphMapFeatures(hk.Linear(128), hk.Linear(128),
                                      hk.Linear(128))
    net = jraph.GraphNetwork(update_node_fn=node_update_fn,
                             update_edge_fn=edge_update_fn,
                             update_global_fn=update_global_fn)
    return net(embedder(graph))
Ejemplo n.º 8
0
  def __call__(self, graph):
    # Add a global parameter for graph classification.
    graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1]))

    embedder = jraph.GraphMapFeatures(
        embed_node_fn=make_embed_fn(self.latent_size),
        embed_edge_fn=make_embed_fn(self.latent_size),
        embed_global_fn=make_embed_fn(self.latent_size))
    net = jraph.GraphNetwork(
        update_node_fn=make_mlp(self.mlp_features),
        update_edge_fn=make_mlp(self.mlp_features),
        # The global update outputs size 2 for binary classification.
        update_global_fn=make_mlp(self.mlp_features + (2,)))  # pytype: disable=unsupported-operands
    return net(embedder(graph))
Ejemplo n.º 9
0
    def __init__(self,
                 n_recurrences: int,
                 mlp_sizes: Tuple[int, ...],
                 mlp_kwargs: Optional[Dict[str, Any]] = None,
                 format: partition.NeighborListFormat = partition.Dense,
                 name: str = 'GraphNetEncoder'):
        super(GraphNetEncoder, self).__init__(name=name)

        if mlp_kwargs is None:
            mlp_kwargs = {}

        self._n_recurrences = n_recurrences

        embedding_fn = lambda name: hk.nets.MLP(output_sizes=mlp_sizes,
                                                activate_final=True,
                                                name=name,
                                                **mlp_kwargs)

        model_fn = lambda name: lambda *args: hk.nets.MLP(
            output_sizes=mlp_sizes,
            activate_final=True,
            name=name,
            **mlp_kwargs)(jnp.concatenate(args, axis=-1))

        if format is partition.Dense:
            self._encoder = GraphMapFeatures(embedding_fn('EdgeEncoder'),
                                             embedding_fn('NodeEncoder'),
                                             embedding_fn('GlobalEncoder'))
            self._propagation_network = lambda: GraphNetwork(
                model_fn('EdgeFunction'), model_fn('NodeFunction'),
                model_fn('GlobalFunction'))
        elif format is partition.Sparse:
            self._encoder = jraph.GraphMapFeatures(
                embedding_fn('EdgeEncoder'), embedding_fn('NodeEncoder'),
                embedding_fn('GlobalEncoder'))
            self._propagation_network = lambda: jraph.GraphNetwork(
                model_fn('EdgeFunction'), model_fn('NodeFunction'),
                model_fn('GlobalFunction'))
        else:
            raise ValueError()
Ejemplo n.º 10
0
 def _encode(
     self,
     graph: jraph.GraphsTuple,
     is_training: bool,
 ) -> jraph.GraphsTuple:
     node_embed_fn = build_update_fn(
         'node_encoder',
         self._output_sizes,
         activation=self._activation,
         normalization_type=self._normalization_type,
         is_training=is_training,
     )
     edge_embed_fn = build_update_fn(
         'edge_encoder',
         self._output_sizes,
         activation=self._activation,
         normalization_type=self._normalization_type,
         is_training=is_training,
     )
     gn = jraph.GraphMapFeatures(edge_embed_fn, node_embed_fn)
     graph = gn(graph)
     if is_training:
         graph = self._dropout_graph(graph)
     return graph