def _get_random_graph(max_n_graph=10): n_graph = np.random.randint(1, max_n_graph + 1) n_node = np.random.randint(0, 10, n_graph) n_edge = np.random.randint(0, 20, n_graph) # We cannot have any edges if there are no nodes. n_edge[n_node == 0] = 0 senders = [] receivers = [] offset = 0 for n_node_in_graph, n_edge_in_graph in zip(n_node, n_edge): if n_edge_in_graph != 0: senders += list( np.random.randint(0, n_node_in_graph, n_edge_in_graph) + offset) receivers += list( np.random.randint(0, n_node_in_graph, n_edge_in_graph) + offset) offset += n_node_in_graph return graph.GraphsTuple( n_node=jnp.asarray(n_node), n_edge=jnp.asarray(n_edge), nodes=jnp.asarray(np.random.random(size=(np.sum(n_node), 4))), edges=jnp.asarray(np.random.random(size=(np.sum(n_edge), 3))), globals=jnp.asarray(np.random.random(size=(n_graph, 5))), senders=jnp.asarray(senders), receivers=jnp.asarray(receivers))
def test_pad_with_graphs(self): """Tests padding of graph.""" _, graphs_tuple = _get_list_and_batched_graph() padded_graphs_tuple = utils.pad_with_graphs(graphs_tuple, 10, 12, 9) expected_padded_graph = graph.GraphsTuple( n_node=jnp.concatenate([graphs_tuple.n_node, jnp.array([3, 0])]), n_edge=jnp.concatenate([graphs_tuple.n_edge, jnp.array([4, 0])]), nodes=tree.tree_map( lambda f: jnp.concatenate( [f, jnp.zeros((3, 2), dtype=f.dtype)]), graphs_tuple.nodes), edges=tree.tree_map( lambda f: jnp.concatenate( [f, jnp.zeros((4, 3), dtype=f.dtype)]), graphs_tuple.edges), globals=tree.tree_map( lambda f: jnp.concatenate( [f, jnp.zeros((2, 2), dtype=f.dtype)]), graphs_tuple.globals), senders=jnp.concatenate( [graphs_tuple.senders, jnp.array([7, 7, 7, 7])]), receivers=jnp.concatenate( [graphs_tuple.receivers, jnp.array([7, 7, 7, 7])]), ) self.assertAllClose(padded_graphs_tuple, expected_padded_graph, check_dtypes=True)
def unpad_with_graphs( padded_graph: gn_graph.GraphsTuple) -> gn_graph.GraphsTuple: """Unpads the given graph by removing the dummy graph and empty graphs. This function assumes that the given graph was padded with the ``pad_with_graphs`` function. This function does not support jax.jit, because the shape of the output is data-dependent! Args: padded_graph: ``GraphsTuple`` padded with a dummy graph and empty graphs. Returns: The unpadded graph. """ n_padding_graph = get_number_of_padding_with_graphs_graphs(padded_graph) n_padding_node = get_number_of_padding_with_graphs_nodes(padded_graph) n_padding_edge = get_number_of_padding_with_graphs_edges(padded_graph) unpadded_graph = gn_graph.GraphsTuple( n_node=padded_graph.n_node[:-n_padding_graph], n_edge=padded_graph.n_edge[:-n_padding_graph], nodes=tree.tree_map(lambda x: x[:-n_padding_node], padded_graph.nodes), edges=tree.tree_map(lambda x: x[:-n_padding_edge], padded_graph.edges), globals=tree.tree_map(lambda x: x[:-n_padding_graph], padded_graph.globals), senders=padded_graph.senders[:-n_padding_edge], receivers=padded_graph.receivers[:-n_padding_edge], ) return unpadded_graph
def get_fully_connected_graph(n_node_per_graph: int, n_graph: int, node_features: Optional[ArrayTree] = None, global_features: Optional[ArrayTree] = None, add_self_edges: bool = True): """Gets a fully connected graph given n_node_per_graph and n_graph. This method is jittable. Args: n_node_per_graph: The number of nodes in each graph. n_graph: The number of graphs in the `jraph.GraphsTuple`. node_features: Optional node features. global_features: Optional global features. add_self_edges: Whether to add self edges to the graph. Returns: `jraph.GraphsTuple` """ if node_features is not None: num_node_features = jax.tree_leaves(node_features)[0].shape[0] if n_node_per_graph * n_graph != num_node_features: raise ValueError( 'Number of nodes is not equal to num_nodes_per_graph * n_graph.' ) if global_features is not None: if n_graph != jax.tree_leaves(global_features)[0].shape[0]: raise ValueError('The number of globals is not equal to n_graph.') senders = [] receivers = [] n_edge = [] tmp_senders, tmp_receivers = jnp.meshgrid(jnp.arange(n_node_per_graph), jnp.arange(n_node_per_graph)) if not add_self_edges: tmp_senders = jax.vmap(jnp.roll)(tmp_senders, jnp.arange(len(tmp_senders)))[:, 1:] tmp_receivers = tmp_receivers[:, 1:] # Flatten the senders and receivers. tmp_senders = tmp_senders.flatten() tmp_receivers = tmp_receivers.flatten() for graph_idx in range(n_graph): offset = graph_idx * n_node_per_graph senders.append(tmp_senders + offset) receivers.append(tmp_receivers + offset) n_edge.append(len(tmp_senders)) return gn_graph.GraphsTuple( nodes=node_features, edges=None, n_node=jnp.array([n_node_per_graph] * n_graph), n_edge=jnp.array(n_edge) if n_edge else jnp.array([0]), senders=jnp.concatenate(senders) if senders else senders, receivers=jnp.concatenate(receivers) if receivers else receivers, globals=global_features)
def test_unpad(self): """Tests unpadding of graph.""" _, graphs_tuple = _get_list_and_batched_graph() unpadded_graphs_tuple = utils.unpad_with_graphs(graphs_tuple) expected_unpadded_graph = graph.GraphsTuple( n_node=jnp.array([1, 3, 1, 0]), n_edge=jnp.array([2, 5, 0, 0]), nodes=_make_nest(jnp.arange(10).reshape(5, 2)), edges=_make_nest(jnp.arange(21).reshape(7, 3)), globals=_make_nest(jnp.arange(8).reshape(4, 2)), senders=jnp.array([0, 0, 1, 1, 2, 3, 3]), receivers=jnp.array([0, 0, 2, 1, 3, 2, 1])) self.assertAllClose(unpadded_graphs_tuple, expected_unpadded_graph, check_dtypes=True)
def test_connect_gnns(self, network_fn): batched_graphs_tuple = graph.GraphsTuple( n_node=jnp.array([1, 3, 1, 0, 2, 0, 0]), n_edge=jnp.array([1, 7, 1, 0, 3, 0, 0]), nodes=jnp.arange(14).reshape(7, 2), edges=jnp.arange(36).reshape(12, 3), globals=jnp.arange(14).reshape(7, 2), senders=jnp.array([0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 3, 6]), receivers=jnp.array([0, 1, 2, 3, 4, 5, 6, 2, 3, 2, 1, 5])) with self.subTest('nojit'): out, expected_out = network_fn(batched_graphs_tuple) jax.tree_util.tree_map(np.testing.assert_allclose, out, expected_out) with self.subTest('jit'): out, expected_out = jax.jit(network_fn)(batched_graphs_tuple) jax.tree_util.tree_map(np.testing.assert_allclose, out, expected_out)
def _get_random_graph(max_n_graph=10, include_node_features=True, include_edge_features=True, include_globals=True): n_graph = np.random.randint(1, max_n_graph + 1) n_node = np.random.randint(0, 10, n_graph) n_edge = np.random.randint(0, 20, n_graph) # We cannot have any edges if there are no nodes. n_edge[n_node == 0] = 0 senders = [] receivers = [] offset = 0 for n_node_in_graph, n_edge_in_graph in zip(n_node, n_edge): if n_edge_in_graph != 0: senders += list( np.random.randint(0, n_node_in_graph, n_edge_in_graph) + offset) receivers += list( np.random.randint(0, n_node_in_graph, n_edge_in_graph) + offset) offset += n_node_in_graph if include_globals: global_features = jnp.asarray(np.random.random(size=(n_graph, 5))) else: global_features = None if include_node_features: nodes = jnp.asarray(np.random.random(size=(np.sum(n_node), 4))) else: nodes = None if include_edge_features: edges = jnp.asarray(np.random.random(size=(np.sum(n_edge), 3))) else: edges = None return graph.GraphsTuple(n_node=jnp.asarray(n_node), n_edge=jnp.asarray(n_edge), nodes=nodes, edges=edges, globals=global_features, senders=jnp.asarray(senders), receivers=jnp.asarray(receivers))
def _batch(graphs, np_): """Returns batched graph given a list of graphs and a numpy-like module.""" # Calculates offsets for sender and receiver arrays, caused by concatenating # the nodes arrays. offsets = np_.cumsum( np_.array([0] + [np_.sum(g.n_node) for g in graphs[:-1]])) def _map_concat(nests): concat = lambda *args: np_.concatenate(args) return tree.tree_multimap(concat, *nests) return gn_graph.GraphsTuple( n_node=np_.concatenate([g.n_node for g in graphs]), n_edge=np_.concatenate([g.n_edge for g in graphs]), nodes=_map_concat([g.nodes for g in graphs]), edges=_map_concat([g.edges for g in graphs]), globals=_map_concat([g.globals for g in graphs]), senders=np_.concatenate( [g.senders + o for g, o in zip(graphs, offsets)]), receivers=np_.concatenate( [g.receivers + o for g, o in zip(graphs, offsets)]))
def test_connect_graphnetwork_nones(self, network_fn): batched_graphs_tuple = graph.GraphsTuple( n_node=jnp.array([1, 3, 1, 0, 2, 0, 0]), n_edge=jnp.array([2, 5, 0, 0, 1, 0, 0]), nodes=self._make_nest(jnp.arange(14).reshape(7, 2)), edges=self._make_nest(jnp.arange(24).reshape(8, 3)), globals=self._make_nest(jnp.arange(14).reshape(7, 2)), senders=jnp.array([0, 0, 1, 1, 2, 3, 3, 6]), receivers=jnp.array([0, 0, 2, 1, 3, 2, 1, 5])) for name, graphs_tuple in [ ('no_globals', batched_graphs_tuple._replace(globals=None)), ('empty_globals', batched_graphs_tuple._replace(globals=[])), ('no_edges', batched_graphs_tuple._replace(edges=None)), ('empty_edges', batched_graphs_tuple._replace(edges=[])), ]: with self.subTest(name + '_nojit'): out = network_fn(graphs_tuple) jax.tree_util.tree_map(np.testing.assert_allclose, out, graphs_tuple) with self.subTest(name + '_jit'): out = jax.jit(network_fn)(graphs_tuple) jax.tree_util.tree_map(np.testing.assert_allclose, out, graphs_tuple)
def _get_list_and_batched_graph(): """Returns a list of individual graphs and a batched version. This test-case includes the following corner-cases: - single node, - multiple nodes, - no edges, - single edge, - and multiple edges. """ batched_graph = graph.GraphsTuple( n_node=jnp.array([1, 3, 1, 0, 2, 0, 0]), n_edge=jnp.array([2, 5, 0, 0, 1, 0, 0]), nodes=_make_nest(jnp.arange(14).reshape(7, 2)), edges=_make_nest(jnp.arange(24).reshape(8, 3)), globals=_make_nest(jnp.arange(14).reshape(7, 2)), senders=jnp.array([0, 0, 1, 1, 2, 3, 3, 6]), receivers=jnp.array([0, 0, 2, 1, 3, 2, 1, 5])) list_graphs = [ graph.GraphsTuple(n_node=jnp.array([1]), n_edge=jnp.array([2]), nodes=_make_nest(jnp.array([[0, 1]])), edges=_make_nest(jnp.array([[0, 1, 2], [3, 4, 5]])), globals=_make_nest(jnp.array([[0, 1]])), senders=jnp.array([0, 0]), receivers=jnp.array([0, 0])), graph.GraphsTuple(n_node=jnp.array([3]), n_edge=jnp.array([5]), nodes=_make_nest(jnp.array([[2, 3], [4, 5], [6, 7]])), edges=_make_nest( jnp.array([[6, 7, 8], [9, 10, 11], [12, 13, 14], [15, 16, 17], [18, 19, 20]])), globals=_make_nest(jnp.array([[2, 3]])), senders=jnp.array([0, 0, 1, 2, 2]), receivers=jnp.array([1, 0, 2, 1, 0])), graph.GraphsTuple(n_node=jnp.array([1]), n_edge=jnp.array([0]), nodes=_make_nest(jnp.array([[8, 9]])), edges=_make_nest(jnp.zeros((0, 3))), globals=_make_nest(jnp.array([[4, 5]])), senders=jnp.array([]), receivers=jnp.array([])), graph.GraphsTuple(n_node=jnp.array([0]), n_edge=jnp.array([0]), nodes=_make_nest(jnp.zeros((0, 2))), edges=_make_nest(jnp.zeros((0, 3))), globals=_make_nest(jnp.array([[6, 7]])), senders=jnp.array([]), receivers=jnp.array([])), graph.GraphsTuple(n_node=jnp.array([2]), n_edge=jnp.array([1]), nodes=_make_nest(jnp.array([[10, 11], [12, 13]])), edges=_make_nest(jnp.array([[21, 22, 23]])), globals=_make_nest(jnp.array([[8, 9]])), senders=jnp.array([1]), receivers=jnp.array([0])), graph.GraphsTuple(n_node=jnp.array([0]), n_edge=jnp.array([0]), nodes=_make_nest(jnp.zeros((0, 2))), edges=_make_nest(jnp.zeros((0, 3))), globals=_make_nest(jnp.array([[10, 11]])), senders=jnp.array([]), receivers=jnp.array([])), graph.GraphsTuple(n_node=jnp.array([0]), n_edge=jnp.array([0]), nodes=_make_nest(jnp.zeros((0, 2))), edges=_make_nest(jnp.zeros((0, 3))), globals=_make_nest(jnp.array([[12, 13]])), senders=jnp.array([]), receivers=jnp.array([])), graph.GraphsTuple(n_node=jnp.array([]), n_edge=jnp.array([]), nodes=_make_nest(jnp.zeros((0, 2))), edges=_make_nest(jnp.zeros((0, 3))), globals=_make_nest(jnp.zeros((0, 2))), senders=jnp.array([]), receivers=jnp.array([])), ] return list_graphs, batched_graph
def pad_with_graphs(graph: gn_graph.GraphsTuple, n_node: int, n_edge: int, n_graph: int = 2) -> gn_graph.GraphsTuple: """Pads a ``GraphsTuple`` to size by adding computation preserving graphs. The ``GraphsTuple`` is padded by first adding a dummy graph which contains the padding nodes and edges, and then empty graphs without nodes or edges. The empty graphs and the dummy graph do not interfer with the graphnet calculations on the original graph, and so are computation preserving. The padding graph requires at least one node and one graph. This function does not support jax.jit, because the shape of the output is data-dependent. Args: graph: ``GraphsTuple`` padded with dummy graph and empty graphs. n_node: the number of nodes in the padded ``GraphsTuple``. n_edge: the number of edges in the padded ``GraphsTuple``. n_graph: the number of graphs in the padded ``GraphsTuple``. Default is 2, which is the lowest possible value, because we always have at least one graph in the original ``GraphsTuple`` and we need one dummy graph for the padding. Raises: ValueError: if the passed ``n_graph`` is smaller than 2. RuntimeError: if the given ``GraphsTuple`` is too large for the given padding. Returns: A padded ``GraphsTuple``. """ if n_graph < 2: raise ValueError( f'n_graph is {n_graph}, which is smaller than minimum value of 2.') graph = jax.device_get(graph) pad_n_node = int(n_node - np.sum(graph.n_node)) pad_n_edge = int(n_edge - np.sum(graph.n_edge)) pad_n_graph = int(n_graph - graph.n_node.shape[0]) if pad_n_node <= 0 or pad_n_edge < 0 or pad_n_graph <= 0: raise RuntimeError( 'Given graph is too large for the given padding. difference: ' f'n_node {pad_n_node}, n_edge {pad_n_edge}, n_graph {pad_n_graph}') pad_n_empty_graph = pad_n_graph - 1 tree_nodes_pad = (lambda leaf: np.zeros( (pad_n_node, ) + leaf.shape[1:], dtype=leaf.dtype)) tree_edges_pad = (lambda leaf: np.zeros( (pad_n_edge, ) + leaf.shape[1:], dtype=leaf.dtype)) tree_globs_pad = (lambda leaf: np.zeros( (pad_n_graph, ) + leaf.shape[1:], dtype=leaf.dtype)) padding_graph = gn_graph.GraphsTuple( n_node=np.concatenate([ np.array([pad_n_node]), np.zeros(pad_n_empty_graph, dtype=np.int32) ]), n_edge=np.concatenate([ np.array([pad_n_edge]), np.zeros(pad_n_empty_graph, dtype=np.int32) ]), nodes=tree.tree_map(tree_nodes_pad, graph.nodes), edges=tree.tree_map(tree_edges_pad, graph.edges), globals=tree.tree_map(tree_globs_pad, graph.globals), senders=np.zeros(pad_n_edge, dtype=np.int32), receivers=np.zeros(pad_n_edge, dtype=np.int32), ) return _batch([graph, padding_graph], np_=np)
def _ApplyGraphNet(graph): """Applies a configured GraphNetwork to a graph. This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 There is one difference. For the nodes update the class aggregates over the sender edges and receiver edges separately. This is a bit more general the algorithm described in the paper. The original behaviour can be recovered by using only the receiver edge aggregations for the update. In addition this implementation supports softmax attention over incoming edge features. Many popular Graph Neural Networks can be implemented as special cases of GraphNets, for more information please see the paper. Args: graph: a `GraphsTuple` containing the graph. Returns: Updated `GraphsTuple`. """ # pylint: disable=g-long-lambda nodes, edges, receivers, senders, globals_, n_node, n_edge = graph # Equivalent to jnp.sum(n_node), but jittable sum_n_node = tree.tree_leaves(nodes)[0].shape[0] sum_n_edge = senders.shape[0] if not tree.tree_all( tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)): raise ValueError( 'All node arrays in nest must contain the same number of nodes.' ) sent_attributes = tree.tree_map(lambda n: n[senders], nodes) received_attributes = tree.tree_map(lambda n: n[receivers], nodes) # Here we scatter the global features to the corresponding edges, # giving us tensors of shape [num_edges, global_feat]. global_edge_attributes = tree.tree_map( lambda g: jnp.repeat( g, n_edge, axis=0, total_repeat_length=sum_n_edge), globals_) if update_edge_fn: edges = update_edge_fn(edges, sent_attributes, received_attributes, global_edge_attributes) if attention_logit_fn: logits = attention_logit_fn(edges, sent_attributes, received_attributes, global_edge_attributes) tree_calculate_weights = functools.partial(attention_normalize_fn, segment_ids=receivers, num_segments=sum_n_node) weights = tree.tree_map(tree_calculate_weights, logits) edges = attention_reduce_fn(edges, weights) if update_node_fn: sent_attributes = tree.tree_map( lambda e: aggregate_edges_for_nodes_fn(e, senders, sum_n_node), edges) received_attributes = tree.tree_map( lambda e: aggregate_edges_for_nodes_fn(e, receivers, sum_n_node ), edges) # Here we scatter the global features to the corresponding nodes, # giving us tensors of shape [num_nodes, global_feat]. global_attributes = tree.tree_map( lambda g: jnp.repeat( g, n_node, axis=0, total_repeat_length=sum_n_node), globals_) nodes = update_node_fn(nodes, sent_attributes, received_attributes, global_attributes) if update_global_fn: n_graph = n_node.shape[0] graph_idx = jnp.arange(n_graph) # To aggregate nodes and edges from each graph to global features, # we first construct tensors that map the node to the corresponding graph. # For example, if you have `n_node=[1,2]`, we construct the tensor # [0, 1, 1]. We then do the same for edges. node_gr_idx = jnp.repeat(graph_idx, n_node, axis=0, total_repeat_length=sum_n_node) edge_gr_idx = jnp.repeat(graph_idx, n_edge, axis=0, total_repeat_length=sum_n_edge) # We use the aggregation function to pool the nodes/edges per graph. node_attributes = tree.tree_map( lambda n: aggregate_nodes_for_globals_fn( n, node_gr_idx, n_graph), nodes) edge_attribtutes = tree.tree_map( lambda e: aggregate_edges_for_globals_fn( e, edge_gr_idx, n_graph), edges) # These pooled nodes are the inputs to the global update fn. globals_ = update_global_fn(node_attributes, edge_attribtutes, globals_) # pylint: enable=g-long-lambda return gn_graph.GraphsTuple(nodes=nodes, edges=edges, receivers=receivers, senders=senders, globals=globals_, n_node=n_node, n_edge=n_edge)