Example #1
0
    def __call__(self, graph, train):
        dropout = nn.Dropout(rate=self.dropout_rate, deterministic=not train)

        graph = graph._replace(
            globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))

        embedder = jraph.GraphMapFeatures(
            embed_node_fn=_make_embed(self.latent_dim),
            embed_edge_fn=_make_embed(self.latent_dim))
        graph = embedder(graph)

        for _ in range(self.num_message_passing_steps):
            net = jraph.GraphNetwork(update_edge_fn=_make_mlp(self.hidden_dims,
                                                              dropout=dropout),
                                     update_node_fn=_make_mlp(self.hidden_dims,
                                                              dropout=dropout),
                                     update_global_fn=_make_mlp(
                                         self.hidden_dims, dropout=dropout))

            graph = net(graph)

        # Map globals to represent the final result
        decoder = jraph.GraphMapFeatures(
            embed_global_fn=nn.Dense(self.num_outputs))
        graph = decoder(graph)

        return graph.globals
Example #2
0
def net_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    # Add a global paramater for graph classification.
    graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1]))
    embedder = jraph.GraphMapFeatures(hk.Linear(128), hk.Linear(128),
                                      hk.Linear(128))
    net = jraph.GraphNetwork(update_node_fn=node_update_fn,
                             update_edge_fn=edge_update_fn,
                             update_global_fn=update_global_fn)
    return net(embedder(graph))
Example #3
0
    def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
        # We will first linearly project the original features as 'embeddings'.
        embedder = jraph.GraphMapFeatures(
            embed_node_fn=nn.Dense(self.latent_size),
            embed_edge_fn=nn.Dense(self.latent_size),
            embed_global_fn=nn.Dense(self.latent_size))
        processed_graphs = embedder(graphs)

        # Now, we will apply a Graph Network once for each message-passing round.
        mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers
        for _ in range(self.message_passing_steps):
            if self.use_edge_model:
                update_edge_fn = jraph.concatenated_args(
                    MLP(mlp_feature_sizes,
                        dropout_rate=self.dropout_rate,
                        deterministic=self.deterministic))
            else:
                update_edge_fn = None

            update_node_fn = jraph.concatenated_args(
                MLP(mlp_feature_sizes,
                    dropout_rate=self.dropout_rate,
                    deterministic=self.deterministic))
            update_global_fn = jraph.concatenated_args(
                MLP(mlp_feature_sizes,
                    dropout_rate=self.dropout_rate,
                    deterministic=self.deterministic))

            graph_net = jraph.GraphNetwork(update_node_fn=update_node_fn,
                                           update_edge_fn=update_edge_fn,
                                           update_global_fn=update_global_fn)

            if self.skip_connections:
                processed_graphs = add_graphs_tuples(
                    graph_net(processed_graphs), processed_graphs)
            else:
                processed_graphs = graph_net(processed_graphs)

            if self.layer_norm:
                processed_graphs = processed_graphs._replace(
                    nodes=nn.LayerNorm()(processed_graphs.nodes),
                    edges=nn.LayerNorm()(processed_graphs.edges),
                    globals=nn.LayerNorm()(processed_graphs.globals),
                )

        # Since our graph-level predictions will be at globals, we will
        # decode to get the required output logits.
        decoder = jraph.GraphMapFeatures(
            embed_global_fn=nn.Dense(self.output_globals_size))
        processed_graphs = decoder(processed_graphs)

        return processed_graphs
Example #4
0
  def __call__(self, graph):
    # Add a global parameter for graph classification.
    graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1]))

    embedder = jraph.GraphMapFeatures(
        embed_node_fn=make_embed_fn(self.latent_size),
        embed_edge_fn=make_embed_fn(self.latent_size),
        embed_global_fn=make_embed_fn(self.latent_size))
    net = jraph.GraphNetwork(
        update_node_fn=make_mlp(self.mlp_features),
        update_edge_fn=make_mlp(self.mlp_features),
        # The global update outputs size 2 for binary classification.
        update_global_fn=make_mlp(self.mlp_features + (2,)))  # pytype: disable=unsupported-operands
    return net(embedder(graph))
Example #5
0
def hookes_hamiltonian_from_graph_fn(
        graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Computes Hamiltonian of a Hooke's potential system represented in a graph.

  While this function hardcodes the Hamiltonian for a Hooke's potential, a
  learned Hamiltonian Graph Network (https://arxiv.org/abs/1909.12790) could
  be implemented by replacing the hardcoded formulas by learnable MLPs that
  take as inputs all of the concatenated features to the edge_fn, node_fn,
  and global_fn, and outputs a single scalar value in the global_fn.

  Args:
    graph: `GraphsTuple` where the nodes contain:
        - "mass": [num_particles]
        - "position": [num_particles, num_dims]
        - "momentum": [num_particles, num_dims]
        and the edges contain:
        - "spring_constant": [num_interations]

  Returns:
    `GraphsTuple` with features:
        - edge features: "hookes_potential" [num_interactions]
        - node features: "kinetic_energy" [num_particles]
        - global features: "hamiltonian" [batch_size]

  """
    def update_edge_fn(edges, senders, receivers, globals_):
        del globals_
        distance = jnp.linalg.norm(senders["position"] - receivers["position"])
        hookes_potential_per_edge = 0.5 * edges["spring_constant"] * distance**2
        return frozendict({"hookes_potential": hookes_potential_per_edge})

    def update_node_fn(nodes, sent_edges, received_edges, globals_):
        del sent_edges, received_edges, globals_
        momentum_norm = jnp.linalg.norm(nodes["momentum"])
        kinetic_energy_per_node = momentum_norm**2 / (2 * nodes["mass"])
        return frozendict({"kinetic_energy": kinetic_energy_per_node})

    def update_global_fn(nodes, edges, globals_):
        del globals_
        # At this point we will receive node and edge features aggregated (summed)
        # for all nodes and edges in each graph.
        hamiltonian_per_graph = nodes["kinetic_energy"] + edges[
            "hookes_potential"]
        return frozendict({"hamiltonian": hamiltonian_per_graph})

    gn = jraph.GraphNetwork(update_edge_fn=update_edge_fn,
                            update_node_fn=update_node_fn,
                            update_global_fn=update_global_fn)

    return gn(graph)
Example #6
0
 def test_sharded_same_as_non_sharded(self, n_edge):
     in_tuple = _get_graphs_from_n_edge(n_edge)
     devices = 3
     sharded_tuple = sharded_graphnet.graphs_tuple_to_broadcasted_sharded_graphs_tuple(
         in_tuple, devices)
     update_fn = jraph.concatenated_args(lambda x: x)
     sharded_gn = sharded_graphnet.ShardedEdgesGraphNetwork(
         update_fn, update_fn, update_fn, num_shards=devices)
     gn = jraph.GraphNetwork(update_fn, update_fn, update_fn)
     sharded_out = jax.pmap(sharded_gn, axis_name='i')(sharded_tuple)
     expected_out = gn(in_tuple)
     reduced_out = sharded_graphnet.broadcasted_sharded_graphs_tuple_to_graphs_tuple(
         sharded_out)
     jax.tree_util.tree_map(
         functools.partial(np.testing.assert_allclose, atol=1E-5,
                           rtol=1E-5), expected_out, reduced_out)
Example #7
0
    def __init__(self,
                 n_recurrences: int,
                 mlp_sizes: Tuple[int, ...],
                 mlp_kwargs: Optional[Dict[str, Any]] = None,
                 format: partition.NeighborListFormat = partition.Dense,
                 name: str = 'GraphNetEncoder'):
        super(GraphNetEncoder, self).__init__(name=name)

        if mlp_kwargs is None:
            mlp_kwargs = {}

        self._n_recurrences = n_recurrences

        embedding_fn = lambda name: hk.nets.MLP(output_sizes=mlp_sizes,
                                                activate_final=True,
                                                name=name,
                                                **mlp_kwargs)

        model_fn = lambda name: lambda *args: hk.nets.MLP(
            output_sizes=mlp_sizes,
            activate_final=True,
            name=name,
            **mlp_kwargs)(jnp.concatenate(args, axis=-1))

        if format is partition.Dense:
            self._encoder = GraphMapFeatures(embedding_fn('EdgeEncoder'),
                                             embedding_fn('NodeEncoder'),
                                             embedding_fn('GlobalEncoder'))
            self._propagation_network = lambda: GraphNetwork(
                model_fn('EdgeFunction'), model_fn('NodeFunction'),
                model_fn('GlobalFunction'))
        elif format is partition.Sparse:
            self._encoder = jraph.GraphMapFeatures(
                embedding_fn('EdgeEncoder'), embedding_fn('NodeEncoder'),
                embedding_fn('GlobalEncoder'))
            self._propagation_network = lambda: jraph.GraphNetwork(
                model_fn('EdgeFunction'), model_fn('NodeFunction'),
                model_fn('GlobalFunction'))
        else:
            raise ValueError()
Example #8
0
    def _processor(
        self,
        graph: jraph.GraphsTuple,
        is_training: bool,
    ) -> jraph.GraphsTuple:
        """Builds the processor."""
        output_sizes = [self._config.mlp_hidden_size] * self._config.mlp_layers
        output_sizes += [self._config.latent_size]
        build_mlp = functools.partial(
            _build_mlp,
            output_sizes=output_sizes,
            use_layer_norm=self._config.use_layer_norm,
        )

        shared_weights = self._config.shared_message_passing_weights
        node_reducer = _REDUCER_NAMES[self._config.node_reducer]
        global_reducer = _REDUCER_NAMES[self._config.global_reducer]

        def dropout_if_training(fn, dropout_rate: float):
            def wrapped(*args):
                out = fn(*args)
                if is_training:
                    mask = hk.dropout(hk.next_rng_key(), dropout_rate,
                                      jnp.ones([out.shape[0], 1]))
                    out = out * mask
                return out

            return wrapped

        num_mps = self._config.num_message_passing_steps
        for step in range(num_mps):
            if step == 0 or not shared_weights:
                suffix = "shared" if shared_weights else step

                update_edge_fn = dropout_if_training(
                    build_mlp(f"edge_processor_{suffix}"),
                    dropout_rate=self._config.dropedge_rate)

                update_node_fn = dropout_if_training(
                    build_mlp(f"node_processor_{suffix}"),
                    dropout_rate=self._config.dropnode_rate)

                if self._config.ignore_globals:
                    gnn = jraph.InteractionNetwork(
                        update_edge_fn=update_edge_fn,
                        update_node_fn=update_node_fn,
                        aggregate_edges_for_nodes_fn=node_reducer)
                else:
                    gnn = jraph.GraphNetwork(
                        update_edge_fn=update_edge_fn,
                        update_node_fn=update_node_fn,
                        update_global_fn=build_mlp(
                            f"global_processor_{suffix}"),
                        aggregate_edges_for_nodes_fn=node_reducer,
                        aggregate_nodes_for_globals_fn=global_reducer,
                        aggregate_edges_for_globals_fn=global_reducer,
                    )

            mode = self._config.processor_mode

            if mode == "mlp":
                graph = gnn(graph)

            elif mode == "resnet":
                new_graph = gnn(graph)
                graph = graph._replace(
                    nodes=graph.nodes + new_graph.nodes,
                    edges=graph.edges + new_graph.edges,
                    globals=graph.globals + new_graph.globals,
                )
            else:
                raise ValueError(f"Unknown processor_mode `{mode}`")

            if self._config.mask_padding_graph_at_every_step:
                graph = _mask_out_padding_graph(graph)

        return graph
Example #9
0
def run():
    """Runs basic example."""

    # Creating graph tuples.

    # Creates a GraphsTuple from scratch containing a single graph.
    # The graph has 3 nodes and 2 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    single_graph = jraph.GraphsTuple(n_node=np.asarray([3]),
                                     n_edge=np.asarray([2]),
                                     nodes=np.ones((3, 4)),
                                     edges=np.ones((2, 5)),
                                     globals=np.ones((1, 6)),
                                     senders=np.array([0, 1]),
                                     receivers=np.array([2, 2]))
    logging.info("Single graph %r", single_graph)

    # Creates a GraphsTuple from scatch containing a single graph with nested
    # feature vectors.
    # The graph has 3 nodes and 2 edges.
    # The feature vector can be arbitrary nested types of dict, list and tuple,
    # or any other type you registered with jax.tree_util.register_pytree_node.
    nested_graph = jraph.GraphsTuple(n_node=np.asarray([3]),
                                     n_edge=np.asarray([2]),
                                     nodes={"a": np.ones((3, 4))},
                                     edges={"b": np.ones((2, 5))},
                                     globals={"c": np.ones((1, 6))},
                                     senders=np.array([0, 1]),
                                     receivers=np.array([2, 2]))
    logging.info("Nested graph %r", nested_graph)

    # Creates a GraphsTuple from scratch containing a 2 graphs using an implicit
    # batch dimension.
    # The first graph has 3 nodes and 2 edges.
    # The second graph has 1 nodes and 1 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    implicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([3, 1]),
                                                 n_edge=np.asarray([2, 1]),
                                                 nodes=np.ones((4, 4)),
                                                 edges=np.ones((3, 5)),
                                                 globals=np.ones((2, 6)),
                                                 senders=np.array([0, 1, 3]),
                                                 receivers=np.array([2, 2, 3]))
    logging.info("Implicitly batched graph %r", implicitly_batched_graph)

    # Creates a GraphsTuple from two existing GraphsTuple using an implicit
    # batch dimension.
    # The GraphsTuple will contain three graphs.
    implicitly_batched_graph = jraph.batch(
        [single_graph, implicitly_batched_graph])
    logging.info("Implicitly batched graph %r", implicitly_batched_graph)

    # Creates multiple GraphsTuples from an existing GraphsTuple with an implicit
    # batch dimension.
    graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph)
    logging.info("Unbatched graphs %r %r %r", graph_1, graph_2, graph_3)

    # Creates a padded GraphsTuple from an existing GraphsTuple.
    # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs.
    # Three graphs are added for the padding.
    # First an dummy graph which contains the padding nodes and edges and secondly
    # two empty graphs without nodes or edges to pad out the graphs.
    padded_graph = jraph.pad_with_graphs(single_graph,
                                         n_node=10,
                                         n_edge=5,
                                         n_graph=4)
    logging.info("Padded graph %r", padded_graph)

    # Creates a GraphsTuple from an existing padded GraphsTuple.
    # The previously added padding is removed.
    single_graph = jraph.unpad_with_graphs(padded_graph)
    logging.info("Unpadded graph %r", single_graph)

    # Creates a GraphsTuple containing a 2 graphs using an explicit batch
    # dimension.
    # An explicit batch dimension requires more memory, but can simplify
    # the definition of functions operating on the graph.
    # Explicitly batched graphs require the GraphNetwork to be transformed
    # by jax.mask followed by jax.vmap.
    # Using an explicit batch requires padding all feature vectors to
    # the maximum size of nodes and edges.
    # The first graph has 3 nodes and 2 edges.
    # The second graph has 1 nodes and 1 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    explicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([[3], [1]]),
                                                 n_edge=np.asarray([[2], [1]]),
                                                 nodes=np.ones((2, 3, 4)),
                                                 edges=np.ones((2, 2, 5)),
                                                 globals=np.ones((2, 1, 6)),
                                                 senders=np.array([[0, 1],
                                                                   [0, -1]]),
                                                 receivers=np.array([[2, 2],
                                                                     [0, -1]]))
    logging.info("Explicitly batched graph %r", explicitly_batched_graph)

    # Running a graph propagation steps.
    # First define the update functions for the edges, nodes and globals.
    # In this example we use the identity everywhere.
    # For Graph neural networks, each update function is typically a neural
    # network.
    def update_edge_fn(edge_features, sender_node_features,
                       receiver_node_features, globals_):
        """Returns the update edge features."""
        del sender_node_features
        del receiver_node_features
        del globals_
        return edge_features

    def update_node_fn(node_features, aggregated_sender_edge_features,
                       aggregated_receiver_edge_features, globals_):
        """Returns the update node features."""
        del aggregated_sender_edge_features
        del aggregated_receiver_edge_features
        del globals_
        return node_features

    def update_globals_fn(aggregated_node_features, aggregated_edge_features,
                          globals_):
        del aggregated_node_features
        del aggregated_edge_features
        return globals_

    # Optionally define custom aggregation functions.
    # In this example we use the defaults (so no need to define them explicitly).
    aggregate_edges_for_nodes_fn = jax.ops.segment_sum
    aggregate_nodes_for_globals_fn = jax.ops.segment_sum
    aggregate_edges_for_globals_fn = jax.ops.segment_sum

    # Optionally define attention logit function and attention reduce function.
    # This can be used for graph attention.
    # The attention function calculates attention weights, and the apply
    # attention function calculates the new edge feature given the weights.
    # We don't use graph attention here, and just pass the defaults.
    attention_logit_fn = None
    attention_reduce_fn = None

    # Creates a new GraphNetwork in its most general form.
    # Most of the arguments have defaults and can be omitted if a feature
    # is not used.
    # There are also predefined GraphNetworks available (see models.py)
    network = jraph.GraphNetwork(
        update_edge_fn=update_edge_fn,
        update_node_fn=update_node_fn,
        update_global_fn=update_globals_fn,
        attention_logit_fn=attention_logit_fn,
        aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn,
        aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn,
        aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn,
        attention_reduce_fn=attention_reduce_fn)

    # Runs graph propagation on (implicitly batched) graphs.
    updated_graph = network(single_graph)
    logging.info("Updated graph from single graph %r", updated_graph)

    updated_graph = network(nested_graph)
    logging.info("Updated graph from nested graph %r", nested_graph)

    updated_graph = network(implicitly_batched_graph)
    logging.info("Updated graph from implicitly batched graph %r",
                 updated_graph)

    updated_graph = network(padded_graph)
    logging.info("Updated graph from padded graph %r", updated_graph)

    # Runs graph propagation on an explicitly batched graph.
    # WARNING: This code relies on an undocumented JAX feature (jax.mask) which
    # might stop working at any time!
    graph_shape = jraph.GraphsTuple(
        n_node="(g)",
        n_edge="(g)",
        nodes="(n, {})".format(explicitly_batched_graph.nodes.shape[-1]),
        edges="(e, {})".format(explicitly_batched_graph.edges.shape[-1]),
        globals="(g, {})".format(explicitly_batched_graph.globals.shape[-1]),
        senders="(e)",
        receivers="(e)")
    batch_size = explicitly_batched_graph.globals.shape[0]
    logical_env = {
        "g": jnp.ones(batch_size, dtype=jnp.int32),
        "n": jnp.sum(explicitly_batched_graph.n_node, axis=-1),
        "e": jnp.sum(explicitly_batched_graph.n_edge, axis=-1)
    }
    try:
        propagation_fn = jax.vmap(
            jax.mask(network, in_shapes=[graph_shape], out_shape=graph_shape))
        updated_graph = propagation_fn([explicitly_batched_graph], logical_env)
        logging.info("Updated graph from explicitly batched graph %r",
                     updated_graph)
    except Exception:  # pylint: disable=broad-except
        logging.warning(MASK_BROKEN_MSG)

    # JIT-compile graph propagation.
    # Use padded graphs to avoid re-compilation at every step!
    jitted_network = jax.jit(network)
    updated_graph = jitted_network(padded_graph)
    logging.info("(JIT) updated graph from padded graph %r", updated_graph)

    # Or use an explicit batch dimension.
    try:
        jitted_propagation_fn = jax.jit(propagation_fn)
        updated_graph = jitted_propagation_fn([explicitly_batched_graph],
                                              logical_env)
        logging.info("(JIT) Updated graph from explicitly batched graph %r",
                     updated_graph)
    except Exception:  # pylint: disable=broad-except
        logging.warning(MASK_BROKEN_MSG)

    logging.info("basic.py complete!")
Example #10
0
def run():
    """Runs basic example."""

    # Creating graph tuples.

    # Creates a GraphsTuple from scratch containing a single graph.
    # The graph has 3 nodes and 2 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    single_graph = jraph.GraphsTuple(n_node=np.asarray([3]),
                                     n_edge=np.asarray([2]),
                                     nodes=np.ones((3, 4)),
                                     edges=np.ones((2, 5)),
                                     globals=np.ones((1, 6)),
                                     senders=np.array([0, 1]),
                                     receivers=np.array([2, 2]))
    logging.info("Single graph %r", single_graph)

    # Creates a GraphsTuple from scratch containing a single graph with nested
    # feature vectors.
    # The graph has 3 nodes and 2 edges.
    # The feature vector can be arbitrary nested types of dict, list and tuple,
    # or any other type you registered with jax.tree_util.register_pytree_node.
    nested_graph = jraph.GraphsTuple(n_node=np.asarray([3]),
                                     n_edge=np.asarray([2]),
                                     nodes={"a": np.ones((3, 4))},
                                     edges={"b": np.ones((2, 5))},
                                     globals={"c": np.ones((1, 6))},
                                     senders=np.array([0, 1]),
                                     receivers=np.array([2, 2]))
    logging.info("Nested graph %r", nested_graph)

    # Creates a GraphsTuple from scratch containing a 2 graphs using an implicit
    # batch dimension.
    # The first graph has 3 nodes and 2 edges.
    # The second graph has 1 nodes and 1 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    implicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([3, 1]),
                                                 n_edge=np.asarray([2, 1]),
                                                 nodes=np.ones((4, 4)),
                                                 edges=np.ones((3, 5)),
                                                 globals=np.ones((2, 6)),
                                                 senders=np.array([0, 1, 3]),
                                                 receivers=np.array([2, 2, 3]))
    logging.info("Implicitly batched graph %r", implicitly_batched_graph)

    # Batching graphs can be challenging. There are in general two approaches:
    # 1. Implicit batching: Independent graphs are combined into the same
    #    GraphsTuple first, and the padding is added to the combined graph.
    # 2. Explicit batching: Pad all graphs to a maximum size, stack them together
    #    using an explicit batch dimension followed by jax.vmap.
    # Both approaches are shown below.

    # Creates a GraphsTuple from two existing GraphsTuple using an implicit
    # batch dimension.
    # The GraphsTuple will contain three graphs.
    implicitly_batched_graph = jraph.batch(
        [single_graph, implicitly_batched_graph])
    logging.info("Implicitly batched graph %r", implicitly_batched_graph)

    # Creates multiple GraphsTuples from an existing GraphsTuple with an implicit
    # batch dimension.
    graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph)
    logging.info("Unbatched graphs %r %r %r", graph_1, graph_2, graph_3)

    # Creates a padded GraphsTuple from an existing GraphsTuple.
    # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs.
    # Three graphs are added for the padding.
    # First an dummy graph which contains the padding nodes and edges and secondly
    # two empty graphs without nodes or edges to pad out the graphs.
    padded_graph = jraph.pad_with_graphs(single_graph,
                                         n_node=10,
                                         n_edge=5,
                                         n_graph=4)
    logging.info("Padded graph %r", padded_graph)

    # Creates a GraphsTuple from an existing padded GraphsTuple.
    # The previously added padding is removed.
    single_graph = jraph.unpad_with_graphs(padded_graph)
    logging.info("Unpadded graph %r", single_graph)

    # Creates a GraphsTuple containing a 2 graphs using an explicit batch
    # dimension.
    # An explicit batch dimension requires more memory, but can simplify
    # the definition of functions operating on the graph.
    # Explicitly batched graphs require the GraphNetwork to be transformed
    # by jax.vmap.
    # Using an explicit batch requires padding all feature vectors to
    # the maximum size of nodes and edges.
    # The first graph has 3 nodes and 2 edges.
    # The second graph has 1 nodes and 1 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    explicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([[3], [1]]),
                                                 n_edge=np.asarray([[2], [1]]),
                                                 nodes=np.ones((2, 3, 4)),
                                                 edges=np.ones((2, 2, 5)),
                                                 globals=np.ones((2, 1, 6)),
                                                 senders=np.array([[0, 1],
                                                                   [0, -1]]),
                                                 receivers=np.array([[2, 2],
                                                                     [0, -1]]))
    logging.info("Explicitly batched graph %r", explicitly_batched_graph)

    # Running a graph propagation steps.
    # First define the update functions for the edges, nodes and globals.
    # In this example we use the identity everywhere.
    # For Graph neural networks, each update function is typically a neural
    # network.
    def update_edge_fn(edge_features, sender_node_features,
                       receiver_node_features, globals_):
        """Returns the update edge features."""
        del sender_node_features
        del receiver_node_features
        del globals_
        return edge_features

    def update_node_fn(node_features, aggregated_sender_edge_features,
                       aggregated_receiver_edge_features, globals_):
        """Returns the update node features."""
        del aggregated_sender_edge_features
        del aggregated_receiver_edge_features
        del globals_
        return node_features

    def update_globals_fn(aggregated_node_features, aggregated_edge_features,
                          globals_):
        del aggregated_node_features
        del aggregated_edge_features
        return globals_

    # Optionally define custom aggregation functions.
    # In this example we use the defaults (so no need to define them explicitly).
    aggregate_edges_for_nodes_fn = jraph.segment_sum
    aggregate_nodes_for_globals_fn = jraph.segment_sum
    aggregate_edges_for_globals_fn = jraph.segment_sum

    # Optionally define attention logit function and attention reduce function.
    # This can be used for graph attention.
    # The attention function calculates attention weights, and the apply
    # attention function calculates the new edge feature given the weights.
    # We don't use graph attention here, and just pass the defaults.
    attention_logit_fn = None
    attention_reduce_fn = None

    # Creates a new GraphNetwork in its most general form.
    # Most of the arguments have defaults and can be omitted if a feature
    # is not used.
    # There are also predefined GraphNetworks available (see models.py)
    network = jraph.GraphNetwork(
        update_edge_fn=update_edge_fn,
        update_node_fn=update_node_fn,
        update_global_fn=update_globals_fn,
        attention_logit_fn=attention_logit_fn,
        aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn,
        aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn,
        aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn,
        attention_reduce_fn=attention_reduce_fn)

    # Runs graph propagation on (implicitly batched) graphs.
    updated_graph = network(single_graph)
    logging.info("Updated graph from single graph %r", updated_graph)

    updated_graph = network(nested_graph)
    logging.info("Updated graph from nested graph %r", nested_graph)

    updated_graph = network(implicitly_batched_graph)
    logging.info("Updated graph from implicitly batched graph %r",
                 updated_graph)

    updated_graph = network(padded_graph)
    logging.info("Updated graph from padded graph %r", updated_graph)

    # JIT-compile graph propagation.
    # Use padded graphs to avoid re-compilation at every step!
    jitted_network = jax.jit(network)
    updated_graph = jitted_network(padded_graph)
    logging.info("(JIT) updated graph from padded graph %r", updated_graph)
    logging.info("basic.py complete!")