def __call__(self, graph, train): dropout = nn.Dropout(rate=self.dropout_rate, deterministic=not train) graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) embedder = jraph.GraphMapFeatures( embed_node_fn=_make_embed(self.latent_dim), embed_edge_fn=_make_embed(self.latent_dim)) graph = embedder(graph) for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork(update_edge_fn=_make_mlp(self.hidden_dims, dropout=dropout), update_node_fn=_make_mlp(self.hidden_dims, dropout=dropout), update_global_fn=_make_mlp( self.hidden_dims, dropout=dropout)) graph = net(graph) # Map globals to represent the final result decoder = jraph.GraphMapFeatures( embed_global_fn=nn.Dense(self.num_outputs)) graph = decoder(graph) return graph.globals
def net_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple: # Add a global paramater for graph classification. graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1])) embedder = jraph.GraphMapFeatures(hk.Linear(128), hk.Linear(128), hk.Linear(128)) net = jraph.GraphNetwork(update_node_fn=node_update_fn, update_edge_fn=edge_update_fn, update_global_fn=update_global_fn) return net(embedder(graph))
def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: # We will first linearly project the original features as 'embeddings'. embedder = jraph.GraphMapFeatures( embed_node_fn=nn.Dense(self.latent_size), embed_edge_fn=nn.Dense(self.latent_size), embed_global_fn=nn.Dense(self.latent_size)) processed_graphs = embedder(graphs) # Now, we will apply a Graph Network once for each message-passing round. mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers for _ in range(self.message_passing_steps): if self.use_edge_model: update_edge_fn = jraph.concatenated_args( MLP(mlp_feature_sizes, dropout_rate=self.dropout_rate, deterministic=self.deterministic)) else: update_edge_fn = None update_node_fn = jraph.concatenated_args( MLP(mlp_feature_sizes, dropout_rate=self.dropout_rate, deterministic=self.deterministic)) update_global_fn = jraph.concatenated_args( MLP(mlp_feature_sizes, dropout_rate=self.dropout_rate, deterministic=self.deterministic)) graph_net = jraph.GraphNetwork(update_node_fn=update_node_fn, update_edge_fn=update_edge_fn, update_global_fn=update_global_fn) if self.skip_connections: processed_graphs = add_graphs_tuples( graph_net(processed_graphs), processed_graphs) else: processed_graphs = graph_net(processed_graphs) if self.layer_norm: processed_graphs = processed_graphs._replace( nodes=nn.LayerNorm()(processed_graphs.nodes), edges=nn.LayerNorm()(processed_graphs.edges), globals=nn.LayerNorm()(processed_graphs.globals), ) # Since our graph-level predictions will be at globals, we will # decode to get the required output logits. decoder = jraph.GraphMapFeatures( embed_global_fn=nn.Dense(self.output_globals_size)) processed_graphs = decoder(processed_graphs) return processed_graphs
def __call__(self, graph): # Add a global parameter for graph classification. graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1])) embedder = jraph.GraphMapFeatures( embed_node_fn=make_embed_fn(self.latent_size), embed_edge_fn=make_embed_fn(self.latent_size), embed_global_fn=make_embed_fn(self.latent_size)) net = jraph.GraphNetwork( update_node_fn=make_mlp(self.mlp_features), update_edge_fn=make_mlp(self.mlp_features), # The global update outputs size 2 for binary classification. update_global_fn=make_mlp(self.mlp_features + (2,))) # pytype: disable=unsupported-operands return net(embedder(graph))
def hookes_hamiltonian_from_graph_fn( graph: jraph.GraphsTuple) -> jraph.GraphsTuple: """Computes Hamiltonian of a Hooke's potential system represented in a graph. While this function hardcodes the Hamiltonian for a Hooke's potential, a learned Hamiltonian Graph Network (https://arxiv.org/abs/1909.12790) could be implemented by replacing the hardcoded formulas by learnable MLPs that take as inputs all of the concatenated features to the edge_fn, node_fn, and global_fn, and outputs a single scalar value in the global_fn. Args: graph: `GraphsTuple` where the nodes contain: - "mass": [num_particles] - "position": [num_particles, num_dims] - "momentum": [num_particles, num_dims] and the edges contain: - "spring_constant": [num_interations] Returns: `GraphsTuple` with features: - edge features: "hookes_potential" [num_interactions] - node features: "kinetic_energy" [num_particles] - global features: "hamiltonian" [batch_size] """ def update_edge_fn(edges, senders, receivers, globals_): del globals_ distance = jnp.linalg.norm(senders["position"] - receivers["position"]) hookes_potential_per_edge = 0.5 * edges["spring_constant"] * distance**2 return frozendict({"hookes_potential": hookes_potential_per_edge}) def update_node_fn(nodes, sent_edges, received_edges, globals_): del sent_edges, received_edges, globals_ momentum_norm = jnp.linalg.norm(nodes["momentum"]) kinetic_energy_per_node = momentum_norm**2 / (2 * nodes["mass"]) return frozendict({"kinetic_energy": kinetic_energy_per_node}) def update_global_fn(nodes, edges, globals_): del globals_ # At this point we will receive node and edge features aggregated (summed) # for all nodes and edges in each graph. hamiltonian_per_graph = nodes["kinetic_energy"] + edges[ "hookes_potential"] return frozendict({"hamiltonian": hamiltonian_per_graph}) gn = jraph.GraphNetwork(update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, update_global_fn=update_global_fn) return gn(graph)
def test_sharded_same_as_non_sharded(self, n_edge): in_tuple = _get_graphs_from_n_edge(n_edge) devices = 3 sharded_tuple = sharded_graphnet.graphs_tuple_to_broadcasted_sharded_graphs_tuple( in_tuple, devices) update_fn = jraph.concatenated_args(lambda x: x) sharded_gn = sharded_graphnet.ShardedEdgesGraphNetwork( update_fn, update_fn, update_fn, num_shards=devices) gn = jraph.GraphNetwork(update_fn, update_fn, update_fn) sharded_out = jax.pmap(sharded_gn, axis_name='i')(sharded_tuple) expected_out = gn(in_tuple) reduced_out = sharded_graphnet.broadcasted_sharded_graphs_tuple_to_graphs_tuple( sharded_out) jax.tree_util.tree_map( functools.partial(np.testing.assert_allclose, atol=1E-5, rtol=1E-5), expected_out, reduced_out)
def __init__(self, n_recurrences: int, mlp_sizes: Tuple[int, ...], mlp_kwargs: Optional[Dict[str, Any]] = None, format: partition.NeighborListFormat = partition.Dense, name: str = 'GraphNetEncoder'): super(GraphNetEncoder, self).__init__(name=name) if mlp_kwargs is None: mlp_kwargs = {} self._n_recurrences = n_recurrences embedding_fn = lambda name: hk.nets.MLP(output_sizes=mlp_sizes, activate_final=True, name=name, **mlp_kwargs) model_fn = lambda name: lambda *args: hk.nets.MLP( output_sizes=mlp_sizes, activate_final=True, name=name, **mlp_kwargs)(jnp.concatenate(args, axis=-1)) if format is partition.Dense: self._encoder = GraphMapFeatures(embedding_fn('EdgeEncoder'), embedding_fn('NodeEncoder'), embedding_fn('GlobalEncoder')) self._propagation_network = lambda: GraphNetwork( model_fn('EdgeFunction'), model_fn('NodeFunction'), model_fn('GlobalFunction')) elif format is partition.Sparse: self._encoder = jraph.GraphMapFeatures( embedding_fn('EdgeEncoder'), embedding_fn('NodeEncoder'), embedding_fn('GlobalEncoder')) self._propagation_network = lambda: jraph.GraphNetwork( model_fn('EdgeFunction'), model_fn('NodeFunction'), model_fn('GlobalFunction')) else: raise ValueError()
def _processor( self, graph: jraph.GraphsTuple, is_training: bool, ) -> jraph.GraphsTuple: """Builds the processor.""" output_sizes = [self._config.mlp_hidden_size] * self._config.mlp_layers output_sizes += [self._config.latent_size] build_mlp = functools.partial( _build_mlp, output_sizes=output_sizes, use_layer_norm=self._config.use_layer_norm, ) shared_weights = self._config.shared_message_passing_weights node_reducer = _REDUCER_NAMES[self._config.node_reducer] global_reducer = _REDUCER_NAMES[self._config.global_reducer] def dropout_if_training(fn, dropout_rate: float): def wrapped(*args): out = fn(*args) if is_training: mask = hk.dropout(hk.next_rng_key(), dropout_rate, jnp.ones([out.shape[0], 1])) out = out * mask return out return wrapped num_mps = self._config.num_message_passing_steps for step in range(num_mps): if step == 0 or not shared_weights: suffix = "shared" if shared_weights else step update_edge_fn = dropout_if_training( build_mlp(f"edge_processor_{suffix}"), dropout_rate=self._config.dropedge_rate) update_node_fn = dropout_if_training( build_mlp(f"node_processor_{suffix}"), dropout_rate=self._config.dropnode_rate) if self._config.ignore_globals: gnn = jraph.InteractionNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, aggregate_edges_for_nodes_fn=node_reducer) else: gnn = jraph.GraphNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, update_global_fn=build_mlp( f"global_processor_{suffix}"), aggregate_edges_for_nodes_fn=node_reducer, aggregate_nodes_for_globals_fn=global_reducer, aggregate_edges_for_globals_fn=global_reducer, ) mode = self._config.processor_mode if mode == "mlp": graph = gnn(graph) elif mode == "resnet": new_graph = gnn(graph) graph = graph._replace( nodes=graph.nodes + new_graph.nodes, edges=graph.edges + new_graph.edges, globals=graph.globals + new_graph.globals, ) else: raise ValueError(f"Unknown processor_mode `{mode}`") if self._config.mask_padding_graph_at_every_step: graph = _mask_out_padding_graph(graph) return graph
def run(): """Runs basic example.""" # Creating graph tuples. # Creates a GraphsTuple from scratch containing a single graph. # The graph has 3 nodes and 2 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. single_graph = jraph.GraphsTuple(n_node=np.asarray([3]), n_edge=np.asarray([2]), nodes=np.ones((3, 4)), edges=np.ones((2, 5)), globals=np.ones((1, 6)), senders=np.array([0, 1]), receivers=np.array([2, 2])) logging.info("Single graph %r", single_graph) # Creates a GraphsTuple from scatch containing a single graph with nested # feature vectors. # The graph has 3 nodes and 2 edges. # The feature vector can be arbitrary nested types of dict, list and tuple, # or any other type you registered with jax.tree_util.register_pytree_node. nested_graph = jraph.GraphsTuple(n_node=np.asarray([3]), n_edge=np.asarray([2]), nodes={"a": np.ones((3, 4))}, edges={"b": np.ones((2, 5))}, globals={"c": np.ones((1, 6))}, senders=np.array([0, 1]), receivers=np.array([2, 2])) logging.info("Nested graph %r", nested_graph) # Creates a GraphsTuple from scratch containing a 2 graphs using an implicit # batch dimension. # The first graph has 3 nodes and 2 edges. # The second graph has 1 nodes and 1 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. implicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([3, 1]), n_edge=np.asarray([2, 1]), nodes=np.ones((4, 4)), edges=np.ones((3, 5)), globals=np.ones((2, 6)), senders=np.array([0, 1, 3]), receivers=np.array([2, 2, 3])) logging.info("Implicitly batched graph %r", implicitly_batched_graph) # Creates a GraphsTuple from two existing GraphsTuple using an implicit # batch dimension. # The GraphsTuple will contain three graphs. implicitly_batched_graph = jraph.batch( [single_graph, implicitly_batched_graph]) logging.info("Implicitly batched graph %r", implicitly_batched_graph) # Creates multiple GraphsTuples from an existing GraphsTuple with an implicit # batch dimension. graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph) logging.info("Unbatched graphs %r %r %r", graph_1, graph_2, graph_3) # Creates a padded GraphsTuple from an existing GraphsTuple. # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs. # Three graphs are added for the padding. # First an dummy graph which contains the padding nodes and edges and secondly # two empty graphs without nodes or edges to pad out the graphs. padded_graph = jraph.pad_with_graphs(single_graph, n_node=10, n_edge=5, n_graph=4) logging.info("Padded graph %r", padded_graph) # Creates a GraphsTuple from an existing padded GraphsTuple. # The previously added padding is removed. single_graph = jraph.unpad_with_graphs(padded_graph) logging.info("Unpadded graph %r", single_graph) # Creates a GraphsTuple containing a 2 graphs using an explicit batch # dimension. # An explicit batch dimension requires more memory, but can simplify # the definition of functions operating on the graph. # Explicitly batched graphs require the GraphNetwork to be transformed # by jax.mask followed by jax.vmap. # Using an explicit batch requires padding all feature vectors to # the maximum size of nodes and edges. # The first graph has 3 nodes and 2 edges. # The second graph has 1 nodes and 1 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. explicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([[3], [1]]), n_edge=np.asarray([[2], [1]]), nodes=np.ones((2, 3, 4)), edges=np.ones((2, 2, 5)), globals=np.ones((2, 1, 6)), senders=np.array([[0, 1], [0, -1]]), receivers=np.array([[2, 2], [0, -1]])) logging.info("Explicitly batched graph %r", explicitly_batched_graph) # Running a graph propagation steps. # First define the update functions for the edges, nodes and globals. # In this example we use the identity everywhere. # For Graph neural networks, each update function is typically a neural # network. def update_edge_fn(edge_features, sender_node_features, receiver_node_features, globals_): """Returns the update edge features.""" del sender_node_features del receiver_node_features del globals_ return edge_features def update_node_fn(node_features, aggregated_sender_edge_features, aggregated_receiver_edge_features, globals_): """Returns the update node features.""" del aggregated_sender_edge_features del aggregated_receiver_edge_features del globals_ return node_features def update_globals_fn(aggregated_node_features, aggregated_edge_features, globals_): del aggregated_node_features del aggregated_edge_features return globals_ # Optionally define custom aggregation functions. # In this example we use the defaults (so no need to define them explicitly). aggregate_edges_for_nodes_fn = jax.ops.segment_sum aggregate_nodes_for_globals_fn = jax.ops.segment_sum aggregate_edges_for_globals_fn = jax.ops.segment_sum # Optionally define attention logit function and attention reduce function. # This can be used for graph attention. # The attention function calculates attention weights, and the apply # attention function calculates the new edge feature given the weights. # We don't use graph attention here, and just pass the defaults. attention_logit_fn = None attention_reduce_fn = None # Creates a new GraphNetwork in its most general form. # Most of the arguments have defaults and can be omitted if a feature # is not used. # There are also predefined GraphNetworks available (see models.py) network = jraph.GraphNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, update_global_fn=update_globals_fn, attention_logit_fn=attention_logit_fn, aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn, aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn, aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn, attention_reduce_fn=attention_reduce_fn) # Runs graph propagation on (implicitly batched) graphs. updated_graph = network(single_graph) logging.info("Updated graph from single graph %r", updated_graph) updated_graph = network(nested_graph) logging.info("Updated graph from nested graph %r", nested_graph) updated_graph = network(implicitly_batched_graph) logging.info("Updated graph from implicitly batched graph %r", updated_graph) updated_graph = network(padded_graph) logging.info("Updated graph from padded graph %r", updated_graph) # Runs graph propagation on an explicitly batched graph. # WARNING: This code relies on an undocumented JAX feature (jax.mask) which # might stop working at any time! graph_shape = jraph.GraphsTuple( n_node="(g)", n_edge="(g)", nodes="(n, {})".format(explicitly_batched_graph.nodes.shape[-1]), edges="(e, {})".format(explicitly_batched_graph.edges.shape[-1]), globals="(g, {})".format(explicitly_batched_graph.globals.shape[-1]), senders="(e)", receivers="(e)") batch_size = explicitly_batched_graph.globals.shape[0] logical_env = { "g": jnp.ones(batch_size, dtype=jnp.int32), "n": jnp.sum(explicitly_batched_graph.n_node, axis=-1), "e": jnp.sum(explicitly_batched_graph.n_edge, axis=-1) } try: propagation_fn = jax.vmap( jax.mask(network, in_shapes=[graph_shape], out_shape=graph_shape)) updated_graph = propagation_fn([explicitly_batched_graph], logical_env) logging.info("Updated graph from explicitly batched graph %r", updated_graph) except Exception: # pylint: disable=broad-except logging.warning(MASK_BROKEN_MSG) # JIT-compile graph propagation. # Use padded graphs to avoid re-compilation at every step! jitted_network = jax.jit(network) updated_graph = jitted_network(padded_graph) logging.info("(JIT) updated graph from padded graph %r", updated_graph) # Or use an explicit batch dimension. try: jitted_propagation_fn = jax.jit(propagation_fn) updated_graph = jitted_propagation_fn([explicitly_batched_graph], logical_env) logging.info("(JIT) Updated graph from explicitly batched graph %r", updated_graph) except Exception: # pylint: disable=broad-except logging.warning(MASK_BROKEN_MSG) logging.info("basic.py complete!")
def run(): """Runs basic example.""" # Creating graph tuples. # Creates a GraphsTuple from scratch containing a single graph. # The graph has 3 nodes and 2 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. single_graph = jraph.GraphsTuple(n_node=np.asarray([3]), n_edge=np.asarray([2]), nodes=np.ones((3, 4)), edges=np.ones((2, 5)), globals=np.ones((1, 6)), senders=np.array([0, 1]), receivers=np.array([2, 2])) logging.info("Single graph %r", single_graph) # Creates a GraphsTuple from scratch containing a single graph with nested # feature vectors. # The graph has 3 nodes and 2 edges. # The feature vector can be arbitrary nested types of dict, list and tuple, # or any other type you registered with jax.tree_util.register_pytree_node. nested_graph = jraph.GraphsTuple(n_node=np.asarray([3]), n_edge=np.asarray([2]), nodes={"a": np.ones((3, 4))}, edges={"b": np.ones((2, 5))}, globals={"c": np.ones((1, 6))}, senders=np.array([0, 1]), receivers=np.array([2, 2])) logging.info("Nested graph %r", nested_graph) # Creates a GraphsTuple from scratch containing a 2 graphs using an implicit # batch dimension. # The first graph has 3 nodes and 2 edges. # The second graph has 1 nodes and 1 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. implicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([3, 1]), n_edge=np.asarray([2, 1]), nodes=np.ones((4, 4)), edges=np.ones((3, 5)), globals=np.ones((2, 6)), senders=np.array([0, 1, 3]), receivers=np.array([2, 2, 3])) logging.info("Implicitly batched graph %r", implicitly_batched_graph) # Batching graphs can be challenging. There are in general two approaches: # 1. Implicit batching: Independent graphs are combined into the same # GraphsTuple first, and the padding is added to the combined graph. # 2. Explicit batching: Pad all graphs to a maximum size, stack them together # using an explicit batch dimension followed by jax.vmap. # Both approaches are shown below. # Creates a GraphsTuple from two existing GraphsTuple using an implicit # batch dimension. # The GraphsTuple will contain three graphs. implicitly_batched_graph = jraph.batch( [single_graph, implicitly_batched_graph]) logging.info("Implicitly batched graph %r", implicitly_batched_graph) # Creates multiple GraphsTuples from an existing GraphsTuple with an implicit # batch dimension. graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph) logging.info("Unbatched graphs %r %r %r", graph_1, graph_2, graph_3) # Creates a padded GraphsTuple from an existing GraphsTuple. # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs. # Three graphs are added for the padding. # First an dummy graph which contains the padding nodes and edges and secondly # two empty graphs without nodes or edges to pad out the graphs. padded_graph = jraph.pad_with_graphs(single_graph, n_node=10, n_edge=5, n_graph=4) logging.info("Padded graph %r", padded_graph) # Creates a GraphsTuple from an existing padded GraphsTuple. # The previously added padding is removed. single_graph = jraph.unpad_with_graphs(padded_graph) logging.info("Unpadded graph %r", single_graph) # Creates a GraphsTuple containing a 2 graphs using an explicit batch # dimension. # An explicit batch dimension requires more memory, but can simplify # the definition of functions operating on the graph. # Explicitly batched graphs require the GraphNetwork to be transformed # by jax.vmap. # Using an explicit batch requires padding all feature vectors to # the maximum size of nodes and edges. # The first graph has 3 nodes and 2 edges. # The second graph has 1 nodes and 1 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. explicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([[3], [1]]), n_edge=np.asarray([[2], [1]]), nodes=np.ones((2, 3, 4)), edges=np.ones((2, 2, 5)), globals=np.ones((2, 1, 6)), senders=np.array([[0, 1], [0, -1]]), receivers=np.array([[2, 2], [0, -1]])) logging.info("Explicitly batched graph %r", explicitly_batched_graph) # Running a graph propagation steps. # First define the update functions for the edges, nodes and globals. # In this example we use the identity everywhere. # For Graph neural networks, each update function is typically a neural # network. def update_edge_fn(edge_features, sender_node_features, receiver_node_features, globals_): """Returns the update edge features.""" del sender_node_features del receiver_node_features del globals_ return edge_features def update_node_fn(node_features, aggregated_sender_edge_features, aggregated_receiver_edge_features, globals_): """Returns the update node features.""" del aggregated_sender_edge_features del aggregated_receiver_edge_features del globals_ return node_features def update_globals_fn(aggregated_node_features, aggregated_edge_features, globals_): del aggregated_node_features del aggregated_edge_features return globals_ # Optionally define custom aggregation functions. # In this example we use the defaults (so no need to define them explicitly). aggregate_edges_for_nodes_fn = jraph.segment_sum aggregate_nodes_for_globals_fn = jraph.segment_sum aggregate_edges_for_globals_fn = jraph.segment_sum # Optionally define attention logit function and attention reduce function. # This can be used for graph attention. # The attention function calculates attention weights, and the apply # attention function calculates the new edge feature given the weights. # We don't use graph attention here, and just pass the defaults. attention_logit_fn = None attention_reduce_fn = None # Creates a new GraphNetwork in its most general form. # Most of the arguments have defaults and can be omitted if a feature # is not used. # There are also predefined GraphNetworks available (see models.py) network = jraph.GraphNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, update_global_fn=update_globals_fn, attention_logit_fn=attention_logit_fn, aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn, aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn, aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn, attention_reduce_fn=attention_reduce_fn) # Runs graph propagation on (implicitly batched) graphs. updated_graph = network(single_graph) logging.info("Updated graph from single graph %r", updated_graph) updated_graph = network(nested_graph) logging.info("Updated graph from nested graph %r", nested_graph) updated_graph = network(implicitly_batched_graph) logging.info("Updated graph from implicitly batched graph %r", updated_graph) updated_graph = network(padded_graph) logging.info("Updated graph from padded graph %r", updated_graph) # JIT-compile graph propagation. # Use padded graphs to avoid re-compilation at every step! jitted_network = jax.jit(network) updated_graph = jitted_network(padded_graph) logging.info("(JIT) updated graph from padded graph %r", updated_graph) logging.info("basic.py complete!")