def batch_graphs_by_device(graphs: List[jraph.GraphsTuple], num_devices: int) -> List[jraph.GraphsTuple]: """Batch a list of graphs into num_devices batched graphs. The input graphs are grouped into num_devices groups. Within each group the graphs are merged. This is needed for parallelizing the graphs using pmap. Args: graphs: a list of graphs to be merged. num_devices: the number of local devices. Returns: graph: a size num_devices list of merged graphs. """ bs = len(graphs) assert bs % num_devices == 0, ( 'Batch size {} is not divisible by {} devices.'.format( bs, num_devices)) bs_per_device = bs // num_devices graphs_on_devices = [] for i in range(num_devices): graphs_on_device_i = graphs[i * bs_per_device:(i + 1) * bs_per_device] graphs_on_device_i = jraph.batch(graphs_on_device_i) graphs_on_devices.append(graphs_on_device_i) return graphs_on_devices
def preprocess(batch, model_type, num_devices=1): """Preprocess the batch before sending to the model.""" if model_type == 'text': if 'graphs' in batch: del batch['graphs'] elif model_type == 'bow2text': # Do nothing, bow2text data is already in a good form. pass else: # graph2text if num_devices == 1: graphs = gn.pad_graphs(jraph.batch(batch['graphs'])) else: # We need to first batch graphs into num_devices batchs. graphs = gn.batch_graphs_by_device(batch['graphs'], num_devices) # Then we pad them to the maximum graph size in the batch and concat. # This way graphs can be distributed to each device through pmap. graphs = gn.pad_graphs_by_device(graphs) max_graph_size = gn.pad_size(graphs.n_node.max()) batch.update({'graphs': graphs, 'max_graph_size': max_graph_size}) return batch
def _get_batch(n_nodes_list, batch_size): # Hardcode batch_size to be 2 for simplicity del batch_size graphs_list = [] labels_list = [] weights_list = [] for n_nodes in n_nodes_list: n_edges = n_nodes**2 graph = jraph.get_fully_connected_graph( n_nodes, 1, np.ones((n_nodes, *hps.input_node_shape))) graph = graph._replace(edges=np.ones((n_edges, *hps.input_edge_shape))) labels = np.ones( hps.output_shape) * (1 if n_nodes in [4, 6] else 0) weights = np.ones(*hps.output_shape) graphs_list.append(graph) labels_list.append(labels) weights_list.append(weights) return { 'inputs': jraph.batch(graphs_list), 'targets': np.stack(labels_list), 'weights': np.stack(weights_list), }
def run(): """Runs basic example.""" # Creating graph tuples. # Creates a GraphsTuple from scratch containing a single graph. # The graph has 3 nodes and 2 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. single_graph = jraph.GraphsTuple(n_node=np.asarray([3]), n_edge=np.asarray([2]), nodes=np.ones((3, 4)), edges=np.ones((2, 5)), globals=np.ones((1, 6)), senders=np.array([0, 1]), receivers=np.array([2, 2])) logging.info("Single graph %r", single_graph) # Creates a GraphsTuple from scatch containing a single graph with nested # feature vectors. # The graph has 3 nodes and 2 edges. # The feature vector can be arbitrary nested types of dict, list and tuple, # or any other type you registered with jax.tree_util.register_pytree_node. nested_graph = jraph.GraphsTuple(n_node=np.asarray([3]), n_edge=np.asarray([2]), nodes={"a": np.ones((3, 4))}, edges={"b": np.ones((2, 5))}, globals={"c": np.ones((1, 6))}, senders=np.array([0, 1]), receivers=np.array([2, 2])) logging.info("Nested graph %r", nested_graph) # Creates a GraphsTuple from scratch containing a 2 graphs using an implicit # batch dimension. # The first graph has 3 nodes and 2 edges. # The second graph has 1 nodes and 1 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. implicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([3, 1]), n_edge=np.asarray([2, 1]), nodes=np.ones((4, 4)), edges=np.ones((3, 5)), globals=np.ones((2, 6)), senders=np.array([0, 1, 3]), receivers=np.array([2, 2, 3])) logging.info("Implicitly batched graph %r", implicitly_batched_graph) # Creates a GraphsTuple from two existing GraphsTuple using an implicit # batch dimension. # The GraphsTuple will contain three graphs. implicitly_batched_graph = jraph.batch( [single_graph, implicitly_batched_graph]) logging.info("Implicitly batched graph %r", implicitly_batched_graph) # Creates multiple GraphsTuples from an existing GraphsTuple with an implicit # batch dimension. graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph) logging.info("Unbatched graphs %r %r %r", graph_1, graph_2, graph_3) # Creates a padded GraphsTuple from an existing GraphsTuple. # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs. # Three graphs are added for the padding. # First an dummy graph which contains the padding nodes and edges and secondly # two empty graphs without nodes or edges to pad out the graphs. padded_graph = jraph.pad_with_graphs(single_graph, n_node=10, n_edge=5, n_graph=4) logging.info("Padded graph %r", padded_graph) # Creates a GraphsTuple from an existing padded GraphsTuple. # The previously added padding is removed. single_graph = jraph.unpad_with_graphs(padded_graph) logging.info("Unpadded graph %r", single_graph) # Creates a GraphsTuple containing a 2 graphs using an explicit batch # dimension. # An explicit batch dimension requires more memory, but can simplify # the definition of functions operating on the graph. # Explicitly batched graphs require the GraphNetwork to be transformed # by jax.mask followed by jax.vmap. # Using an explicit batch requires padding all feature vectors to # the maximum size of nodes and edges. # The first graph has 3 nodes and 2 edges. # The second graph has 1 nodes and 1 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. explicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([[3], [1]]), n_edge=np.asarray([[2], [1]]), nodes=np.ones((2, 3, 4)), edges=np.ones((2, 2, 5)), globals=np.ones((2, 1, 6)), senders=np.array([[0, 1], [0, -1]]), receivers=np.array([[2, 2], [0, -1]])) logging.info("Explicitly batched graph %r", explicitly_batched_graph) # Running a graph propagation steps. # First define the update functions for the edges, nodes and globals. # In this example we use the identity everywhere. # For Graph neural networks, each update function is typically a neural # network. def update_edge_fn(edge_features, sender_node_features, receiver_node_features, globals_): """Returns the update edge features.""" del sender_node_features del receiver_node_features del globals_ return edge_features def update_node_fn(node_features, aggregated_sender_edge_features, aggregated_receiver_edge_features, globals_): """Returns the update node features.""" del aggregated_sender_edge_features del aggregated_receiver_edge_features del globals_ return node_features def update_globals_fn(aggregated_node_features, aggregated_edge_features, globals_): del aggregated_node_features del aggregated_edge_features return globals_ # Optionally define custom aggregation functions. # In this example we use the defaults (so no need to define them explicitly). aggregate_edges_for_nodes_fn = jax.ops.segment_sum aggregate_nodes_for_globals_fn = jax.ops.segment_sum aggregate_edges_for_globals_fn = jax.ops.segment_sum # Optionally define attention logit function and attention reduce function. # This can be used for graph attention. # The attention function calculates attention weights, and the apply # attention function calculates the new edge feature given the weights. # We don't use graph attention here, and just pass the defaults. attention_logit_fn = None attention_reduce_fn = None # Creates a new GraphNetwork in its most general form. # Most of the arguments have defaults and can be omitted if a feature # is not used. # There are also predefined GraphNetworks available (see models.py) network = jraph.GraphNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, update_global_fn=update_globals_fn, attention_logit_fn=attention_logit_fn, aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn, aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn, aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn, attention_reduce_fn=attention_reduce_fn) # Runs graph propagation on (implicitly batched) graphs. updated_graph = network(single_graph) logging.info("Updated graph from single graph %r", updated_graph) updated_graph = network(nested_graph) logging.info("Updated graph from nested graph %r", nested_graph) updated_graph = network(implicitly_batched_graph) logging.info("Updated graph from implicitly batched graph %r", updated_graph) updated_graph = network(padded_graph) logging.info("Updated graph from padded graph %r", updated_graph) # Runs graph propagation on an explicitly batched graph. # WARNING: This code relies on an undocumented JAX feature (jax.mask) which # might stop working at any time! graph_shape = jraph.GraphsTuple( n_node="(g)", n_edge="(g)", nodes="(n, {})".format(explicitly_batched_graph.nodes.shape[-1]), edges="(e, {})".format(explicitly_batched_graph.edges.shape[-1]), globals="(g, {})".format(explicitly_batched_graph.globals.shape[-1]), senders="(e)", receivers="(e)") batch_size = explicitly_batched_graph.globals.shape[0] logical_env = { "g": jnp.ones(batch_size, dtype=jnp.int32), "n": jnp.sum(explicitly_batched_graph.n_node, axis=-1), "e": jnp.sum(explicitly_batched_graph.n_edge, axis=-1) } try: propagation_fn = jax.vmap( jax.mask(network, in_shapes=[graph_shape], out_shape=graph_shape)) updated_graph = propagation_fn([explicitly_batched_graph], logical_env) logging.info("Updated graph from explicitly batched graph %r", updated_graph) except Exception: # pylint: disable=broad-except logging.warning(MASK_BROKEN_MSG) # JIT-compile graph propagation. # Use padded graphs to avoid re-compilation at every step! jitted_network = jax.jit(network) updated_graph = jitted_network(padded_graph) logging.info("(JIT) updated graph from padded graph %r", updated_graph) # Or use an explicit batch dimension. try: jitted_propagation_fn = jax.jit(propagation_fn) updated_graph = jitted_propagation_fn([explicitly_batched_graph], logical_env) logging.info("(JIT) Updated graph from explicitly batched graph %r", updated_graph) except Exception: # pylint: disable=broad-except logging.warning(MASK_BROKEN_MSG) logging.info("basic.py complete!")
def _sample(eval_dataset, tokenizer, devices, batch_size=1): """Evaluate the graph2text transformer.""" checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, 'checkpoint.pkl') logging.info('Loading checkpoint from %s', checkpoint_dir) with open(checkpoint_dir, 'rb') as f: state = pickle.load(f) if FLAGS.model_type == 'graph2text': # process list of graphs into a batch eval_dataset = map( lambda x: dict( # pylint: disable=g-long-lambda obs=x['obs'], target=x['target'], should_reset=x['should_reset'], mask=x['mask'], graphs=jraph.batch(x['graphs']), ), eval_dataset) eval_dataset = utils.take_unique_graphs(eval_dataset, FLAGS.model_type) samplers = [] for device in devices: sampler = utils.build_sampler(tokenizer, device=device) samplers.append(sampler) step = state['step'] params = state['params'] sample_logger = [] with concurrent.futures.ThreadPoolExecutor( max_workers=len(samplers)) as executor: futures = dict() for sampler in samplers: batch = next(eval_dataset) prompts = utils.construct_prompts(batch['obs'], batch_size, FLAGS.sample_length, tokenizer, prompt_title=FLAGS.prompt_title) if FLAGS.model_type in ['graph2text', 'bow2text']: future = executor.submit(utils.generate_samples, params, tokenizer, sampler, model_type=FLAGS.model_type, prompts=prompts, graphs=batch['graphs']) futures[future] = (sampler, batch['graphs'], batch['obs']) else: future = executor.submit(utils.generate_samples, params, tokenizer, sampler, model_type=FLAGS.model_type, prompts=prompts, graphs=None) futures[future] = (sampler, batch['obs']) n_samples = 0 while n_samples < FLAGS.num_samples: for future, future_items in list(futures.items()): if not future.done(): continue samples, tokens = future.result() if FLAGS.model_type == 'graph2text': sampler, graphs, text = future_items graphs = jraph.unbatch(graphs) elif FLAGS.model_type == 'bow2text': sampler, graphs, text = future_items else: sampler, text = future_items if FLAGS.model_type in ['graph2text', 'bow2text']: for s, g, tk, txt in zip(samples, graphs, tokens, text): # Only log a small fraction of the generated samples, if we are # generating non-stop. Otherwise log every sample. logging.info('[step %d]', step) logging.info('graph=\n%r', g) logging.info('sample=\n%s', s) if FLAGS.model_type == 'graph2text': sample_logger.append({ 'step': step, 'sample': s, 'sample_tokens': tk, 'ground_truth_text': txt, }) elif FLAGS.model_type == 'bow2text': sample_logger.append({ 'step': step, 'bow': g, 'sample': s, 'sample_tokens': tk, 'ground_truth_text': txt, }) else: for s, tk, txt in zip(samples, tokens, text): # Only log a small fraction of the generated samples, if we are # generating non-stop. Otherwise log every sample. logging.info('[step %d]', step) logging.info('sample=\n%s', s) sample_logger.append({ 'step': step, 'sample': s, 'sample_tokens': tk, 'ground_truth_text': txt, }) n_samples += len(samples) logging.info('Finished generating %d samples', n_samples) del futures[future] if n_samples < FLAGS.num_samples: batch = next(eval_dataset) prompts = utils.construct_prompts( batch['obs'], batch_size, FLAGS.sample_length, tokenizer, prompt_title=FLAGS.prompt_title) if FLAGS.model_type in ['graph2text', 'bow2text']: future = executor.submit(utils.generate_samples, params, tokenizer, sampler, model_type=FLAGS.model_type, prompts=prompts, graphs=batch['graphs']) futures[future] = (sampler, batch['graphs'], batch['obs']) else: future = executor.submit(utils.generate_samples, params, tokenizer, sampler, model_type=FLAGS.model_type, prompts=prompts, graphs=None) futures[future] = (sampler, batch['obs']) logging.info('Finished') path = os.path.join(FLAGS.checkpoint_dir, 'samples.pkl') with open(path, 'wb') as f: pickle.dump(dict(samples=sample_logger), f) logging.info('Samples saved to %s', path)
def run(): """Runs basic example.""" # Creating graph tuples. # Creates a GraphsTuple from scratch containing a single graph. # The graph has 3 nodes and 2 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. single_graph = jraph.GraphsTuple(n_node=np.asarray([3]), n_edge=np.asarray([2]), nodes=np.ones((3, 4)), edges=np.ones((2, 5)), globals=np.ones((1, 6)), senders=np.array([0, 1]), receivers=np.array([2, 2])) logging.info("Single graph %r", single_graph) # Creates a GraphsTuple from scratch containing a single graph with nested # feature vectors. # The graph has 3 nodes and 2 edges. # The feature vector can be arbitrary nested types of dict, list and tuple, # or any other type you registered with jax.tree_util.register_pytree_node. nested_graph = jraph.GraphsTuple(n_node=np.asarray([3]), n_edge=np.asarray([2]), nodes={"a": np.ones((3, 4))}, edges={"b": np.ones((2, 5))}, globals={"c": np.ones((1, 6))}, senders=np.array([0, 1]), receivers=np.array([2, 2])) logging.info("Nested graph %r", nested_graph) # Creates a GraphsTuple from scratch containing a 2 graphs using an implicit # batch dimension. # The first graph has 3 nodes and 2 edges. # The second graph has 1 nodes and 1 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. implicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([3, 1]), n_edge=np.asarray([2, 1]), nodes=np.ones((4, 4)), edges=np.ones((3, 5)), globals=np.ones((2, 6)), senders=np.array([0, 1, 3]), receivers=np.array([2, 2, 3])) logging.info("Implicitly batched graph %r", implicitly_batched_graph) # Batching graphs can be challenging. There are in general two approaches: # 1. Implicit batching: Independent graphs are combined into the same # GraphsTuple first, and the padding is added to the combined graph. # 2. Explicit batching: Pad all graphs to a maximum size, stack them together # using an explicit batch dimension followed by jax.vmap. # Both approaches are shown below. # Creates a GraphsTuple from two existing GraphsTuple using an implicit # batch dimension. # The GraphsTuple will contain three graphs. implicitly_batched_graph = jraph.batch( [single_graph, implicitly_batched_graph]) logging.info("Implicitly batched graph %r", implicitly_batched_graph) # Creates multiple GraphsTuples from an existing GraphsTuple with an implicit # batch dimension. graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph) logging.info("Unbatched graphs %r %r %r", graph_1, graph_2, graph_3) # Creates a padded GraphsTuple from an existing GraphsTuple. # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs. # Three graphs are added for the padding. # First an dummy graph which contains the padding nodes and edges and secondly # two empty graphs without nodes or edges to pad out the graphs. padded_graph = jraph.pad_with_graphs(single_graph, n_node=10, n_edge=5, n_graph=4) logging.info("Padded graph %r", padded_graph) # Creates a GraphsTuple from an existing padded GraphsTuple. # The previously added padding is removed. single_graph = jraph.unpad_with_graphs(padded_graph) logging.info("Unpadded graph %r", single_graph) # Creates a GraphsTuple containing a 2 graphs using an explicit batch # dimension. # An explicit batch dimension requires more memory, but can simplify # the definition of functions operating on the graph. # Explicitly batched graphs require the GraphNetwork to be transformed # by jax.vmap. # Using an explicit batch requires padding all feature vectors to # the maximum size of nodes and edges. # The first graph has 3 nodes and 2 edges. # The second graph has 1 nodes and 1 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. explicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([[3], [1]]), n_edge=np.asarray([[2], [1]]), nodes=np.ones((2, 3, 4)), edges=np.ones((2, 2, 5)), globals=np.ones((2, 1, 6)), senders=np.array([[0, 1], [0, -1]]), receivers=np.array([[2, 2], [0, -1]])) logging.info("Explicitly batched graph %r", explicitly_batched_graph) # Running a graph propagation steps. # First define the update functions for the edges, nodes and globals. # In this example we use the identity everywhere. # For Graph neural networks, each update function is typically a neural # network. def update_edge_fn(edge_features, sender_node_features, receiver_node_features, globals_): """Returns the update edge features.""" del sender_node_features del receiver_node_features del globals_ return edge_features def update_node_fn(node_features, aggregated_sender_edge_features, aggregated_receiver_edge_features, globals_): """Returns the update node features.""" del aggregated_sender_edge_features del aggregated_receiver_edge_features del globals_ return node_features def update_globals_fn(aggregated_node_features, aggregated_edge_features, globals_): del aggregated_node_features del aggregated_edge_features return globals_ # Optionally define custom aggregation functions. # In this example we use the defaults (so no need to define them explicitly). aggregate_edges_for_nodes_fn = jraph.segment_sum aggregate_nodes_for_globals_fn = jraph.segment_sum aggregate_edges_for_globals_fn = jraph.segment_sum # Optionally define attention logit function and attention reduce function. # This can be used for graph attention. # The attention function calculates attention weights, and the apply # attention function calculates the new edge feature given the weights. # We don't use graph attention here, and just pass the defaults. attention_logit_fn = None attention_reduce_fn = None # Creates a new GraphNetwork in its most general form. # Most of the arguments have defaults and can be omitted if a feature # is not used. # There are also predefined GraphNetworks available (see models.py) network = jraph.GraphNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, update_global_fn=update_globals_fn, attention_logit_fn=attention_logit_fn, aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn, aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn, aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn, attention_reduce_fn=attention_reduce_fn) # Runs graph propagation on (implicitly batched) graphs. updated_graph = network(single_graph) logging.info("Updated graph from single graph %r", updated_graph) updated_graph = network(nested_graph) logging.info("Updated graph from nested graph %r", nested_graph) updated_graph = network(implicitly_batched_graph) logging.info("Updated graph from implicitly batched graph %r", updated_graph) updated_graph = network(padded_graph) logging.info("Updated graph from padded graph %r", updated_graph) # JIT-compile graph propagation. # Use padded graphs to avoid re-compilation at every step! jitted_network = jax.jit(network) updated_graph = jitted_network(padded_graph) logging.info("(JIT) updated graph from padded graph %r", updated_graph) logging.info("basic.py complete!")