Example #1
0
def network_definition(
    graph: jraph.GraphsTuple,
    num_message_passing_steps: int = 5) -> jraph.ArrayTree:
  """Defines a graph neural network.

  Args:
    graph: Graphstuple the network processes.
    num_message_passing_steps: number of message passing steps.

  Returns:
    Decoded nodes.
  """
  embedding = jraph.GraphMapFeatures(
      embed_edge_fn=jax.vmap(hk.Linear(output_size=16)),
      embed_node_fn=jax.vmap(hk.Linear(output_size=16)))
  graph = embedding(graph)

  @jax.vmap
  @jraph.concatenated_args
  def update_fn(features):
    net = hk.Sequential([
        hk.Linear(10), jax.nn.relu,
        hk.Linear(10), jax.nn.relu,
        hk.Linear(10), jax.nn.relu])
    return net(features)

  for _ in range(num_message_passing_steps):
    gn = jraph.InteractionNetwork(
        update_edge_fn=update_fn,
        update_node_fn=update_fn,
        include_sent_messages_in_node_update=True)
    graph = gn(graph)

  return hk.Linear(2)(graph.nodes)
Example #2
0
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree:
    """`InteractionNetwork` with an LSTM in the edge update."""

    # LSTM that will keep a memory of the inputs to the edge model.
    edge_fn_lstm = hk.LSTM(hidden_size=HIDDEN_SIZE)

    # MLPs used in the edge and the node model. Note that in this instance
    # the output size matches the input size so the same model can be run
    # iteratively multiple times. In a real model, this would usually be achieved
    # by first using an encoder in the input data into a common `EMBEDDING_SIZE`.
    edge_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE])
    node_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE])

    # Initialize the edge features to contain both the input edge embedding
    # and initial LSTM state. Note for the nodes we only have an embedding since
    # in this example nodes do not use a `node_fn_lstm`, but for analogy, we
    # still put it in a `StatefulField`.
    graph = graph._replace(
        edges=StatefulField(embedding=graph.edges,
                            state=edge_fn_lstm.initial_state(
                                graph.edges.shape[0])),
        nodes=StatefulField(embedding=graph.nodes, state=None),
    )

    def update_edge_fn(edges, sender_nodes, receiver_nodes):
        # We will run an LSTM memory on the inputs first, and then
        # process the output of the LSTM with an MLP.
        edge_inputs = jnp.concatenate([
            edges.embedding, sender_nodes.embedding, receiver_nodes.embedding
        ],
                                      axis=-1)
        lstm_output, updated_state = edge_fn_lstm(edge_inputs, edges.state)
        updated_edges = StatefulField(
            embedding=edge_fn_mlp(lstm_output),
            state=updated_state,
        )
        return updated_edges

    def update_node_fn(nodes, received_edges):
        # Note `received_edges.state` will also contain the aggregated state for
        # all received edges, which we may choose to use in the node update.
        node_inputs = jnp.concatenate(
            [nodes.embedding, received_edges.embedding], axis=-1)
        updated_nodes = StatefulField(embedding=node_fn_mlp(node_inputs),
                                      state=None)
        return updated_nodes

    recurrent_graph_network = jraph.InteractionNetwork(
        update_edge_fn=update_edge_fn, update_node_fn=update_node_fn)

    # Apply the model recurrently for 10 message passing steps.
    # If instead we intended to use the LSTM to process a sequence of features
    # for each node/edge, here we would select the corresponding inputs from the
    # sequence along the sequence axis of the nodes/edges features to build the
    # correct input graph for each step of the iteration.
    num_message_passing_steps = 10
    for _ in range(num_message_passing_steps):
        graph = recurrent_graph_network(graph)

    return graph
Example #3
0
def build_gn(
    output_sizes: Sequence[int],
    activation: Callable[[jnp.ndarray], jnp.ndarray],
    suffix: str,
    use_sent_edges: bool,
    is_training: bool,
    dropedge_rate: float,
    normalization_type: str,
    aggregation_function: str,
):
    """Builds an InteractionNetwork with MLP update functions."""
    node_update_fn = build_update_fn(
        f'node_processor_{suffix}',
        output_sizes,
        activation=activation,
        normalization_type=normalization_type,
        is_training=is_training,
    )
    edge_update_fn = build_update_fn(
        f'edge_processor_{suffix}',
        output_sizes,
        activation=activation,
        normalization_type=normalization_type,
        is_training=is_training,
    )

    def maybe_dropedge(x):
        """Dropout on edge messages."""
        if not is_training:
            return x
        return x * hk.dropout(
            hk.next_rng_key(),
            dropedge_rate,
            jnp.ones([x.shape[0], 1]),
        )

    dropped_edge_update_fn = lambda *args: maybe_dropedge(edge_update_fn(*args)
                                                          )
    return jraph.InteractionNetwork(
        update_edge_fn=dropped_edge_update_fn,
        update_node_fn=node_update_fn,
        aggregate_edges_for_nodes_fn=_REDUCER_NAMES[aggregation_function],
        include_sent_messages_in_node_update=use_sent_edges,
    )
def network_definition(graph):
    """Defines a graph neural network.

  Args:
    graph: Graphstuple the network processes.

  Returns:
    Decoded nodes.
  """
    model_fn = functools.partial(hk.nets.MLP,
                                 w_init=hk.initializers.VarianceScaling(1.0),
                                 b_init=hk.initializers.VarianceScaling(1.0))
    mlp_sizes = (64, 64)
    num_message_passing_steps = 7

    node_encoder = model_fn(output_sizes=mlp_sizes, activate_final=True)
    edge_encoder = model_fn(output_sizes=mlp_sizes, activate_final=True)
    node_decoder = model_fn(output_sizes=mlp_sizes + (1, ),
                            activate_final=False)

    node_encoding = node_encoder(graph.nodes)
    edge_encoding = edge_encoder(graph.edges)
    graph = graph._replace(nodes=node_encoding, edges=edge_encoding)

    update_edge_fn = jraph.concatenated_args(
        model_fn(output_sizes=mlp_sizes, activate_final=True))
    update_node_fn = jraph.concatenated_args(
        model_fn(output_sizes=mlp_sizes, activate_final=True))
    gn = jraph.InteractionNetwork(update_edge_fn=update_edge_fn,
                                  update_node_fn=update_node_fn,
                                  include_sent_messages_in_node_update=True)

    for _ in range(num_message_passing_steps):
        graph = graph._replace(
            nodes=jnp.concatenate([graph.nodes, node_encoding], axis=-1),
            edges=jnp.concatenate([graph.edges, edge_encoding], axis=-1))
        graph = gn(graph)

    return jnp.squeeze(node_decoder(graph.nodes), axis=-1)
Example #5
0
 def net_fn(graph: jraph.GraphsTuple):
     unf = jraph.concatenated_args(conway_mlp)
     net = jraph.InteractionNetwork(update_edge_fn=lambda e, n_s, n_r: n_s,
                                    update_node_fn=jax.vmap(unf))
     return net(graph)
Example #6
0
    def _processor(
        self,
        graph: jraph.GraphsTuple,
        is_training: bool,
    ) -> jraph.GraphsTuple:
        """Builds the processor."""
        output_sizes = [self._config.mlp_hidden_size] * self._config.mlp_layers
        output_sizes += [self._config.latent_size]
        build_mlp = functools.partial(
            _build_mlp,
            output_sizes=output_sizes,
            use_layer_norm=self._config.use_layer_norm,
        )

        shared_weights = self._config.shared_message_passing_weights
        node_reducer = _REDUCER_NAMES[self._config.node_reducer]
        global_reducer = _REDUCER_NAMES[self._config.global_reducer]

        def dropout_if_training(fn, dropout_rate: float):
            def wrapped(*args):
                out = fn(*args)
                if is_training:
                    mask = hk.dropout(hk.next_rng_key(), dropout_rate,
                                      jnp.ones([out.shape[0], 1]))
                    out = out * mask
                return out

            return wrapped

        num_mps = self._config.num_message_passing_steps
        for step in range(num_mps):
            if step == 0 or not shared_weights:
                suffix = "shared" if shared_weights else step

                update_edge_fn = dropout_if_training(
                    build_mlp(f"edge_processor_{suffix}"),
                    dropout_rate=self._config.dropedge_rate)

                update_node_fn = dropout_if_training(
                    build_mlp(f"node_processor_{suffix}"),
                    dropout_rate=self._config.dropnode_rate)

                if self._config.ignore_globals:
                    gnn = jraph.InteractionNetwork(
                        update_edge_fn=update_edge_fn,
                        update_node_fn=update_node_fn,
                        aggregate_edges_for_nodes_fn=node_reducer)
                else:
                    gnn = jraph.GraphNetwork(
                        update_edge_fn=update_edge_fn,
                        update_node_fn=update_node_fn,
                        update_global_fn=build_mlp(
                            f"global_processor_{suffix}"),
                        aggregate_edges_for_nodes_fn=node_reducer,
                        aggregate_nodes_for_globals_fn=global_reducer,
                        aggregate_edges_for_globals_fn=global_reducer,
                    )

            mode = self._config.processor_mode

            if mode == "mlp":
                graph = gnn(graph)

            elif mode == "resnet":
                new_graph = gnn(graph)
                graph = graph._replace(
                    nodes=graph.nodes + new_graph.nodes,
                    edges=graph.edges + new_graph.edges,
                    globals=graph.globals + new_graph.globals,
                )
            else:
                raise ValueError(f"Unknown processor_mode `{mode}`")

            if self._config.mask_padding_graph_at_every_step:
                graph = _mask_out_padding_graph(graph)

        return graph