Esempio n. 1
0
 def _dropout_graph(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
     node_key, edge_key = hk.next_rng_keys(2)
     nodes = hk.dropout(node_key, self._dropout_rate, graph.nodes)
     edges = graph.edges
     if not self._disable_edge_updates:
         edges = hk.dropout(edge_key, self._dropout_rate, edges)
     return graph._replace(nodes=nodes, edges=edges)
Esempio n. 2
0
    def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
        """Compute embeddings for each node in the graphs.

    Args:
      graphs: a set of graphs batched into a single graph.  The nodes and edges
        are represented as feature tensors.

    Returns:
      graphs: new graph with node embeddings updated (shape [n_nodes,
        embed_dim]).
    """
        nodes = hk.Linear(self._embed_dim)(graphs.nodes)
        edges = hk.Linear(self._embed_dim)(graphs.edges)

        nodes = hk.LayerNorm(axis=-1, create_scale=True,
                             create_offset=True)(jax.nn.gelu(nodes))
        edges = hk.LayerNorm(axis=-1, create_scale=True,
                             create_offset=True)(jax.nn.gelu(edges))

        graphs = graphs._replace(nodes=nodes, edges=edges)
        graphs = gn.SimpleGraphNet(
            num_layers=self._num_layers,
            msg_hidden_size_factor=self._msg_hidden_size_factor,
            layer_norm=self._use_layer_norm)(graphs)
        return graphs
Esempio n. 3
0
    def __call__(
        self,
        graph: jraph.GraphsTuple,
        is_training: bool,
        stop_gradient_embedding_to_logits: bool = False,
    ) -> ModelOutput:
        # Note that these update configs may need to change if
        # we switch back to GraphNetwork rather than InteractionNetwork.

        graph = self._encode(graph, is_training)
        graph = self._process(graph, is_training)
        node_embeddings = graph.nodes
        node_projections = self._node_mlp(graph, is_training,
                                          self._latent_size, 'projector')
        node_predictions = self._node_mlp(
            graph._replace(nodes=node_projections),
            is_training,
            self._latent_size,
            'predictor',
        )
        if stop_gradient_embedding_to_logits:
            graph = jax.tree_map(jax.lax.stop_gradient, graph)
        node_logits = self._node_mlp(graph, is_training, self._num_classes,
                                     'logits_decoder')
        return ModelOutput(
            node_embeddings=node_embeddings,
            node_logits=node_logits,
            node_embedding_projections=node_projections,
            node_projection_predictions=node_predictions,
        )
Esempio n. 4
0
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree:
    """`InteractionNetwork` with an LSTM in the edge update."""

    # LSTM that will keep a memory of the inputs to the edge model.
    edge_fn_lstm = hk.LSTM(hidden_size=HIDDEN_SIZE)

    # MLPs used in the edge and the node model. Note that in this instance
    # the output size matches the input size so the same model can be run
    # iteratively multiple times. In a real model, this would usually be achieved
    # by first using an encoder in the input data into a common `EMBEDDING_SIZE`.
    edge_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE])
    node_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE])

    # Initialize the edge features to contain both the input edge embedding
    # and initial LSTM state. Note for the nodes we only have an embedding since
    # in this example nodes do not use a `node_fn_lstm`, but for analogy, we
    # still put it in a `StatefulField`.
    graph = graph._replace(
        edges=StatefulField(embedding=graph.edges,
                            state=edge_fn_lstm.initial_state(
                                graph.edges.shape[0])),
        nodes=StatefulField(embedding=graph.nodes, state=None),
    )

    def update_edge_fn(edges, sender_nodes, receiver_nodes):
        # We will run an LSTM memory on the inputs first, and then
        # process the output of the LSTM with an MLP.
        edge_inputs = jnp.concatenate([
            edges.embedding, sender_nodes.embedding, receiver_nodes.embedding
        ],
                                      axis=-1)
        lstm_output, updated_state = edge_fn_lstm(edge_inputs, edges.state)
        updated_edges = StatefulField(
            embedding=edge_fn_mlp(lstm_output),
            state=updated_state,
        )
        return updated_edges

    def update_node_fn(nodes, received_edges):
        # Note `received_edges.state` will also contain the aggregated state for
        # all received edges, which we may choose to use in the node update.
        node_inputs = jnp.concatenate(
            [nodes.embedding, received_edges.embedding], axis=-1)
        updated_nodes = StatefulField(embedding=node_fn_mlp(node_inputs),
                                      state=None)
        return updated_nodes

    recurrent_graph_network = jraph.InteractionNetwork(
        update_edge_fn=update_edge_fn, update_node_fn=update_node_fn)

    # Apply the model recurrently for 10 message passing steps.
    # If instead we intended to use the LSTM to process a sequence of features
    # for each node/edge, here we would select the corresponding inputs from the
    # sequence along the sequence axis of the nodes/edges features to build the
    # correct input graph for each step of the iteration.
    num_message_passing_steps = 10
    for _ in range(num_message_passing_steps):
        graph = recurrent_graph_network(graph)

    return graph
Esempio n. 5
0
def get_corrupted_view(
    graph: jraph.GraphsTuple,
    feature_drop_prob: float,
    edge_drop_prob: float,
    rng_key: jnp.ndarray,
) -> jraph.GraphsTuple:
    """Returns corrupted graph view."""
    node_key, edge_key = jax.random.split(rng_key)

    def mask_feature(x):
        mask = jax.random.bernoulli(node_key, 1 - feature_drop_prob, x.shape)
        return x * mask

    # Randomly mask features with fixed probability.
    nodes = jax.tree_map(mask_feature, graph.nodes)

    # Simulate dropping of edges by changing genuine edges to self-loops on
    # the padded node.
    num_edges = graph.senders.shape[0]
    last_node_idx = graph.n_node.sum() - 1
    edge_mask = jax.random.bernoulli(edge_key, 1 - edge_drop_prob, [num_edges])
    senders = jnp.where(edge_mask, graph.senders, last_node_idx)
    receivers = jnp.where(edge_mask, graph.receivers, last_node_idx)
    # Note that n_edge will now be invalid since edges in the middle of the list
    # will correspond to the final graph. Set n_edge to None to ensure we do not
    # accidentally use this.
    return graph._replace(
        nodes=nodes,
        senders=senders,
        receivers=receivers,
        n_edge=None,
    )
Esempio n. 6
0
def net_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    # Add a global paramater for graph classification.
    graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1]))
    embedder = jraph.GraphMapFeatures(hk.Linear(128), hk.Linear(128),
                                      hk.Linear(128))
    net = jraph.GraphNetwork(update_node_fn=node_update_fn,
                             update_edge_fn=edge_update_fn,
                             update_global_fn=update_global_fn)
    return net(embedder(graph))
Esempio n. 7
0
 def _process(
     self,
     graph: jraph.GraphsTuple,
     is_training: bool,
 ) -> jraph.GraphsTuple:
     for idx in range(self._num_message_passing_steps):
         net = build_gn(output_sizes=self._output_sizes,
                        activation=self._activation,
                        suffix=str(idx),
                        use_sent_edges=self._use_sent_edges,
                        is_training=is_training,
                        dropedge_rate=self._dropedge_rate,
                        normalization_type=self._normalization_type,
                        aggregation_function=self._aggregation_function)
         residual_graph = net(graph)
         graph = graph._replace(nodes=graph.nodes + residual_graph.nodes)
         if not self._disable_edge_updates:
             graph = graph._replace(edges=graph.edges +
                                    residual_graph.edges)
         if is_training:
             graph = self._dropout_graph(graph)
     return graph
Esempio n. 8
0
def _mask_out_padding_graph(
        padded_graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    return padded_graph._replace(
        nodes=jnp.where(
            jraph.get_node_padding_mask(padded_graph)[:, None],
            padded_graph.nodes, 0.),
        edges=jnp.where(
            jraph.get_edge_padding_mask(padded_graph)[:, None],
            padded_graph.edges, 0.),
        globals=jnp.where(
            jraph.get_graph_padding_mask(padded_graph)[:, None],
            padded_graph.globals, 0.),
    )
Esempio n. 9
0
    def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
        """Pooling operation, taken from Jraph."""

        # Equivalent to jnp.sum(n_node), but JIT-able.
        sum_n_node = graphs.nodes.shape[0]
        # To aggregate nodes from each graph to global features,
        # we first construct tensors that map the node to the corresponding graph.
        # Example: if you have `n_node=[1,2]`, we construct the tensor [0, 1, 1].
        n_graph = graphs.n_node.shape[0]
        node_graph_indices = jnp.repeat(jnp.arange(n_graph),
                                        graphs.n_node,
                                        axis=0,
                                        total_repeat_length=sum_n_node)
        # We use the aggregation function to pool the nodes per graph.
        pooled = self.pooling_fn(graphs.nodes, node_graph_indices,
                                 n_graph)  # type: ignore[call-arg]
        return graphs._replace(globals=pooled)
Esempio n. 10
0
    def _prepare_features(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
        """Prepares features keys into flat node, edge and global features."""

        # Collect edge features.
        edge_features_list = [graph.edges["bond_one_hots"]]
        if (self._config.add_relative_displacement
                or self._config.add_relative_distance):
            (relative_displacement,
             relative_distance) = _compute_relative_displacement_and_distance(
                 graph,
                 self._config.relative_displacement_normalization,
                 use_target=False)

            if self._config.add_relative_displacement:
                edge_features_list.append(relative_displacement)
            if self._config.add_relative_distance:
                edge_features_list.append(relative_distance)
            mask_at_edges = _broadcast_global_to_edges(
                graph.globals["positions_nan_mask"], graph)
            edge_features_list.append(mask_at_edges[:,
                                                    None].astype(jnp.float32))

        edge_features = jnp.concatenate(edge_features_list, axis=-1)

        # Collect node features
        node_features_list = [graph.nodes["atom_one_hots"]]

        if self._config.add_absolute_positions:
            node_features_list.append(graph.nodes["positions"] /
                                      self._config.position_normalization)
            mask_at_nodes = _broadcast_global_to_nodes(
                graph.globals["positions_nan_mask"], graph)
            node_features_list.append(mask_at_nodes[:,
                                                    None].astype(jnp.float32))

        node_features = jnp.concatenate(node_features_list, axis=-1)

        global_features = jnp.zeros(
            (len(graph.n_node), self._config.latent_size))
        chex.assert_tree_shape_prefix(global_features, (len(graph.n_node), ))
        return graph._replace(nodes=node_features,
                              edges=edge_features,
                              globals=global_features)
Esempio n. 11
0
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree:
    """Implements the GCN from Kipf et al https://arxiv.org/pdf/1609.02907.pdf.

     A' = D^{-0.5} A D^{-0.5}
     Z = f(X, A') = A' relu(A' X W_0) W_1

  Args:
    graph: GraphsTuple the network processes.

  Returns:
    processed nodes.
  """
    gn = jraph.GraphConvolution(update_node_fn=hk.Linear(5, with_bias=False),
                                add_self_edges=True)
    graph = gn(graph)
    graph = graph._replace(nodes=jax.nn.relu(graph.nodes))
    gn = jraph.GraphConvolution(update_node_fn=hk.Linear(2, with_bias=False))
    graph = gn(graph)
    return graph.nodes
Esempio n. 12
0
 def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
     """Apply this layer on the input graph."""
     messages = self._compute_messages(graph)
     updated_nodes = self._update_nodes(graph, messages)
     return graph._replace(nodes=updated_nodes)
Esempio n. 13
0
def set_system_state(static_graph: jraph.GraphsTuple, position: np.ndarray,
                     momentum: np.ndarray) -> jraph.GraphsTuple:
    """Sets the non-static parameters of the graph (momentum, position)."""
    nodes = static_graph.nodes.copy(position=position, momentum=momentum)
    return static_graph._replace(nodes=nodes)
Esempio n. 14
0
def get_static_graph(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Returns the graph with the static parts of a system only."""
    nodes = dict(graph.nodes)
    del nodes["position"], nodes["momentum"]
    return graph._replace(nodes=frozendict(nodes))
Esempio n. 15
0
    def _processor(
        self,
        graph: jraph.GraphsTuple,
        is_training: bool,
    ) -> jraph.GraphsTuple:
        """Builds the processor."""
        output_sizes = [self._config.mlp_hidden_size] * self._config.mlp_layers
        output_sizes += [self._config.latent_size]
        build_mlp = functools.partial(
            _build_mlp,
            output_sizes=output_sizes,
            use_layer_norm=self._config.use_layer_norm,
        )

        shared_weights = self._config.shared_message_passing_weights
        node_reducer = _REDUCER_NAMES[self._config.node_reducer]
        global_reducer = _REDUCER_NAMES[self._config.global_reducer]

        def dropout_if_training(fn, dropout_rate: float):
            def wrapped(*args):
                out = fn(*args)
                if is_training:
                    mask = hk.dropout(hk.next_rng_key(), dropout_rate,
                                      jnp.ones([out.shape[0], 1]))
                    out = out * mask
                return out

            return wrapped

        num_mps = self._config.num_message_passing_steps
        for step in range(num_mps):
            if step == 0 or not shared_weights:
                suffix = "shared" if shared_weights else step

                update_edge_fn = dropout_if_training(
                    build_mlp(f"edge_processor_{suffix}"),
                    dropout_rate=self._config.dropedge_rate)

                update_node_fn = dropout_if_training(
                    build_mlp(f"node_processor_{suffix}"),
                    dropout_rate=self._config.dropnode_rate)

                if self._config.ignore_globals:
                    gnn = jraph.InteractionNetwork(
                        update_edge_fn=update_edge_fn,
                        update_node_fn=update_node_fn,
                        aggregate_edges_for_nodes_fn=node_reducer)
                else:
                    gnn = jraph.GraphNetwork(
                        update_edge_fn=update_edge_fn,
                        update_node_fn=update_node_fn,
                        update_global_fn=build_mlp(
                            f"global_processor_{suffix}"),
                        aggregate_edges_for_nodes_fn=node_reducer,
                        aggregate_nodes_for_globals_fn=global_reducer,
                        aggregate_edges_for_globals_fn=global_reducer,
                    )

            mode = self._config.processor_mode

            if mode == "mlp":
                graph = gnn(graph)

            elif mode == "resnet":
                new_graph = gnn(graph)
                graph = graph._replace(
                    nodes=graph.nodes + new_graph.nodes,
                    edges=graph.edges + new_graph.edges,
                    globals=graph.globals + new_graph.globals,
                )
            else:
                raise ValueError(f"Unknown processor_mode `{mode}`")

            if self._config.mask_padding_graph_at_every_step:
                graph = _mask_out_padding_graph(graph)

        return graph
Esempio n. 16
0
def add_reverse_edges(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Add edges in the reverse direction, copy edge features."""
    senders = np.concatenate([graph.senders, graph.receivers], axis=0)
    receivers = np.concatenate([graph.receivers, graph.senders], axis=0)
    edges = np.concatenate([graph.edges, graph.edges], axis=0)
    return graph._replace(senders=senders, receivers=receivers, edges=edges)
Esempio n. 17
0
def add_graphs_tuples(graphs: jraph.GraphsTuple,
                      other_graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Adds the nodes, edges and global features from other_graphs to graphs."""
    return graphs._replace(nodes=graphs.nodes + other_graphs.nodes,
                           edges=graphs.edges + other_graphs.edges,
                           globals=graphs.globals + other_graphs.globals)