def batch_graphs_by_device(graphs: List[jraph.GraphsTuple],
                           num_devices: int) -> List[jraph.GraphsTuple]:
    """Batch a list of graphs into num_devices batched graphs.

  The input graphs are grouped into num_devices groups. Within each group the
  graphs are merged. This is needed for parallelizing the graphs using pmap.

  Args:
    graphs: a list of graphs to be merged.
    num_devices: the number of local devices.

  Returns:
    graph: a size num_devices list of merged graphs.
  """
    bs = len(graphs)
    assert bs % num_devices == 0, (
        'Batch size {} is not divisible by {} devices.'.format(
            bs, num_devices))
    bs_per_device = bs // num_devices
    graphs_on_devices = []
    for i in range(num_devices):
        graphs_on_device_i = graphs[i * bs_per_device:(i + 1) * bs_per_device]
        graphs_on_device_i = jraph.batch(graphs_on_device_i)
        graphs_on_devices.append(graphs_on_device_i)
    return graphs_on_devices
示例#2
0
def preprocess(batch, model_type, num_devices=1):
    """Preprocess the batch before sending to the model."""
    if model_type == 'text':
        if 'graphs' in batch:
            del batch['graphs']
    elif model_type == 'bow2text':
        # Do nothing, bow2text data is already in a good form.
        pass
    else:  # graph2text
        if num_devices == 1:
            graphs = gn.pad_graphs(jraph.batch(batch['graphs']))
        else:
            # We need to first batch graphs into num_devices batchs.
            graphs = gn.batch_graphs_by_device(batch['graphs'], num_devices)
            # Then we pad them to the maximum graph size in the batch and concat.
            # This way graphs can be distributed to each device through pmap.
            graphs = gn.pad_graphs_by_device(graphs)
        max_graph_size = gn.pad_size(graphs.n_node.max())
        batch.update({'graphs': graphs, 'max_graph_size': max_graph_size})
    return batch
示例#3
0
 def _get_batch(n_nodes_list, batch_size):
     # Hardcode batch_size to be 2 for simplicity
     del batch_size
     graphs_list = []
     labels_list = []
     weights_list = []
     for n_nodes in n_nodes_list:
         n_edges = n_nodes**2
         graph = jraph.get_fully_connected_graph(
             n_nodes, 1, np.ones((n_nodes, *hps.input_node_shape)))
         graph = graph._replace(edges=np.ones((n_edges,
                                               *hps.input_edge_shape)))
         labels = np.ones(
             hps.output_shape) * (1 if n_nodes in [4, 6] else 0)
         weights = np.ones(*hps.output_shape)
         graphs_list.append(graph)
         labels_list.append(labels)
         weights_list.append(weights)
     return {
         'inputs': jraph.batch(graphs_list),
         'targets': np.stack(labels_list),
         'weights': np.stack(weights_list),
     }
示例#4
0
文件: basic.py 项目: wayne9qiu/jraph
def run():
    """Runs basic example."""

    # Creating graph tuples.

    # Creates a GraphsTuple from scratch containing a single graph.
    # The graph has 3 nodes and 2 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    single_graph = jraph.GraphsTuple(n_node=np.asarray([3]),
                                     n_edge=np.asarray([2]),
                                     nodes=np.ones((3, 4)),
                                     edges=np.ones((2, 5)),
                                     globals=np.ones((1, 6)),
                                     senders=np.array([0, 1]),
                                     receivers=np.array([2, 2]))
    logging.info("Single graph %r", single_graph)

    # Creates a GraphsTuple from scatch containing a single graph with nested
    # feature vectors.
    # The graph has 3 nodes and 2 edges.
    # The feature vector can be arbitrary nested types of dict, list and tuple,
    # or any other type you registered with jax.tree_util.register_pytree_node.
    nested_graph = jraph.GraphsTuple(n_node=np.asarray([3]),
                                     n_edge=np.asarray([2]),
                                     nodes={"a": np.ones((3, 4))},
                                     edges={"b": np.ones((2, 5))},
                                     globals={"c": np.ones((1, 6))},
                                     senders=np.array([0, 1]),
                                     receivers=np.array([2, 2]))
    logging.info("Nested graph %r", nested_graph)

    # Creates a GraphsTuple from scratch containing a 2 graphs using an implicit
    # batch dimension.
    # The first graph has 3 nodes and 2 edges.
    # The second graph has 1 nodes and 1 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    implicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([3, 1]),
                                                 n_edge=np.asarray([2, 1]),
                                                 nodes=np.ones((4, 4)),
                                                 edges=np.ones((3, 5)),
                                                 globals=np.ones((2, 6)),
                                                 senders=np.array([0, 1, 3]),
                                                 receivers=np.array([2, 2, 3]))
    logging.info("Implicitly batched graph %r", implicitly_batched_graph)

    # Creates a GraphsTuple from two existing GraphsTuple using an implicit
    # batch dimension.
    # The GraphsTuple will contain three graphs.
    implicitly_batched_graph = jraph.batch(
        [single_graph, implicitly_batched_graph])
    logging.info("Implicitly batched graph %r", implicitly_batched_graph)

    # Creates multiple GraphsTuples from an existing GraphsTuple with an implicit
    # batch dimension.
    graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph)
    logging.info("Unbatched graphs %r %r %r", graph_1, graph_2, graph_3)

    # Creates a padded GraphsTuple from an existing GraphsTuple.
    # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs.
    # Three graphs are added for the padding.
    # First an dummy graph which contains the padding nodes and edges and secondly
    # two empty graphs without nodes or edges to pad out the graphs.
    padded_graph = jraph.pad_with_graphs(single_graph,
                                         n_node=10,
                                         n_edge=5,
                                         n_graph=4)
    logging.info("Padded graph %r", padded_graph)

    # Creates a GraphsTuple from an existing padded GraphsTuple.
    # The previously added padding is removed.
    single_graph = jraph.unpad_with_graphs(padded_graph)
    logging.info("Unpadded graph %r", single_graph)

    # Creates a GraphsTuple containing a 2 graphs using an explicit batch
    # dimension.
    # An explicit batch dimension requires more memory, but can simplify
    # the definition of functions operating on the graph.
    # Explicitly batched graphs require the GraphNetwork to be transformed
    # by jax.mask followed by jax.vmap.
    # Using an explicit batch requires padding all feature vectors to
    # the maximum size of nodes and edges.
    # The first graph has 3 nodes and 2 edges.
    # The second graph has 1 nodes and 1 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    explicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([[3], [1]]),
                                                 n_edge=np.asarray([[2], [1]]),
                                                 nodes=np.ones((2, 3, 4)),
                                                 edges=np.ones((2, 2, 5)),
                                                 globals=np.ones((2, 1, 6)),
                                                 senders=np.array([[0, 1],
                                                                   [0, -1]]),
                                                 receivers=np.array([[2, 2],
                                                                     [0, -1]]))
    logging.info("Explicitly batched graph %r", explicitly_batched_graph)

    # Running a graph propagation steps.
    # First define the update functions for the edges, nodes and globals.
    # In this example we use the identity everywhere.
    # For Graph neural networks, each update function is typically a neural
    # network.
    def update_edge_fn(edge_features, sender_node_features,
                       receiver_node_features, globals_):
        """Returns the update edge features."""
        del sender_node_features
        del receiver_node_features
        del globals_
        return edge_features

    def update_node_fn(node_features, aggregated_sender_edge_features,
                       aggregated_receiver_edge_features, globals_):
        """Returns the update node features."""
        del aggregated_sender_edge_features
        del aggregated_receiver_edge_features
        del globals_
        return node_features

    def update_globals_fn(aggregated_node_features, aggregated_edge_features,
                          globals_):
        del aggregated_node_features
        del aggregated_edge_features
        return globals_

    # Optionally define custom aggregation functions.
    # In this example we use the defaults (so no need to define them explicitly).
    aggregate_edges_for_nodes_fn = jax.ops.segment_sum
    aggregate_nodes_for_globals_fn = jax.ops.segment_sum
    aggregate_edges_for_globals_fn = jax.ops.segment_sum

    # Optionally define attention logit function and attention reduce function.
    # This can be used for graph attention.
    # The attention function calculates attention weights, and the apply
    # attention function calculates the new edge feature given the weights.
    # We don't use graph attention here, and just pass the defaults.
    attention_logit_fn = None
    attention_reduce_fn = None

    # Creates a new GraphNetwork in its most general form.
    # Most of the arguments have defaults and can be omitted if a feature
    # is not used.
    # There are also predefined GraphNetworks available (see models.py)
    network = jraph.GraphNetwork(
        update_edge_fn=update_edge_fn,
        update_node_fn=update_node_fn,
        update_global_fn=update_globals_fn,
        attention_logit_fn=attention_logit_fn,
        aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn,
        aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn,
        aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn,
        attention_reduce_fn=attention_reduce_fn)

    # Runs graph propagation on (implicitly batched) graphs.
    updated_graph = network(single_graph)
    logging.info("Updated graph from single graph %r", updated_graph)

    updated_graph = network(nested_graph)
    logging.info("Updated graph from nested graph %r", nested_graph)

    updated_graph = network(implicitly_batched_graph)
    logging.info("Updated graph from implicitly batched graph %r",
                 updated_graph)

    updated_graph = network(padded_graph)
    logging.info("Updated graph from padded graph %r", updated_graph)

    # Runs graph propagation on an explicitly batched graph.
    # WARNING: This code relies on an undocumented JAX feature (jax.mask) which
    # might stop working at any time!
    graph_shape = jraph.GraphsTuple(
        n_node="(g)",
        n_edge="(g)",
        nodes="(n, {})".format(explicitly_batched_graph.nodes.shape[-1]),
        edges="(e, {})".format(explicitly_batched_graph.edges.shape[-1]),
        globals="(g, {})".format(explicitly_batched_graph.globals.shape[-1]),
        senders="(e)",
        receivers="(e)")
    batch_size = explicitly_batched_graph.globals.shape[0]
    logical_env = {
        "g": jnp.ones(batch_size, dtype=jnp.int32),
        "n": jnp.sum(explicitly_batched_graph.n_node, axis=-1),
        "e": jnp.sum(explicitly_batched_graph.n_edge, axis=-1)
    }
    try:
        propagation_fn = jax.vmap(
            jax.mask(network, in_shapes=[graph_shape], out_shape=graph_shape))
        updated_graph = propagation_fn([explicitly_batched_graph], logical_env)
        logging.info("Updated graph from explicitly batched graph %r",
                     updated_graph)
    except Exception:  # pylint: disable=broad-except
        logging.warning(MASK_BROKEN_MSG)

    # JIT-compile graph propagation.
    # Use padded graphs to avoid re-compilation at every step!
    jitted_network = jax.jit(network)
    updated_graph = jitted_network(padded_graph)
    logging.info("(JIT) updated graph from padded graph %r", updated_graph)

    # Or use an explicit batch dimension.
    try:
        jitted_propagation_fn = jax.jit(propagation_fn)
        updated_graph = jitted_propagation_fn([explicitly_batched_graph],
                                              logical_env)
        logging.info("(JIT) Updated graph from explicitly batched graph %r",
                     updated_graph)
    except Exception:  # pylint: disable=broad-except
        logging.warning(MASK_BROKEN_MSG)

    logging.info("basic.py complete!")
示例#5
0
def _sample(eval_dataset, tokenizer, devices, batch_size=1):
    """Evaluate the graph2text transformer."""
    checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, 'checkpoint.pkl')
    logging.info('Loading checkpoint from %s', checkpoint_dir)
    with open(checkpoint_dir, 'rb') as f:
        state = pickle.load(f)

    if FLAGS.model_type == 'graph2text':
        # process list of graphs into a batch
        eval_dataset = map(
            lambda x: dict(  # pylint: disable=g-long-lambda
                obs=x['obs'],
                target=x['target'],
                should_reset=x['should_reset'],
                mask=x['mask'],
                graphs=jraph.batch(x['graphs']),
            ),
            eval_dataset)
    eval_dataset = utils.take_unique_graphs(eval_dataset, FLAGS.model_type)

    samplers = []
    for device in devices:
        sampler = utils.build_sampler(tokenizer, device=device)
        samplers.append(sampler)

    step = state['step']
    params = state['params']
    sample_logger = []

    with concurrent.futures.ThreadPoolExecutor(
            max_workers=len(samplers)) as executor:
        futures = dict()
        for sampler in samplers:
            batch = next(eval_dataset)
            prompts = utils.construct_prompts(batch['obs'],
                                              batch_size,
                                              FLAGS.sample_length,
                                              tokenizer,
                                              prompt_title=FLAGS.prompt_title)
            if FLAGS.model_type in ['graph2text', 'bow2text']:
                future = executor.submit(utils.generate_samples,
                                         params,
                                         tokenizer,
                                         sampler,
                                         model_type=FLAGS.model_type,
                                         prompts=prompts,
                                         graphs=batch['graphs'])
                futures[future] = (sampler, batch['graphs'], batch['obs'])
            else:
                future = executor.submit(utils.generate_samples,
                                         params,
                                         tokenizer,
                                         sampler,
                                         model_type=FLAGS.model_type,
                                         prompts=prompts,
                                         graphs=None)
                futures[future] = (sampler, batch['obs'])

        n_samples = 0

        while n_samples < FLAGS.num_samples:
            for future, future_items in list(futures.items()):
                if not future.done():
                    continue
                samples, tokens = future.result()
                if FLAGS.model_type == 'graph2text':
                    sampler, graphs, text = future_items
                    graphs = jraph.unbatch(graphs)
                elif FLAGS.model_type == 'bow2text':
                    sampler, graphs, text = future_items
                else:
                    sampler, text = future_items

                if FLAGS.model_type in ['graph2text', 'bow2text']:
                    for s, g, tk, txt in zip(samples, graphs, tokens, text):
                        # Only log a small fraction of the generated samples, if we are
                        # generating non-stop.  Otherwise log every sample.
                        logging.info('[step %d]', step)
                        logging.info('graph=\n%r', g)
                        logging.info('sample=\n%s', s)
                        if FLAGS.model_type == 'graph2text':
                            sample_logger.append({
                                'step': step,
                                'sample': s,
                                'sample_tokens': tk,
                                'ground_truth_text': txt,
                            })
                        elif FLAGS.model_type == 'bow2text':
                            sample_logger.append({
                                'step': step,
                                'bow': g,
                                'sample': s,
                                'sample_tokens': tk,
                                'ground_truth_text': txt,
                            })
                else:
                    for s, tk, txt in zip(samples, tokens, text):
                        # Only log a small fraction of the generated samples, if we are
                        # generating non-stop.  Otherwise log every sample.
                        logging.info('[step %d]', step)
                        logging.info('sample=\n%s', s)
                        sample_logger.append({
                            'step': step,
                            'sample': s,
                            'sample_tokens': tk,
                            'ground_truth_text': txt,
                        })
                n_samples += len(samples)
                logging.info('Finished generating %d samples', n_samples)

                del futures[future]

                if n_samples < FLAGS.num_samples:
                    batch = next(eval_dataset)
                    prompts = utils.construct_prompts(
                        batch['obs'],
                        batch_size,
                        FLAGS.sample_length,
                        tokenizer,
                        prompt_title=FLAGS.prompt_title)
                    if FLAGS.model_type in ['graph2text', 'bow2text']:
                        future = executor.submit(utils.generate_samples,
                                                 params,
                                                 tokenizer,
                                                 sampler,
                                                 model_type=FLAGS.model_type,
                                                 prompts=prompts,
                                                 graphs=batch['graphs'])
                        futures[future] = (sampler, batch['graphs'],
                                           batch['obs'])
                    else:
                        future = executor.submit(utils.generate_samples,
                                                 params,
                                                 tokenizer,
                                                 sampler,
                                                 model_type=FLAGS.model_type,
                                                 prompts=prompts,
                                                 graphs=None)
                        futures[future] = (sampler, batch['obs'])

    logging.info('Finished')
    path = os.path.join(FLAGS.checkpoint_dir, 'samples.pkl')
    with open(path, 'wb') as f:
        pickle.dump(dict(samples=sample_logger), f)
    logging.info('Samples saved to %s', path)
示例#6
0
文件: basic.py 项目: deepmind/jraph
def run():
    """Runs basic example."""

    # Creating graph tuples.

    # Creates a GraphsTuple from scratch containing a single graph.
    # The graph has 3 nodes and 2 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    single_graph = jraph.GraphsTuple(n_node=np.asarray([3]),
                                     n_edge=np.asarray([2]),
                                     nodes=np.ones((3, 4)),
                                     edges=np.ones((2, 5)),
                                     globals=np.ones((1, 6)),
                                     senders=np.array([0, 1]),
                                     receivers=np.array([2, 2]))
    logging.info("Single graph %r", single_graph)

    # Creates a GraphsTuple from scratch containing a single graph with nested
    # feature vectors.
    # The graph has 3 nodes and 2 edges.
    # The feature vector can be arbitrary nested types of dict, list and tuple,
    # or any other type you registered with jax.tree_util.register_pytree_node.
    nested_graph = jraph.GraphsTuple(n_node=np.asarray([3]),
                                     n_edge=np.asarray([2]),
                                     nodes={"a": np.ones((3, 4))},
                                     edges={"b": np.ones((2, 5))},
                                     globals={"c": np.ones((1, 6))},
                                     senders=np.array([0, 1]),
                                     receivers=np.array([2, 2]))
    logging.info("Nested graph %r", nested_graph)

    # Creates a GraphsTuple from scratch containing a 2 graphs using an implicit
    # batch dimension.
    # The first graph has 3 nodes and 2 edges.
    # The second graph has 1 nodes and 1 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    implicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([3, 1]),
                                                 n_edge=np.asarray([2, 1]),
                                                 nodes=np.ones((4, 4)),
                                                 edges=np.ones((3, 5)),
                                                 globals=np.ones((2, 6)),
                                                 senders=np.array([0, 1, 3]),
                                                 receivers=np.array([2, 2, 3]))
    logging.info("Implicitly batched graph %r", implicitly_batched_graph)

    # Batching graphs can be challenging. There are in general two approaches:
    # 1. Implicit batching: Independent graphs are combined into the same
    #    GraphsTuple first, and the padding is added to the combined graph.
    # 2. Explicit batching: Pad all graphs to a maximum size, stack them together
    #    using an explicit batch dimension followed by jax.vmap.
    # Both approaches are shown below.

    # Creates a GraphsTuple from two existing GraphsTuple using an implicit
    # batch dimension.
    # The GraphsTuple will contain three graphs.
    implicitly_batched_graph = jraph.batch(
        [single_graph, implicitly_batched_graph])
    logging.info("Implicitly batched graph %r", implicitly_batched_graph)

    # Creates multiple GraphsTuples from an existing GraphsTuple with an implicit
    # batch dimension.
    graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph)
    logging.info("Unbatched graphs %r %r %r", graph_1, graph_2, graph_3)

    # Creates a padded GraphsTuple from an existing GraphsTuple.
    # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs.
    # Three graphs are added for the padding.
    # First an dummy graph which contains the padding nodes and edges and secondly
    # two empty graphs without nodes or edges to pad out the graphs.
    padded_graph = jraph.pad_with_graphs(single_graph,
                                         n_node=10,
                                         n_edge=5,
                                         n_graph=4)
    logging.info("Padded graph %r", padded_graph)

    # Creates a GraphsTuple from an existing padded GraphsTuple.
    # The previously added padding is removed.
    single_graph = jraph.unpad_with_graphs(padded_graph)
    logging.info("Unpadded graph %r", single_graph)

    # Creates a GraphsTuple containing a 2 graphs using an explicit batch
    # dimension.
    # An explicit batch dimension requires more memory, but can simplify
    # the definition of functions operating on the graph.
    # Explicitly batched graphs require the GraphNetwork to be transformed
    # by jax.vmap.
    # Using an explicit batch requires padding all feature vectors to
    # the maximum size of nodes and edges.
    # The first graph has 3 nodes and 2 edges.
    # The second graph has 1 nodes and 1 edges.
    # Each node has a 4-dimensional feature vector.
    # Each edge has a 5-dimensional feature vector.
    # The graph itself has a 6-dimensional feature vector.
    explicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([[3], [1]]),
                                                 n_edge=np.asarray([[2], [1]]),
                                                 nodes=np.ones((2, 3, 4)),
                                                 edges=np.ones((2, 2, 5)),
                                                 globals=np.ones((2, 1, 6)),
                                                 senders=np.array([[0, 1],
                                                                   [0, -1]]),
                                                 receivers=np.array([[2, 2],
                                                                     [0, -1]]))
    logging.info("Explicitly batched graph %r", explicitly_batched_graph)

    # Running a graph propagation steps.
    # First define the update functions for the edges, nodes and globals.
    # In this example we use the identity everywhere.
    # For Graph neural networks, each update function is typically a neural
    # network.
    def update_edge_fn(edge_features, sender_node_features,
                       receiver_node_features, globals_):
        """Returns the update edge features."""
        del sender_node_features
        del receiver_node_features
        del globals_
        return edge_features

    def update_node_fn(node_features, aggregated_sender_edge_features,
                       aggregated_receiver_edge_features, globals_):
        """Returns the update node features."""
        del aggregated_sender_edge_features
        del aggregated_receiver_edge_features
        del globals_
        return node_features

    def update_globals_fn(aggregated_node_features, aggregated_edge_features,
                          globals_):
        del aggregated_node_features
        del aggregated_edge_features
        return globals_

    # Optionally define custom aggregation functions.
    # In this example we use the defaults (so no need to define them explicitly).
    aggregate_edges_for_nodes_fn = jraph.segment_sum
    aggregate_nodes_for_globals_fn = jraph.segment_sum
    aggregate_edges_for_globals_fn = jraph.segment_sum

    # Optionally define attention logit function and attention reduce function.
    # This can be used for graph attention.
    # The attention function calculates attention weights, and the apply
    # attention function calculates the new edge feature given the weights.
    # We don't use graph attention here, and just pass the defaults.
    attention_logit_fn = None
    attention_reduce_fn = None

    # Creates a new GraphNetwork in its most general form.
    # Most of the arguments have defaults and can be omitted if a feature
    # is not used.
    # There are also predefined GraphNetworks available (see models.py)
    network = jraph.GraphNetwork(
        update_edge_fn=update_edge_fn,
        update_node_fn=update_node_fn,
        update_global_fn=update_globals_fn,
        attention_logit_fn=attention_logit_fn,
        aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn,
        aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn,
        aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn,
        attention_reduce_fn=attention_reduce_fn)

    # Runs graph propagation on (implicitly batched) graphs.
    updated_graph = network(single_graph)
    logging.info("Updated graph from single graph %r", updated_graph)

    updated_graph = network(nested_graph)
    logging.info("Updated graph from nested graph %r", nested_graph)

    updated_graph = network(implicitly_batched_graph)
    logging.info("Updated graph from implicitly batched graph %r",
                 updated_graph)

    updated_graph = network(padded_graph)
    logging.info("Updated graph from padded graph %r", updated_graph)

    # JIT-compile graph propagation.
    # Use padded graphs to avoid re-compilation at every step!
    jitted_network = jax.jit(network)
    updated_graph = jitted_network(padded_graph)
    logging.info("(JIT) updated graph from padded graph %r", updated_graph)
    logging.info("basic.py complete!")