def network_definition( graph: jraph.GraphsTuple, num_message_passing_steps: int = 5) -> jraph.ArrayTree: """Defines a graph neural network. Args: graph: Graphstuple the network processes. num_message_passing_steps: number of message passing steps. Returns: Decoded nodes. """ embedding = jraph.GraphMapFeatures( embed_edge_fn=jax.vmap(hk.Linear(output_size=16)), embed_node_fn=jax.vmap(hk.Linear(output_size=16))) graph = embedding(graph) @jax.vmap @jraph.concatenated_args def update_fn(features): net = hk.Sequential([ hk.Linear(10), jax.nn.relu, hk.Linear(10), jax.nn.relu, hk.Linear(10), jax.nn.relu]) return net(features) for _ in range(num_message_passing_steps): gn = jraph.InteractionNetwork( update_edge_fn=update_fn, update_node_fn=update_fn, include_sent_messages_in_node_update=True) graph = gn(graph) return hk.Linear(2)(graph.nodes)
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree: """`InteractionNetwork` with an LSTM in the edge update.""" # LSTM that will keep a memory of the inputs to the edge model. edge_fn_lstm = hk.LSTM(hidden_size=HIDDEN_SIZE) # MLPs used in the edge and the node model. Note that in this instance # the output size matches the input size so the same model can be run # iteratively multiple times. In a real model, this would usually be achieved # by first using an encoder in the input data into a common `EMBEDDING_SIZE`. edge_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE]) node_fn_mlp = hk.nets.MLP([HIDDEN_SIZE, EMBEDDING_SIZE]) # Initialize the edge features to contain both the input edge embedding # and initial LSTM state. Note for the nodes we only have an embedding since # in this example nodes do not use a `node_fn_lstm`, but for analogy, we # still put it in a `StatefulField`. graph = graph._replace( edges=StatefulField(embedding=graph.edges, state=edge_fn_lstm.initial_state( graph.edges.shape[0])), nodes=StatefulField(embedding=graph.nodes, state=None), ) def update_edge_fn(edges, sender_nodes, receiver_nodes): # We will run an LSTM memory on the inputs first, and then # process the output of the LSTM with an MLP. edge_inputs = jnp.concatenate([ edges.embedding, sender_nodes.embedding, receiver_nodes.embedding ], axis=-1) lstm_output, updated_state = edge_fn_lstm(edge_inputs, edges.state) updated_edges = StatefulField( embedding=edge_fn_mlp(lstm_output), state=updated_state, ) return updated_edges def update_node_fn(nodes, received_edges): # Note `received_edges.state` will also contain the aggregated state for # all received edges, which we may choose to use in the node update. node_inputs = jnp.concatenate( [nodes.embedding, received_edges.embedding], axis=-1) updated_nodes = StatefulField(embedding=node_fn_mlp(node_inputs), state=None) return updated_nodes recurrent_graph_network = jraph.InteractionNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn) # Apply the model recurrently for 10 message passing steps. # If instead we intended to use the LSTM to process a sequence of features # for each node/edge, here we would select the corresponding inputs from the # sequence along the sequence axis of the nodes/edges features to build the # correct input graph for each step of the iteration. num_message_passing_steps = 10 for _ in range(num_message_passing_steps): graph = recurrent_graph_network(graph) return graph
def build_gn( output_sizes: Sequence[int], activation: Callable[[jnp.ndarray], jnp.ndarray], suffix: str, use_sent_edges: bool, is_training: bool, dropedge_rate: float, normalization_type: str, aggregation_function: str, ): """Builds an InteractionNetwork with MLP update functions.""" node_update_fn = build_update_fn( f'node_processor_{suffix}', output_sizes, activation=activation, normalization_type=normalization_type, is_training=is_training, ) edge_update_fn = build_update_fn( f'edge_processor_{suffix}', output_sizes, activation=activation, normalization_type=normalization_type, is_training=is_training, ) def maybe_dropedge(x): """Dropout on edge messages.""" if not is_training: return x return x * hk.dropout( hk.next_rng_key(), dropedge_rate, jnp.ones([x.shape[0], 1]), ) dropped_edge_update_fn = lambda *args: maybe_dropedge(edge_update_fn(*args) ) return jraph.InteractionNetwork( update_edge_fn=dropped_edge_update_fn, update_node_fn=node_update_fn, aggregate_edges_for_nodes_fn=_REDUCER_NAMES[aggregation_function], include_sent_messages_in_node_update=use_sent_edges, )
def network_definition(graph): """Defines a graph neural network. Args: graph: Graphstuple the network processes. Returns: Decoded nodes. """ model_fn = functools.partial(hk.nets.MLP, w_init=hk.initializers.VarianceScaling(1.0), b_init=hk.initializers.VarianceScaling(1.0)) mlp_sizes = (64, 64) num_message_passing_steps = 7 node_encoder = model_fn(output_sizes=mlp_sizes, activate_final=True) edge_encoder = model_fn(output_sizes=mlp_sizes, activate_final=True) node_decoder = model_fn(output_sizes=mlp_sizes + (1, ), activate_final=False) node_encoding = node_encoder(graph.nodes) edge_encoding = edge_encoder(graph.edges) graph = graph._replace(nodes=node_encoding, edges=edge_encoding) update_edge_fn = jraph.concatenated_args( model_fn(output_sizes=mlp_sizes, activate_final=True)) update_node_fn = jraph.concatenated_args( model_fn(output_sizes=mlp_sizes, activate_final=True)) gn = jraph.InteractionNetwork(update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, include_sent_messages_in_node_update=True) for _ in range(num_message_passing_steps): graph = graph._replace( nodes=jnp.concatenate([graph.nodes, node_encoding], axis=-1), edges=jnp.concatenate([graph.edges, edge_encoding], axis=-1)) graph = gn(graph) return jnp.squeeze(node_decoder(graph.nodes), axis=-1)
def net_fn(graph: jraph.GraphsTuple): unf = jraph.concatenated_args(conway_mlp) net = jraph.InteractionNetwork(update_edge_fn=lambda e, n_s, n_r: n_s, update_node_fn=jax.vmap(unf)) return net(graph)
def _processor( self, graph: jraph.GraphsTuple, is_training: bool, ) -> jraph.GraphsTuple: """Builds the processor.""" output_sizes = [self._config.mlp_hidden_size] * self._config.mlp_layers output_sizes += [self._config.latent_size] build_mlp = functools.partial( _build_mlp, output_sizes=output_sizes, use_layer_norm=self._config.use_layer_norm, ) shared_weights = self._config.shared_message_passing_weights node_reducer = _REDUCER_NAMES[self._config.node_reducer] global_reducer = _REDUCER_NAMES[self._config.global_reducer] def dropout_if_training(fn, dropout_rate: float): def wrapped(*args): out = fn(*args) if is_training: mask = hk.dropout(hk.next_rng_key(), dropout_rate, jnp.ones([out.shape[0], 1])) out = out * mask return out return wrapped num_mps = self._config.num_message_passing_steps for step in range(num_mps): if step == 0 or not shared_weights: suffix = "shared" if shared_weights else step update_edge_fn = dropout_if_training( build_mlp(f"edge_processor_{suffix}"), dropout_rate=self._config.dropedge_rate) update_node_fn = dropout_if_training( build_mlp(f"node_processor_{suffix}"), dropout_rate=self._config.dropnode_rate) if self._config.ignore_globals: gnn = jraph.InteractionNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, aggregate_edges_for_nodes_fn=node_reducer) else: gnn = jraph.GraphNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, update_global_fn=build_mlp( f"global_processor_{suffix}"), aggregate_edges_for_nodes_fn=node_reducer, aggregate_nodes_for_globals_fn=global_reducer, aggregate_edges_for_globals_fn=global_reducer, ) mode = self._config.processor_mode if mode == "mlp": graph = gnn(graph) elif mode == "resnet": new_graph = gnn(graph) graph = graph._replace( nodes=graph.nodes + new_graph.nodes, edges=graph.edges + new_graph.edges, globals=graph.globals + new_graph.globals, ) else: raise ValueError(f"Unknown processor_mode `{mode}`") if self._config.mask_padding_graph_at_every_step: graph = _mask_out_padding_graph(graph) return graph