def _dropout_graph(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: node_key, edge_key = hk.next_rng_keys(2) nodes = hk.dropout(node_key, self._dropout_rate, graph.nodes) edges = graph.edges if not self._disable_edge_updates: edges = hk.dropout(edge_key, self._dropout_rate, edges) return graph._replace(nodes=nodes, edges=edges)
def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: """Compute embeddings for each node in the graphs. Args: graphs: a set of graphs batched into a single graph. The nodes and edges are represented as feature tensors. Returns: graphs: new graph with node embeddings updated (shape [n_nodes, embed_dim]). """ nodes = hk.Linear(self._embed_dim)(graphs.nodes) edges = hk.Linear(self._embed_dim)(graphs.edges) nodes = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(jax.nn.gelu(nodes)) edges = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(jax.nn.gelu(edges)) graphs = graphs._replace(nodes=nodes, edges=edges) graphs = gn.SimpleGraphNet( num_layers=self._num_layers, msg_hidden_size_factor=self._msg_hidden_size_factor, layer_norm=self._use_layer_norm)(graphs) return graphs
def __call__( self, graph: jraph.GraphsTuple, is_training: bool, stop_gradient_embedding_to_logits: bool = False, ) -> ModelOutput: # Note that these update configs may need to change if # we switch back to GraphNetwork rather than InteractionNetwork. graph = self._encode(graph, is_training) graph = self._process(graph, is_training) node_embeddings = graph.nodes node_projections = self._node_mlp(graph, is_training, self._latent_size, 'projector') node_predictions = self._node_mlp( graph._replace(nodes=node_projections), is_training, self._latent_size, 'predictor', ) if stop_gradient_embedding_to_logits: graph = jax.tree_map(jax.lax.stop_gradient, graph) node_logits = self._node_mlp(graph, is_training, self._num_classes, 'logits_decoder') return ModelOutput( node_embeddings=node_embeddings, node_logits=node_logits, node_embedding_projections=node_projections, node_projection_predictions=node_predictions, )
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree: """`InteractionNetwork` with an LSTM in the edge update.""" # LSTM that will keep a memory of the inputs to the edge model. edge_fn_lstm = hk.LSTM(hidden_size=HIDDEN_SIZE) # MLPs used in the edge and the node model. Note that in this instance # the output size matches the input size so the same model can be run # iteratively multiple times. In a real model, this would usually be achieved # by first using an encoder in the input data into a common `EMBEDDING_SIZE`. edge_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE]) node_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE]) # Initialize the edge features to contain both the input edge embedding # and initial LSTM state. Note for the nodes we only have an embedding since # in this example nodes do not use a `node_fn_lstm`, but for analogy, we # still put it in a `StatefulField`. graph = graph._replace( edges=StatefulField(embedding=graph.edges, state=edge_fn_lstm.initial_state( graph.edges.shape[0])), nodes=StatefulField(embedding=graph.nodes, state=None), ) def update_edge_fn(edges, sender_nodes, receiver_nodes): # We will run an LSTM memory on the inputs first, and then # process the output of the LSTM with an MLP. edge_inputs = jnp.concatenate([ edges.embedding, sender_nodes.embedding, receiver_nodes.embedding ], axis=-1) lstm_output, updated_state = edge_fn_lstm(edge_inputs, edges.state) updated_edges = StatefulField( embedding=edge_fn_mlp(lstm_output), state=updated_state, ) return updated_edges def update_node_fn(nodes, received_edges): # Note `received_edges.state` will also contain the aggregated state for # all received edges, which we may choose to use in the node update. node_inputs = jnp.concatenate( [nodes.embedding, received_edges.embedding], axis=-1) updated_nodes = StatefulField(embedding=node_fn_mlp(node_inputs), state=None) return updated_nodes recurrent_graph_network = jraph.InteractionNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn) # Apply the model recurrently for 10 message passing steps. # If instead we intended to use the LSTM to process a sequence of features # for each node/edge, here we would select the corresponding inputs from the # sequence along the sequence axis of the nodes/edges features to build the # correct input graph for each step of the iteration. num_message_passing_steps = 10 for _ in range(num_message_passing_steps): graph = recurrent_graph_network(graph) return graph
def get_corrupted_view( graph: jraph.GraphsTuple, feature_drop_prob: float, edge_drop_prob: float, rng_key: jnp.ndarray, ) -> jraph.GraphsTuple: """Returns corrupted graph view.""" node_key, edge_key = jax.random.split(rng_key) def mask_feature(x): mask = jax.random.bernoulli(node_key, 1 - feature_drop_prob, x.shape) return x * mask # Randomly mask features with fixed probability. nodes = jax.tree_map(mask_feature, graph.nodes) # Simulate dropping of edges by changing genuine edges to self-loops on # the padded node. num_edges = graph.senders.shape[0] last_node_idx = graph.n_node.sum() - 1 edge_mask = jax.random.bernoulli(edge_key, 1 - edge_drop_prob, [num_edges]) senders = jnp.where(edge_mask, graph.senders, last_node_idx) receivers = jnp.where(edge_mask, graph.receivers, last_node_idx) # Note that n_edge will now be invalid since edges in the middle of the list # will correspond to the final graph. Set n_edge to None to ensure we do not # accidentally use this. return graph._replace( nodes=nodes, senders=senders, receivers=receivers, n_edge=None, )
def net_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple: # Add a global paramater for graph classification. graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1])) embedder = jraph.GraphMapFeatures(hk.Linear(128), hk.Linear(128), hk.Linear(128)) net = jraph.GraphNetwork(update_node_fn=node_update_fn, update_edge_fn=edge_update_fn, update_global_fn=update_global_fn) return net(embedder(graph))
def _process( self, graph: jraph.GraphsTuple, is_training: bool, ) -> jraph.GraphsTuple: for idx in range(self._num_message_passing_steps): net = build_gn(output_sizes=self._output_sizes, activation=self._activation, suffix=str(idx), use_sent_edges=self._use_sent_edges, is_training=is_training, dropedge_rate=self._dropedge_rate, normalization_type=self._normalization_type, aggregation_function=self._aggregation_function) residual_graph = net(graph) graph = graph._replace(nodes=graph.nodes + residual_graph.nodes) if not self._disable_edge_updates: graph = graph._replace(edges=graph.edges + residual_graph.edges) if is_training: graph = self._dropout_graph(graph) return graph
def _mask_out_padding_graph( padded_graph: jraph.GraphsTuple) -> jraph.GraphsTuple: return padded_graph._replace( nodes=jnp.where( jraph.get_node_padding_mask(padded_graph)[:, None], padded_graph.nodes, 0.), edges=jnp.where( jraph.get_edge_padding_mask(padded_graph)[:, None], padded_graph.edges, 0.), globals=jnp.where( jraph.get_graph_padding_mask(padded_graph)[:, None], padded_graph.globals, 0.), )
def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: """Pooling operation, taken from Jraph.""" # Equivalent to jnp.sum(n_node), but JIT-able. sum_n_node = graphs.nodes.shape[0] # To aggregate nodes from each graph to global features, # we first construct tensors that map the node to the corresponding graph. # Example: if you have `n_node=[1,2]`, we construct the tensor [0, 1, 1]. n_graph = graphs.n_node.shape[0] node_graph_indices = jnp.repeat(jnp.arange(n_graph), graphs.n_node, axis=0, total_repeat_length=sum_n_node) # We use the aggregation function to pool the nodes per graph. pooled = self.pooling_fn(graphs.nodes, node_graph_indices, n_graph) # type: ignore[call-arg] return graphs._replace(globals=pooled)
def _prepare_features(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: """Prepares features keys into flat node, edge and global features.""" # Collect edge features. edge_features_list = [graph.edges["bond_one_hots"]] if (self._config.add_relative_displacement or self._config.add_relative_distance): (relative_displacement, relative_distance) = _compute_relative_displacement_and_distance( graph, self._config.relative_displacement_normalization, use_target=False) if self._config.add_relative_displacement: edge_features_list.append(relative_displacement) if self._config.add_relative_distance: edge_features_list.append(relative_distance) mask_at_edges = _broadcast_global_to_edges( graph.globals["positions_nan_mask"], graph) edge_features_list.append(mask_at_edges[:, None].astype(jnp.float32)) edge_features = jnp.concatenate(edge_features_list, axis=-1) # Collect node features node_features_list = [graph.nodes["atom_one_hots"]] if self._config.add_absolute_positions: node_features_list.append(graph.nodes["positions"] / self._config.position_normalization) mask_at_nodes = _broadcast_global_to_nodes( graph.globals["positions_nan_mask"], graph) node_features_list.append(mask_at_nodes[:, None].astype(jnp.float32)) node_features = jnp.concatenate(node_features_list, axis=-1) global_features = jnp.zeros( (len(graph.n_node), self._config.latent_size)) chex.assert_tree_shape_prefix(global_features, (len(graph.n_node), )) return graph._replace(nodes=node_features, edges=edge_features, globals=global_features)
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree: """Implements the GCN from Kipf et al https://arxiv.org/pdf/1609.02907.pdf. A' = D^{-0.5} A D^{-0.5} Z = f(X, A') = A' relu(A' X W_0) W_1 Args: graph: GraphsTuple the network processes. Returns: processed nodes. """ gn = jraph.GraphConvolution(update_node_fn=hk.Linear(5, with_bias=False), add_self_edges=True) graph = gn(graph) graph = graph._replace(nodes=jax.nn.relu(graph.nodes)) gn = jraph.GraphConvolution(update_node_fn=hk.Linear(2, with_bias=False)) graph = gn(graph) return graph.nodes
def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: """Apply this layer on the input graph.""" messages = self._compute_messages(graph) updated_nodes = self._update_nodes(graph, messages) return graph._replace(nodes=updated_nodes)
def set_system_state(static_graph: jraph.GraphsTuple, position: np.ndarray, momentum: np.ndarray) -> jraph.GraphsTuple: """Sets the non-static parameters of the graph (momentum, position).""" nodes = static_graph.nodes.copy(position=position, momentum=momentum) return static_graph._replace(nodes=nodes)
def get_static_graph(graph: jraph.GraphsTuple) -> jraph.GraphsTuple: """Returns the graph with the static parts of a system only.""" nodes = dict(graph.nodes) del nodes["position"], nodes["momentum"] return graph._replace(nodes=frozendict(nodes))
def _processor( self, graph: jraph.GraphsTuple, is_training: bool, ) -> jraph.GraphsTuple: """Builds the processor.""" output_sizes = [self._config.mlp_hidden_size] * self._config.mlp_layers output_sizes += [self._config.latent_size] build_mlp = functools.partial( _build_mlp, output_sizes=output_sizes, use_layer_norm=self._config.use_layer_norm, ) shared_weights = self._config.shared_message_passing_weights node_reducer = _REDUCER_NAMES[self._config.node_reducer] global_reducer = _REDUCER_NAMES[self._config.global_reducer] def dropout_if_training(fn, dropout_rate: float): def wrapped(*args): out = fn(*args) if is_training: mask = hk.dropout(hk.next_rng_key(), dropout_rate, jnp.ones([out.shape[0], 1])) out = out * mask return out return wrapped num_mps = self._config.num_message_passing_steps for step in range(num_mps): if step == 0 or not shared_weights: suffix = "shared" if shared_weights else step update_edge_fn = dropout_if_training( build_mlp(f"edge_processor_{suffix}"), dropout_rate=self._config.dropedge_rate) update_node_fn = dropout_if_training( build_mlp(f"node_processor_{suffix}"), dropout_rate=self._config.dropnode_rate) if self._config.ignore_globals: gnn = jraph.InteractionNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, aggregate_edges_for_nodes_fn=node_reducer) else: gnn = jraph.GraphNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, update_global_fn=build_mlp( f"global_processor_{suffix}"), aggregate_edges_for_nodes_fn=node_reducer, aggregate_nodes_for_globals_fn=global_reducer, aggregate_edges_for_globals_fn=global_reducer, ) mode = self._config.processor_mode if mode == "mlp": graph = gnn(graph) elif mode == "resnet": new_graph = gnn(graph) graph = graph._replace( nodes=graph.nodes + new_graph.nodes, edges=graph.edges + new_graph.edges, globals=graph.globals + new_graph.globals, ) else: raise ValueError(f"Unknown processor_mode `{mode}`") if self._config.mask_padding_graph_at_every_step: graph = _mask_out_padding_graph(graph) return graph
def add_reverse_edges(graph: jraph.GraphsTuple) -> jraph.GraphsTuple: """Add edges in the reverse direction, copy edge features.""" senders = np.concatenate([graph.senders, graph.receivers], axis=0) receivers = np.concatenate([graph.receivers, graph.senders], axis=0) edges = np.concatenate([graph.edges, graph.edges], axis=0) return graph._replace(senders=senders, receivers=receivers, edges=edges)
def add_graphs_tuples(graphs: jraph.GraphsTuple, other_graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: """Adds the nodes, edges and global features from other_graphs to graphs.""" return graphs._replace(nodes=graphs.nodes + other_graphs.nodes, edges=graphs.edges + other_graphs.edges, globals=graphs.globals + other_graphs.globals)