Beispiel #1
0
def pad_graph_to_nearest_power_of_two(
    graphs_tuple: jraph.GraphsTuple) -> jraph.GraphsTuple:
  """Pads a batched `GraphsTuple` to the nearest power of two.

  For example, if a `GraphsTuple` has 7 nodes, 5 edges and 3 graphs, this method
  would pad the `GraphsTuple` nodes and edges:
    7 nodes --> 8 nodes (2^3)
    5 edges --> 8 edges (2^3)

  And since padding is accomplished using `jraph.pad_with_graphs`, an extra
  graph and node is added:
    8 nodes --> 9 nodes
    3 graphs --> 4 graphs

  Args:
    graphs_tuple: a batched `GraphsTuple` (can be batch size 1).

  Returns:
    A graphs_tuple batched to the nearest power of two.
  """
  # Add 1 since we need at least one padding node for pad_with_graphs.
  pad_nodes_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_node)) + 1
  pad_edges_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_edge))
  # Add 1 since we need at least one padding graph for pad_with_graphs.
  # We do not pad to nearest power of two because the batch size is fixed.
  pad_graphs_to = graphs_tuple.n_node.shape[0] + 1
  return jraph.pad_with_graphs(graphs_tuple, pad_nodes_to, pad_edges_to,
                               pad_graphs_to)
Beispiel #2
0
def get_voting_problem(min_n_voters: int, max_n_voters: int) -> Problem:
    """Creates set of one-hot vectors representing a randomly generated election.

  Args:
    min_n_voters: minimum number of voters in the election.
    max_n_voters: maximum number of voters in the election.

  Returns:
    set, one-hot vector encoding the winner.
  """
    n_candidates = 20
    n_voters = random.randint(min_n_voters, max_n_voters)
    votes = np.random.randint(0, n_candidates, size=(n_voters, ))
    one_hot_votes = np.eye(n_candidates)[votes]
    winner = np.argmax(np.sum(one_hot_votes, axis=0))
    one_hot_winner = np.eye(n_candidates)[winner]

    graph = jraph.GraphsTuple(
        n_node=np.asarray([n_voters]),
        n_edge=np.asarray([0]),
        nodes=one_hot_votes,
        edges=None,
        globals=np.zeros((1, n_candidates)),
        # There are no edges in our graph.
        senders=np.array([], dtype=np.int32),
        receivers=np.array([], dtype=np.int32))

    # In order to jit compile our code, we have to pad the nodes and edges of
    # the GraphsTuple to a static shape.
    graph = jraph.pad_with_graphs(graph, max_n_voters + 1, 0)

    return Problem(graph=graph, labels=one_hot_winner)
Beispiel #3
0
def get_2sat_problem(min_n_literals: int, max_n_literals: int) -> Problem:
  """Creates bipartite-graph representing a randomly generated 2-sat problem.

  Args:
    min_n_literals: minimum number of literals in the 2-sat problem.
    max_n_literals: maximum number of literals in the 2-sat problem.

  Returns:
    bipartite-graph, node labels and node mask.
  """
  n_literals = random.randint(min_n_literals, max_n_literals)
  n_literals_true = random.randint(1, n_literals - 1)
  n_constraints = n_literals * (n_literals - 1) // 2

  n_node = n_literals +  n_constraints
  # 0 indicates a literal node
  # 1 indicates a constraint node.
  nodes = [0 if i < n_literals else 1 for i in range(n_node)]
  edges = []
  senders = []
  for literal_node1 in range(n_literals):
    for literal_node2 in range(literal_node1 + 1, n_literals):
      senders.append(literal_node1)
      senders.append(literal_node2)
      # 1 indicates that the literal must be true for this constraint.
      # 0 indicates that the literal must be false for this constraint.
      # I.e. with literals a and b, we have the following possible constraints:
      # 0, 0 -> a or b
      # 1, 0 -> not a or b
      # 0, 1 -> a or not b
      # 1, 1 -> not a or not b
      edges.append(1 if literal_node1 < n_literals_true else 0)
      edges.append(1 if literal_node2 < n_literals_true else 0)

  graph = jraph.GraphsTuple(
      n_node=np.asarray([n_node]),
      n_edge=np.asarray([2 * n_constraints]),
      # One-hot encoding for nodes and edges.
      nodes=np.eye(2)[nodes],
      edges=np.eye(2)[edges],
      globals=None,
      senders=np.asarray(senders),
      receivers=np.repeat(np.arange(n_constraints) + n_literals, 2))

  # In order to jit compile our code, we have to pad the nodes and edges of
  # the GraphsTuple to a static shape.
  max_n_constraints = max_n_literals * (max_n_literals - 1) // 2
  max_nodes = max_n_literals + max_n_constraints  + 1
  max_edges = 2 * max_n_constraints
  graph = jraph.pad_with_graphs(graph, max_nodes, max_edges)

  # The ground truth solution for the 2-sat problem.
  labels = (np.arange(max_nodes) < n_literals_true).astype(np.int32)
  labels = np.eye(2)[labels]

  # For the loss calculation we create a mask for the nodes, which masks the
  # the constraint nodes and the padding nodes.
  mask = (np.arange(max_nodes) < n_literals).astype(np.int32)
  return Problem(graph=graph, labels=labels, mask=mask)
Beispiel #4
0
  def test_graph_model(self):
    """Test forward pass of the GNN model."""
    edge_input_shape = (5,)
    node_input_shape = (5,)
    output_shape = (5,)
    model_str = 'gnn'
    model_hps = models.get_model_hparams(model_str)
    model_hps.update({'output_shape': output_shape,
                      'latent_dim': 10,
                      'hidden_dims': (10,),
                      'batch_size': 5,
                      'normalizer': 'batch_norm'})
    model_cls = models.get_model(model_str)
    rng = jax.random.PRNGKey(0)
    dropout_rng, params_rng = jax.random.split(rng)
    loss = 'sigmoid_binary_cross_entropy'
    metrics = 'binary_classification_metrics'
    model = model_cls(model_hps, {}, loss, metrics)

    num_graphs = 5
    node_per_graph = 3
    edge_per_graph = 9
    inputs = jraph.get_fully_connected_graph(
        n_node_per_graph=node_per_graph,
        n_graph=num_graphs,
        node_features=np.ones((num_graphs * node_per_graph,) +
                              node_input_shape),
    )
    inputs = inputs._replace(
        edges=np.ones((num_graphs * edge_per_graph,) + edge_input_shape))
    padded_inputs = jraph.pad_with_graphs(inputs, 20, 50, 7)
    model_init_fn = jax.jit(
        functools.partial(model.flax_module.init, train=False))
    init_dict = model_init_fn({'params': params_rng}, padded_inputs)
    params = init_dict['params']
    batch_stats = init_dict['batch_stats']

    # Check that the forward pass works with mutated batch_stats.
    outputs, _ = model.flax_module.apply(
        {'params': params, 'batch_stats': batch_stats},
        padded_inputs,
        mutable=['batch_stats'],
        rngs={'dropout': dropout_rng},
        train=True)
    self.assertEqual(outputs.shape, (7,) + output_shape)
def make_graph_from_static_structure(positions, types, box, edge_threshold):
    """Returns graph representing the static structure of the glass.

  Each particle is represented by a node in the graph. The particle type is
  stored as a node feature.
  Two particles at a distance less than the threshold are connected by an edge.
  The relative distance vector is stored as an edge feature.

  Args:
    positions: particle positions with shape [n_particles, 3].
    types: particle types with shape [n_particles].
    box: dimensions of the cubic box that contains the particles with shape [3].
    edge_threshold: particles at distance less than threshold are connected by
      an edge.
  """
    # Calculate pairwise relative distances between particles: shape [n, n, 3].
    cross_positions = positions[None, :, :] - positions[:, None, :]
    # Enforces periodic boundary conditions.
    box_ = box[None, None, :]
    cross_positions += (cross_positions < -box_ / 2.).astype(np.float32) * box_
    cross_positions -= (cross_positions > box_ / 2.).astype(np.float32) * box_
    # Calculates adjacency matrix in a sparse format (indices), based on the given
    # distances and threshold.
    distances = np.linalg.norm(cross_positions, axis=-1)
    indices = np.where(distances < edge_threshold)
    # Defines graph.
    nodes = types[:, None]
    senders = indices[0]
    receivers = indices[1]
    edges = cross_positions[indices]

    return jraph.pad_with_graphs(jraph.GraphsTuple(
        nodes=nodes.astype(np.float32),
        n_node=np.reshape(nodes.shape[0], [1]),
        edges=edges.astype(np.float32),
        n_edge=np.reshape(edges.shape[0], [1]),
        globals=np.zeros((1, 1), dtype=np.float32),
        receivers=receivers.astype(np.int32),
        senders=senders.astype(np.int32)),
                                 n_node=4097,
                                 n_edge=200000)
Beispiel #6
0
def get_higgs_problem(min_n_photons: int, max_n_photons: int) -> Problem:
  """Creates fully connected graph containing the detected photons.

  Args:
    min_n_photons: minimum number of photons in the detector.
    max_n_photons: maximum number of photons in the detector.

  Returns:
    graph, one-hot label whether a higgs was present or not.
  """
  assert min_n_photons >= 2, "Number of photons must be at least 2."
  n_photons = random.randint(min_n_photons, max_n_photons)
  photons = np.stack([get_random_background_photon() for _ in range(n_photons)])

  # Add a higgs
  if random.random() > 0.5:
    label = np.eye(2)[0]
    photons[:2] = np.stack(get_random_higgs_photons())
  else:
    label = np.eye(2)[1]

  # The graph is fully connected.
  senders = np.repeat(np.arange(n_photons), n_photons)
  receivers = np.tile(np.arange(n_photons), n_photons)
  graph = jraph.GraphsTuple(
      n_node=np.asarray([n_photons]),
      n_edge=np.asarray([len(senders)]),
      nodes=photons,
      edges=None,
      globals=None,
      senders=senders,
      receivers=receivers)

  # In order to jit compile our code, we have to pad the nodes and edges of
  # the GraphsTuple to a static shape.
  graph = jraph.pad_with_graphs(graph, max_n_photons + 1,
                                max_n_photons * max_n_photons)

  return Problem(graph=graph, labels=label)
Beispiel #7
0
def run():
    """Runs basic example."""

    # Creating graph tuples.

    # Creates a GraphsTuple from scratch containing a single graph.
    # The graph has 3 nodes and 2 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    single_graph = jraph.GraphsTuple(n_node=np.asarray([3]),
                                     n_edge=np.asarray([2]),
                                     nodes=np.ones((3, 4)),
                                     edges=np.ones((2, 5)),
                                     globals=np.ones((1, 6)),
                                     senders=np.array([0, 1]),
                                     receivers=np.array([2, 2]))
    logging.info("Single graph %r", single_graph)

    # Creates a GraphsTuple from scatch containing a single graph with nested
    # feature vectors.
    # The graph has 3 nodes and 2 edges.
    # The feature vector can be arbitrary nested types of dict, list and tuple,
    # or any other type you registered with jax.tree_util.register_pytree_node.
    nested_graph = jraph.GraphsTuple(n_node=np.asarray([3]),
                                     n_edge=np.asarray([2]),
                                     nodes={"a": np.ones((3, 4))},
                                     edges={"b": np.ones((2, 5))},
                                     globals={"c": np.ones((1, 6))},
                                     senders=np.array([0, 1]),
                                     receivers=np.array([2, 2]))
    logging.info("Nested graph %r", nested_graph)

    # Creates a GraphsTuple from scratch containing a 2 graphs using an implicit
    # batch dimension.
    # The first graph has 3 nodes and 2 edges.
    # The second graph has 1 nodes and 1 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    implicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([3, 1]),
                                                 n_edge=np.asarray([2, 1]),
                                                 nodes=np.ones((4, 4)),
                                                 edges=np.ones((3, 5)),
                                                 globals=np.ones((2, 6)),
                                                 senders=np.array([0, 1, 3]),
                                                 receivers=np.array([2, 2, 3]))
    logging.info("Implicitly batched graph %r", implicitly_batched_graph)

    # Creates a GraphsTuple from two existing GraphsTuple using an implicit
    # batch dimension.
    # The GraphsTuple will contain three graphs.
    implicitly_batched_graph = jraph.batch(
        [single_graph, implicitly_batched_graph])
    logging.info("Implicitly batched graph %r", implicitly_batched_graph)

    # Creates multiple GraphsTuples from an existing GraphsTuple with an implicit
    # batch dimension.
    graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph)
    logging.info("Unbatched graphs %r %r %r", graph_1, graph_2, graph_3)

    # Creates a padded GraphsTuple from an existing GraphsTuple.
    # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs.
    # Three graphs are added for the padding.
    # First an dummy graph which contains the padding nodes and edges and secondly
    # two empty graphs without nodes or edges to pad out the graphs.
    padded_graph = jraph.pad_with_graphs(single_graph,
                                         n_node=10,
                                         n_edge=5,
                                         n_graph=4)
    logging.info("Padded graph %r", padded_graph)

    # Creates a GraphsTuple from an existing padded GraphsTuple.
    # The previously added padding is removed.
    single_graph = jraph.unpad_with_graphs(padded_graph)
    logging.info("Unpadded graph %r", single_graph)

    # Creates a GraphsTuple containing a 2 graphs using an explicit batch
    # dimension.
    # An explicit batch dimension requires more memory, but can simplify
    # the definition of functions operating on the graph.
    # Explicitly batched graphs require the GraphNetwork to be transformed
    # by jax.mask followed by jax.vmap.
    # Using an explicit batch requires padding all feature vectors to
    # the maximum size of nodes and edges.
    # The first graph has 3 nodes and 2 edges.
    # The second graph has 1 nodes and 1 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    explicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([[3], [1]]),
                                                 n_edge=np.asarray([[2], [1]]),
                                                 nodes=np.ones((2, 3, 4)),
                                                 edges=np.ones((2, 2, 5)),
                                                 globals=np.ones((2, 1, 6)),
                                                 senders=np.array([[0, 1],
                                                                   [0, -1]]),
                                                 receivers=np.array([[2, 2],
                                                                     [0, -1]]))
    logging.info("Explicitly batched graph %r", explicitly_batched_graph)

    # Running a graph propagation steps.
    # First define the update functions for the edges, nodes and globals.
    # In this example we use the identity everywhere.
    # For Graph neural networks, each update function is typically a neural
    # network.
    def update_edge_fn(edge_features, sender_node_features,
                       receiver_node_features, globals_):
        """Returns the update edge features."""
        del sender_node_features
        del receiver_node_features
        del globals_
        return edge_features

    def update_node_fn(node_features, aggregated_sender_edge_features,
                       aggregated_receiver_edge_features, globals_):
        """Returns the update node features."""
        del aggregated_sender_edge_features
        del aggregated_receiver_edge_features
        del globals_
        return node_features

    def update_globals_fn(aggregated_node_features, aggregated_edge_features,
                          globals_):
        del aggregated_node_features
        del aggregated_edge_features
        return globals_

    # Optionally define custom aggregation functions.
    # In this example we use the defaults (so no need to define them explicitly).
    aggregate_edges_for_nodes_fn = jax.ops.segment_sum
    aggregate_nodes_for_globals_fn = jax.ops.segment_sum
    aggregate_edges_for_globals_fn = jax.ops.segment_sum

    # Optionally define attention logit function and attention reduce function.
    # This can be used for graph attention.
    # The attention function calculates attention weights, and the apply
    # attention function calculates the new edge feature given the weights.
    # We don't use graph attention here, and just pass the defaults.
    attention_logit_fn = None
    attention_reduce_fn = None

    # Creates a new GraphNetwork in its most general form.
    # Most of the arguments have defaults and can be omitted if a feature
    # is not used.
    # There are also predefined GraphNetworks available (see models.py)
    network = jraph.GraphNetwork(
        update_edge_fn=update_edge_fn,
        update_node_fn=update_node_fn,
        update_global_fn=update_globals_fn,
        attention_logit_fn=attention_logit_fn,
        aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn,
        aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn,
        aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn,
        attention_reduce_fn=attention_reduce_fn)

    # Runs graph propagation on (implicitly batched) graphs.
    updated_graph = network(single_graph)
    logging.info("Updated graph from single graph %r", updated_graph)

    updated_graph = network(nested_graph)
    logging.info("Updated graph from nested graph %r", nested_graph)

    updated_graph = network(implicitly_batched_graph)
    logging.info("Updated graph from implicitly batched graph %r",
                 updated_graph)

    updated_graph = network(padded_graph)
    logging.info("Updated graph from padded graph %r", updated_graph)

    # Runs graph propagation on an explicitly batched graph.
    # WARNING: This code relies on an undocumented JAX feature (jax.mask) which
    # might stop working at any time!
    graph_shape = jraph.GraphsTuple(
        n_node="(g)",
        n_edge="(g)",
        nodes="(n, {})".format(explicitly_batched_graph.nodes.shape[-1]),
        edges="(e, {})".format(explicitly_batched_graph.edges.shape[-1]),
        globals="(g, {})".format(explicitly_batched_graph.globals.shape[-1]),
        senders="(e)",
        receivers="(e)")
    batch_size = explicitly_batched_graph.globals.shape[0]
    logical_env = {
        "g": jnp.ones(batch_size, dtype=jnp.int32),
        "n": jnp.sum(explicitly_batched_graph.n_node, axis=-1),
        "e": jnp.sum(explicitly_batched_graph.n_edge, axis=-1)
    }
    try:
        propagation_fn = jax.vmap(
            jax.mask(network, in_shapes=[graph_shape], out_shape=graph_shape))
        updated_graph = propagation_fn([explicitly_batched_graph], logical_env)
        logging.info("Updated graph from explicitly batched graph %r",
                     updated_graph)
    except Exception:  # pylint: disable=broad-except
        logging.warning(MASK_BROKEN_MSG)

    # JIT-compile graph propagation.
    # Use padded graphs to avoid re-compilation at every step!
    jitted_network = jax.jit(network)
    updated_graph = jitted_network(padded_graph)
    logging.info("(JIT) updated graph from padded graph %r", updated_graph)

    # Or use an explicit batch dimension.
    try:
        jitted_propagation_fn = jax.jit(propagation_fn)
        updated_graph = jitted_propagation_fn([explicitly_batched_graph],
                                              logical_env)
        logging.info("(JIT) Updated graph from explicitly batched graph %r",
                     updated_graph)
    except Exception:  # pylint: disable=broad-except
        logging.warning(MASK_BROKEN_MSG)

    logging.info("basic.py complete!")
Beispiel #8
0
def dynamically_batch(graphs_tuple_iterator: Iterator[jraph.GraphsTuple],
                      n_node: int, n_edge: int,
                      n_graph: int) -> Generator[jraph.GraphsTuple, None, None]:
  """Dynamically batches trees with `jraph.GraphsTuples` to `graph_batch_size`.

  Elements of the `graphs_tuple_iterator` will be incrementally added to a batch
  until the limits defined by `n_node`, `n_edge` and `n_graph` are reached. This
  means each element yielded by this generator

  For situations where you have variable sized data, it"s useful to be able to
  have variable sized batches. This is especially the case if you have a loss
  defined on the variable shaped element (for example, nodes in a graph).

  Args:
    graphs_tuple_iterator: An iterator of `jraph.GraphsTuples`.
    n_node: The maximum number of nodes in a batch.
    n_edge: The maximum number of edges in a batch.
    n_graph: The maximum number of graphs in a batch.

  Yields:
    A `jraph.GraphsTuple` batch of graphs.

  Raises:
    ValueError: if the number of graphs is < 2.
    RuntimeError: if the `graphs_tuple_iterator` contains elements which are not
      `jraph.GraphsTuple`s.
    RuntimeError: if a graph is found which is larger than the batch size.
  """
  if n_graph < 2:
    raise ValueError("The number of graphs in a batch size must be greater or "
                     f"equal to `2` for padding with graphs, got {n_graph}.")
  valid_batch_size = (n_node - 1, n_edge, n_graph - 1)
  accumulated_graphs = []
  num_accumulated_nodes = 0
  num_accumulated_edges = 0
  num_accumulated_graphs = 0
  for element in graphs_tuple_iterator:
    element_nodes, element_edges, element_graphs = _get_graph_size(element)
    if _is_over_batch_size(element, valid_batch_size):
      graph_size = element_nodes, element_edges, element_graphs
      graph_size = {k: v for k, v in zip(_NUMBER_FIELDS, graph_size)}
      batch_size = {k: v for k, v in zip(_NUMBER_FIELDS, valid_batch_size)}
      raise RuntimeError("Found graph bigger than batch size. Valid Batch "
                         f"Size: {batch_size}, Graph Size: {graph_size}")

    if not accumulated_graphs:
      # If this is the first element of the batch, set it and continue.
      accumulated_graphs = [element]
      num_accumulated_nodes = element_nodes
      num_accumulated_edges = element_edges
      num_accumulated_graphs = element_graphs
      continue
    else:
      # Otherwise check if there is space for the graph in the batch:
      if ((num_accumulated_graphs + element_graphs > n_graph - 1) or
          (num_accumulated_nodes + element_nodes > n_node - 1) or
          (num_accumulated_edges + element_edges > n_edge)):
        # If there is, add it to the batch
        batched_graph = _batch_np(accumulated_graphs)
        yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph)
        accumulated_graphs = [element]
        num_accumulated_nodes = element_nodes
        num_accumulated_edges = element_edges
        num_accumulated_graphs = element_graphs
      else:
        # Otherwise, return the old batch and start a new batch.
        accumulated_graphs.append(element)
        num_accumulated_nodes += element_nodes
        num_accumulated_edges += element_edges
        num_accumulated_graphs += element_graphs

  # We may still have data in batched graph.
  if accumulated_graphs:
    batched_graph = _batch_np(accumulated_graphs)
    yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph)
Beispiel #9
0
def run():
    """Runs basic example."""

    # Creating graph tuples.

    # Creates a GraphsTuple from scratch containing a single graph.
    # The graph has 3 nodes and 2 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    single_graph = jraph.GraphsTuple(n_node=np.asarray([3]),
                                     n_edge=np.asarray([2]),
                                     nodes=np.ones((3, 4)),
                                     edges=np.ones((2, 5)),
                                     globals=np.ones((1, 6)),
                                     senders=np.array([0, 1]),
                                     receivers=np.array([2, 2]))
    logging.info("Single graph %r", single_graph)

    # Creates a GraphsTuple from scratch containing a single graph with nested
    # feature vectors.
    # The graph has 3 nodes and 2 edges.
    # The feature vector can be arbitrary nested types of dict, list and tuple,
    # or any other type you registered with jax.tree_util.register_pytree_node.
    nested_graph = jraph.GraphsTuple(n_node=np.asarray([3]),
                                     n_edge=np.asarray([2]),
                                     nodes={"a": np.ones((3, 4))},
                                     edges={"b": np.ones((2, 5))},
                                     globals={"c": np.ones((1, 6))},
                                     senders=np.array([0, 1]),
                                     receivers=np.array([2, 2]))
    logging.info("Nested graph %r", nested_graph)

    # Creates a GraphsTuple from scratch containing a 2 graphs using an implicit
    # batch dimension.
    # The first graph has 3 nodes and 2 edges.
    # The second graph has 1 nodes and 1 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    implicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([3, 1]),
                                                 n_edge=np.asarray([2, 1]),
                                                 nodes=np.ones((4, 4)),
                                                 edges=np.ones((3, 5)),
                                                 globals=np.ones((2, 6)),
                                                 senders=np.array([0, 1, 3]),
                                                 receivers=np.array([2, 2, 3]))
    logging.info("Implicitly batched graph %r", implicitly_batched_graph)

    # Batching graphs can be challenging. There are in general two approaches:
    # 1. Implicit batching: Independent graphs are combined into the same
    #    GraphsTuple first, and the padding is added to the combined graph.
    # 2. Explicit batching: Pad all graphs to a maximum size, stack them together
    #    using an explicit batch dimension followed by jax.vmap.
    # Both approaches are shown below.

    # Creates a GraphsTuple from two existing GraphsTuple using an implicit
    # batch dimension.
    # The GraphsTuple will contain three graphs.
    implicitly_batched_graph = jraph.batch(
        [single_graph, implicitly_batched_graph])
    logging.info("Implicitly batched graph %r", implicitly_batched_graph)

    # Creates multiple GraphsTuples from an existing GraphsTuple with an implicit
    # batch dimension.
    graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph)
    logging.info("Unbatched graphs %r %r %r", graph_1, graph_2, graph_3)

    # Creates a padded GraphsTuple from an existing GraphsTuple.
    # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs.
    # Three graphs are added for the padding.
    # First an dummy graph which contains the padding nodes and edges and secondly
    # two empty graphs without nodes or edges to pad out the graphs.
    padded_graph = jraph.pad_with_graphs(single_graph,
                                         n_node=10,
                                         n_edge=5,
                                         n_graph=4)
    logging.info("Padded graph %r", padded_graph)

    # Creates a GraphsTuple from an existing padded GraphsTuple.
    # The previously added padding is removed.
    single_graph = jraph.unpad_with_graphs(padded_graph)
    logging.info("Unpadded graph %r", single_graph)

    # Creates a GraphsTuple containing a 2 graphs using an explicit batch
    # dimension.
    # An explicit batch dimension requires more memory, but can simplify
    # the definition of functions operating on the graph.
    # Explicitly batched graphs require the GraphNetwork to be transformed
    # by jax.vmap.
    # Using an explicit batch requires padding all feature vectors to
    # the maximum size of nodes and edges.
    # The first graph has 3 nodes and 2 edges.
    # The second graph has 1 nodes and 1 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    explicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([[3], [1]]),
                                                 n_edge=np.asarray([[2], [1]]),
                                                 nodes=np.ones((2, 3, 4)),
                                                 edges=np.ones((2, 2, 5)),
                                                 globals=np.ones((2, 1, 6)),
                                                 senders=np.array([[0, 1],
                                                                   [0, -1]]),
                                                 receivers=np.array([[2, 2],
                                                                     [0, -1]]))
    logging.info("Explicitly batched graph %r", explicitly_batched_graph)

    # Running a graph propagation steps.
    # First define the update functions for the edges, nodes and globals.
    # In this example we use the identity everywhere.
    # For Graph neural networks, each update function is typically a neural
    # network.
    def update_edge_fn(edge_features, sender_node_features,
                       receiver_node_features, globals_):
        """Returns the update edge features."""
        del sender_node_features
        del receiver_node_features
        del globals_
        return edge_features

    def update_node_fn(node_features, aggregated_sender_edge_features,
                       aggregated_receiver_edge_features, globals_):
        """Returns the update node features."""
        del aggregated_sender_edge_features
        del aggregated_receiver_edge_features
        del globals_
        return node_features

    def update_globals_fn(aggregated_node_features, aggregated_edge_features,
                          globals_):
        del aggregated_node_features
        del aggregated_edge_features
        return globals_

    # Optionally define custom aggregation functions.
    # In this example we use the defaults (so no need to define them explicitly).
    aggregate_edges_for_nodes_fn = jraph.segment_sum
    aggregate_nodes_for_globals_fn = jraph.segment_sum
    aggregate_edges_for_globals_fn = jraph.segment_sum

    # Optionally define attention logit function and attention reduce function.
    # This can be used for graph attention.
    # The attention function calculates attention weights, and the apply
    # attention function calculates the new edge feature given the weights.
    # We don't use graph attention here, and just pass the defaults.
    attention_logit_fn = None
    attention_reduce_fn = None

    # Creates a new GraphNetwork in its most general form.
    # Most of the arguments have defaults and can be omitted if a feature
    # is not used.
    # There are also predefined GraphNetworks available (see models.py)
    network = jraph.GraphNetwork(
        update_edge_fn=update_edge_fn,
        update_node_fn=update_node_fn,
        update_global_fn=update_globals_fn,
        attention_logit_fn=attention_logit_fn,
        aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn,
        aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn,
        aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn,
        attention_reduce_fn=attention_reduce_fn)

    # Runs graph propagation on (implicitly batched) graphs.
    updated_graph = network(single_graph)
    logging.info("Updated graph from single graph %r", updated_graph)

    updated_graph = network(nested_graph)
    logging.info("Updated graph from nested graph %r", nested_graph)

    updated_graph = network(implicitly_batched_graph)
    logging.info("Updated graph from implicitly batched graph %r",
                 updated_graph)

    updated_graph = network(padded_graph)
    logging.info("Updated graph from padded graph %r", updated_graph)

    # JIT-compile graph propagation.
    # Use padded graphs to avoid re-compilation at every step!
    jitted_network = jax.jit(network)
    updated_graph = jitted_network(padded_graph)
    logging.info("(JIT) updated graph from padded graph %r", updated_graph)
    logging.info("basic.py complete!")