Пример #1
0
def _get_random_graph(max_n_graph=10):
    n_graph = np.random.randint(1, max_n_graph + 1)
    n_node = np.random.randint(0, 10, n_graph)
    n_edge = np.random.randint(0, 20, n_graph)
    # We cannot have any edges if there are no nodes.
    n_edge[n_node == 0] = 0

    senders = []
    receivers = []
    offset = 0
    for n_node_in_graph, n_edge_in_graph in zip(n_node, n_edge):
        if n_edge_in_graph != 0:
            senders += list(
                np.random.randint(0, n_node_in_graph, n_edge_in_graph) +
                offset)
            receivers += list(
                np.random.randint(0, n_node_in_graph, n_edge_in_graph) +
                offset)
        offset += n_node_in_graph

    return graph.GraphsTuple(
        n_node=jnp.asarray(n_node),
        n_edge=jnp.asarray(n_edge),
        nodes=jnp.asarray(np.random.random(size=(np.sum(n_node), 4))),
        edges=jnp.asarray(np.random.random(size=(np.sum(n_edge), 3))),
        globals=jnp.asarray(np.random.random(size=(n_graph, 5))),
        senders=jnp.asarray(senders),
        receivers=jnp.asarray(receivers))
Пример #2
0
 def test_pad_with_graphs(self):
     """Tests padding of graph."""
     _, graphs_tuple = _get_list_and_batched_graph()
     padded_graphs_tuple = utils.pad_with_graphs(graphs_tuple, 10, 12, 9)
     expected_padded_graph = graph.GraphsTuple(
         n_node=jnp.concatenate([graphs_tuple.n_node,
                                 jnp.array([3, 0])]),
         n_edge=jnp.concatenate([graphs_tuple.n_edge,
                                 jnp.array([4, 0])]),
         nodes=tree.tree_map(
             lambda f: jnp.concatenate(
                 [f, jnp.zeros((3, 2), dtype=f.dtype)]),
             graphs_tuple.nodes),
         edges=tree.tree_map(
             lambda f: jnp.concatenate(
                 [f, jnp.zeros((4, 3), dtype=f.dtype)]),
             graphs_tuple.edges),
         globals=tree.tree_map(
             lambda f: jnp.concatenate(
                 [f, jnp.zeros((2, 2), dtype=f.dtype)]),
             graphs_tuple.globals),
         senders=jnp.concatenate(
             [graphs_tuple.senders,
              jnp.array([7, 7, 7, 7])]),
         receivers=jnp.concatenate(
             [graphs_tuple.receivers,
              jnp.array([7, 7, 7, 7])]),
     )
     self.assertAllClose(padded_graphs_tuple,
                         expected_padded_graph,
                         check_dtypes=True)
Пример #3
0
def unpad_with_graphs(
        padded_graph: gn_graph.GraphsTuple) -> gn_graph.GraphsTuple:
    """Unpads the given graph by removing the dummy graph and empty graphs.

  This function assumes that the given graph was padded with the
  ``pad_with_graphs`` function.

  This function does not support jax.jit, because the shape of the output
  is data-dependent!

  Args:
    padded_graph: ``GraphsTuple`` padded with a dummy graph
      and empty graphs.

  Returns:
    The unpadded graph.
  """
    n_padding_graph = get_number_of_padding_with_graphs_graphs(padded_graph)
    n_padding_node = get_number_of_padding_with_graphs_nodes(padded_graph)
    n_padding_edge = get_number_of_padding_with_graphs_edges(padded_graph)

    unpadded_graph = gn_graph.GraphsTuple(
        n_node=padded_graph.n_node[:-n_padding_graph],
        n_edge=padded_graph.n_edge[:-n_padding_graph],
        nodes=tree.tree_map(lambda x: x[:-n_padding_node], padded_graph.nodes),
        edges=tree.tree_map(lambda x: x[:-n_padding_edge], padded_graph.edges),
        globals=tree.tree_map(lambda x: x[:-n_padding_graph],
                              padded_graph.globals),
        senders=padded_graph.senders[:-n_padding_edge],
        receivers=padded_graph.receivers[:-n_padding_edge],
    )
    return unpadded_graph
Пример #4
0
def get_fully_connected_graph(n_node_per_graph: int,
                              n_graph: int,
                              node_features: Optional[ArrayTree] = None,
                              global_features: Optional[ArrayTree] = None,
                              add_self_edges: bool = True):
    """Gets a fully connected graph given n_node_per_graph and n_graph.

  This method is jittable.

  Args:
    n_node_per_graph: The number of nodes in each graph.
    n_graph: The number of graphs in the `jraph.GraphsTuple`.
    node_features: Optional node features.
    global_features: Optional global features.
    add_self_edges: Whether to add self edges to the graph.

  Returns:
    `jraph.GraphsTuple`
  """
    if node_features is not None:
        num_node_features = jax.tree_leaves(node_features)[0].shape[0]
        if n_node_per_graph * n_graph != num_node_features:
            raise ValueError(
                'Number of nodes is not equal to num_nodes_per_graph * n_graph.'
            )
    if global_features is not None:
        if n_graph != jax.tree_leaves(global_features)[0].shape[0]:
            raise ValueError('The number of globals is not equal to n_graph.')
    senders = []
    receivers = []
    n_edge = []
    tmp_senders, tmp_receivers = jnp.meshgrid(jnp.arange(n_node_per_graph),
                                              jnp.arange(n_node_per_graph))
    if not add_self_edges:
        tmp_senders = jax.vmap(jnp.roll)(tmp_senders,
                                         jnp.arange(len(tmp_senders)))[:, 1:]
        tmp_receivers = tmp_receivers[:, 1:]
    # Flatten the senders and receivers.
    tmp_senders = tmp_senders.flatten()
    tmp_receivers = tmp_receivers.flatten()
    for graph_idx in range(n_graph):
        offset = graph_idx * n_node_per_graph
        senders.append(tmp_senders + offset)
        receivers.append(tmp_receivers + offset)
        n_edge.append(len(tmp_senders))
    return gn_graph.GraphsTuple(
        nodes=node_features,
        edges=None,
        n_node=jnp.array([n_node_per_graph] * n_graph),
        n_edge=jnp.array(n_edge) if n_edge else jnp.array([0]),
        senders=jnp.concatenate(senders) if senders else senders,
        receivers=jnp.concatenate(receivers) if receivers else receivers,
        globals=global_features)
Пример #5
0
 def test_unpad(self):
     """Tests unpadding of graph."""
     _, graphs_tuple = _get_list_and_batched_graph()
     unpadded_graphs_tuple = utils.unpad_with_graphs(graphs_tuple)
     expected_unpadded_graph = graph.GraphsTuple(
         n_node=jnp.array([1, 3, 1, 0]),
         n_edge=jnp.array([2, 5, 0, 0]),
         nodes=_make_nest(jnp.arange(10).reshape(5, 2)),
         edges=_make_nest(jnp.arange(21).reshape(7, 3)),
         globals=_make_nest(jnp.arange(8).reshape(4, 2)),
         senders=jnp.array([0, 0, 1, 1, 2, 3, 3]),
         receivers=jnp.array([0, 0, 2, 1, 3, 2, 1]))
     self.assertAllClose(unpadded_graphs_tuple,
                         expected_unpadded_graph,
                         check_dtypes=True)
Пример #6
0
 def test_connect_gnns(self, network_fn):
     batched_graphs_tuple = graph.GraphsTuple(
         n_node=jnp.array([1, 3, 1, 0, 2, 0, 0]),
         n_edge=jnp.array([1, 7, 1, 0, 3, 0, 0]),
         nodes=jnp.arange(14).reshape(7, 2),
         edges=jnp.arange(36).reshape(12, 3),
         globals=jnp.arange(14).reshape(7, 2),
         senders=jnp.array([0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 3, 6]),
         receivers=jnp.array([0, 1, 2, 3, 4, 5, 6, 2, 3, 2, 1, 5]))
     with self.subTest('nojit'):
         out, expected_out = network_fn(batched_graphs_tuple)
         jax.tree_util.tree_map(np.testing.assert_allclose, out,
                                expected_out)
     with self.subTest('jit'):
         out, expected_out = jax.jit(network_fn)(batched_graphs_tuple)
         jax.tree_util.tree_map(np.testing.assert_allclose, out,
                                expected_out)
Пример #7
0
def _get_random_graph(max_n_graph=10,
                      include_node_features=True,
                      include_edge_features=True,
                      include_globals=True):
    n_graph = np.random.randint(1, max_n_graph + 1)
    n_node = np.random.randint(0, 10, n_graph)
    n_edge = np.random.randint(0, 20, n_graph)
    # We cannot have any edges if there are no nodes.
    n_edge[n_node == 0] = 0

    senders = []
    receivers = []
    offset = 0
    for n_node_in_graph, n_edge_in_graph in zip(n_node, n_edge):
        if n_edge_in_graph != 0:
            senders += list(
                np.random.randint(0, n_node_in_graph, n_edge_in_graph) +
                offset)
            receivers += list(
                np.random.randint(0, n_node_in_graph, n_edge_in_graph) +
                offset)
        offset += n_node_in_graph
    if include_globals:
        global_features = jnp.asarray(np.random.random(size=(n_graph, 5)))
    else:
        global_features = None
    if include_node_features:
        nodes = jnp.asarray(np.random.random(size=(np.sum(n_node), 4)))
    else:
        nodes = None

    if include_edge_features:
        edges = jnp.asarray(np.random.random(size=(np.sum(n_edge), 3)))
    else:
        edges = None
    return graph.GraphsTuple(n_node=jnp.asarray(n_node),
                             n_edge=jnp.asarray(n_edge),
                             nodes=nodes,
                             edges=edges,
                             globals=global_features,
                             senders=jnp.asarray(senders),
                             receivers=jnp.asarray(receivers))
Пример #8
0
def _batch(graphs, np_):
    """Returns batched graph given a list of graphs and a numpy-like module."""
    # Calculates offsets for sender and receiver arrays, caused by concatenating
    # the nodes arrays.
    offsets = np_.cumsum(
        np_.array([0] + [np_.sum(g.n_node) for g in graphs[:-1]]))

    def _map_concat(nests):
        concat = lambda *args: np_.concatenate(args)
        return tree.tree_multimap(concat, *nests)

    return gn_graph.GraphsTuple(
        n_node=np_.concatenate([g.n_node for g in graphs]),
        n_edge=np_.concatenate([g.n_edge for g in graphs]),
        nodes=_map_concat([g.nodes for g in graphs]),
        edges=_map_concat([g.edges for g in graphs]),
        globals=_map_concat([g.globals for g in graphs]),
        senders=np_.concatenate(
            [g.senders + o for g, o in zip(graphs, offsets)]),
        receivers=np_.concatenate(
            [g.receivers + o for g, o in zip(graphs, offsets)]))
Пример #9
0
    def test_connect_graphnetwork_nones(self, network_fn):
        batched_graphs_tuple = graph.GraphsTuple(
            n_node=jnp.array([1, 3, 1, 0, 2, 0, 0]),
            n_edge=jnp.array([2, 5, 0, 0, 1, 0, 0]),
            nodes=self._make_nest(jnp.arange(14).reshape(7, 2)),
            edges=self._make_nest(jnp.arange(24).reshape(8, 3)),
            globals=self._make_nest(jnp.arange(14).reshape(7, 2)),
            senders=jnp.array([0, 0, 1, 1, 2, 3, 3, 6]),
            receivers=jnp.array([0, 0, 2, 1, 3, 2, 1, 5]))

        for name, graphs_tuple in [
            ('no_globals', batched_graphs_tuple._replace(globals=None)),
            ('empty_globals', batched_graphs_tuple._replace(globals=[])),
            ('no_edges', batched_graphs_tuple._replace(edges=None)),
            ('empty_edges', batched_graphs_tuple._replace(edges=[])),
        ]:
            with self.subTest(name + '_nojit'):
                out = network_fn(graphs_tuple)
                jax.tree_util.tree_map(np.testing.assert_allclose, out,
                                       graphs_tuple)
            with self.subTest(name + '_jit'):
                out = jax.jit(network_fn)(graphs_tuple)
                jax.tree_util.tree_map(np.testing.assert_allclose, out,
                                       graphs_tuple)
Пример #10
0
def _get_list_and_batched_graph():
    """Returns a list of individual graphs and a batched version.

  This test-case includes the following corner-cases:
    - single node,
    - multiple nodes,
    - no edges,
    - single edge,
    - and multiple edges.
  """
    batched_graph = graph.GraphsTuple(
        n_node=jnp.array([1, 3, 1, 0, 2, 0, 0]),
        n_edge=jnp.array([2, 5, 0, 0, 1, 0, 0]),
        nodes=_make_nest(jnp.arange(14).reshape(7, 2)),
        edges=_make_nest(jnp.arange(24).reshape(8, 3)),
        globals=_make_nest(jnp.arange(14).reshape(7, 2)),
        senders=jnp.array([0, 0, 1, 1, 2, 3, 3, 6]),
        receivers=jnp.array([0, 0, 2, 1, 3, 2, 1, 5]))

    list_graphs = [
        graph.GraphsTuple(n_node=jnp.array([1]),
                          n_edge=jnp.array([2]),
                          nodes=_make_nest(jnp.array([[0, 1]])),
                          edges=_make_nest(jnp.array([[0, 1, 2], [3, 4, 5]])),
                          globals=_make_nest(jnp.array([[0, 1]])),
                          senders=jnp.array([0, 0]),
                          receivers=jnp.array([0, 0])),
        graph.GraphsTuple(n_node=jnp.array([3]),
                          n_edge=jnp.array([5]),
                          nodes=_make_nest(jnp.array([[2, 3], [4, 5], [6,
                                                                       7]])),
                          edges=_make_nest(
                              jnp.array([[6, 7, 8], [9, 10, 11], [12, 13, 14],
                                         [15, 16, 17], [18, 19, 20]])),
                          globals=_make_nest(jnp.array([[2, 3]])),
                          senders=jnp.array([0, 0, 1, 2, 2]),
                          receivers=jnp.array([1, 0, 2, 1, 0])),
        graph.GraphsTuple(n_node=jnp.array([1]),
                          n_edge=jnp.array([0]),
                          nodes=_make_nest(jnp.array([[8, 9]])),
                          edges=_make_nest(jnp.zeros((0, 3))),
                          globals=_make_nest(jnp.array([[4, 5]])),
                          senders=jnp.array([]),
                          receivers=jnp.array([])),
        graph.GraphsTuple(n_node=jnp.array([0]),
                          n_edge=jnp.array([0]),
                          nodes=_make_nest(jnp.zeros((0, 2))),
                          edges=_make_nest(jnp.zeros((0, 3))),
                          globals=_make_nest(jnp.array([[6, 7]])),
                          senders=jnp.array([]),
                          receivers=jnp.array([])),
        graph.GraphsTuple(n_node=jnp.array([2]),
                          n_edge=jnp.array([1]),
                          nodes=_make_nest(jnp.array([[10, 11], [12, 13]])),
                          edges=_make_nest(jnp.array([[21, 22, 23]])),
                          globals=_make_nest(jnp.array([[8, 9]])),
                          senders=jnp.array([1]),
                          receivers=jnp.array([0])),
        graph.GraphsTuple(n_node=jnp.array([0]),
                          n_edge=jnp.array([0]),
                          nodes=_make_nest(jnp.zeros((0, 2))),
                          edges=_make_nest(jnp.zeros((0, 3))),
                          globals=_make_nest(jnp.array([[10, 11]])),
                          senders=jnp.array([]),
                          receivers=jnp.array([])),
        graph.GraphsTuple(n_node=jnp.array([0]),
                          n_edge=jnp.array([0]),
                          nodes=_make_nest(jnp.zeros((0, 2))),
                          edges=_make_nest(jnp.zeros((0, 3))),
                          globals=_make_nest(jnp.array([[12, 13]])),
                          senders=jnp.array([]),
                          receivers=jnp.array([])),
        graph.GraphsTuple(n_node=jnp.array([]),
                          n_edge=jnp.array([]),
                          nodes=_make_nest(jnp.zeros((0, 2))),
                          edges=_make_nest(jnp.zeros((0, 3))),
                          globals=_make_nest(jnp.zeros((0, 2))),
                          senders=jnp.array([]),
                          receivers=jnp.array([])),
    ]

    return list_graphs, batched_graph
Пример #11
0
def pad_with_graphs(graph: gn_graph.GraphsTuple,
                    n_node: int,
                    n_edge: int,
                    n_graph: int = 2) -> gn_graph.GraphsTuple:
    """Pads a ``GraphsTuple`` to size by adding computation preserving graphs.

  The ``GraphsTuple`` is padded by first adding a dummy graph which contains the
  padding nodes and edges, and then empty graphs without nodes or edges.

  The empty graphs and the dummy graph do not interfer with the graphnet
  calculations on the original graph, and so are computation preserving.

  The padding graph requires at least one node and one graph.

  This function does not support jax.jit, because the shape of the output
  is data-dependent.

  Args:
    graph: ``GraphsTuple`` padded with dummy graph and empty graphs.
    n_node: the number of nodes in the padded ``GraphsTuple``.
    n_edge: the number of edges in the padded ``GraphsTuple``.
    n_graph: the number of graphs in the padded ``GraphsTuple``. Default is 2,
      which is the lowest possible value, because we always have at least one
      graph in the original ``GraphsTuple`` and we need one dummy graph for the
      padding.

  Raises:
    ValueError: if the passed ``n_graph`` is smaller than 2.
    RuntimeError: if the given ``GraphsTuple`` is too large for the given
      padding.

  Returns:
    A padded ``GraphsTuple``.
  """
    if n_graph < 2:
        raise ValueError(
            f'n_graph is {n_graph}, which is smaller than minimum value of 2.')
    graph = jax.device_get(graph)
    pad_n_node = int(n_node - np.sum(graph.n_node))
    pad_n_edge = int(n_edge - np.sum(graph.n_edge))
    pad_n_graph = int(n_graph - graph.n_node.shape[0])
    if pad_n_node <= 0 or pad_n_edge < 0 or pad_n_graph <= 0:
        raise RuntimeError(
            'Given graph is too large for the given padding. difference: '
            f'n_node {pad_n_node}, n_edge {pad_n_edge}, n_graph {pad_n_graph}')

    pad_n_empty_graph = pad_n_graph - 1

    tree_nodes_pad = (lambda leaf: np.zeros(
        (pad_n_node, ) + leaf.shape[1:], dtype=leaf.dtype))
    tree_edges_pad = (lambda leaf: np.zeros(
        (pad_n_edge, ) + leaf.shape[1:], dtype=leaf.dtype))
    tree_globs_pad = (lambda leaf: np.zeros(
        (pad_n_graph, ) + leaf.shape[1:], dtype=leaf.dtype))

    padding_graph = gn_graph.GraphsTuple(
        n_node=np.concatenate([
            np.array([pad_n_node]),
            np.zeros(pad_n_empty_graph, dtype=np.int32)
        ]),
        n_edge=np.concatenate([
            np.array([pad_n_edge]),
            np.zeros(pad_n_empty_graph, dtype=np.int32)
        ]),
        nodes=tree.tree_map(tree_nodes_pad, graph.nodes),
        edges=tree.tree_map(tree_edges_pad, graph.edges),
        globals=tree.tree_map(tree_globs_pad, graph.globals),
        senders=np.zeros(pad_n_edge, dtype=np.int32),
        receivers=np.zeros(pad_n_edge, dtype=np.int32),
    )
    return _batch([graph, padding_graph], np_=np)
Пример #12
0
    def _ApplyGraphNet(graph):
        """Applies a configured GraphNetwork to a graph.

    This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261

    There is one difference. For the nodes update the class aggregates over the
    sender edges and receiver edges separately. This is a bit more general
    the algorithm described in the paper. The original behaviour can be
    recovered by using only the receiver edge aggregations for the update.

    In addition this implementation supports softmax attention over incoming
    edge features.

    Many popular Graph Neural Networks can be implemented as special cases of
    GraphNets, for more information please see the paper.

    Args:
      graph: a `GraphsTuple` containing the graph.

    Returns:
      Updated `GraphsTuple`.
    """
        # pylint: disable=g-long-lambda
        nodes, edges, receivers, senders, globals_, n_node, n_edge = graph
        # Equivalent to jnp.sum(n_node), but jittable
        sum_n_node = tree.tree_leaves(nodes)[0].shape[0]
        sum_n_edge = senders.shape[0]
        if not tree.tree_all(
                tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)):
            raise ValueError(
                'All node arrays in nest must contain the same number of nodes.'
            )

        sent_attributes = tree.tree_map(lambda n: n[senders], nodes)
        received_attributes = tree.tree_map(lambda n: n[receivers], nodes)
        # Here we scatter the global features to the corresponding edges,
        # giving us tensors of shape [num_edges, global_feat].
        global_edge_attributes = tree.tree_map(
            lambda g: jnp.repeat(
                g, n_edge, axis=0, total_repeat_length=sum_n_edge), globals_)

        if update_edge_fn:
            edges = update_edge_fn(edges, sent_attributes, received_attributes,
                                   global_edge_attributes)

        if attention_logit_fn:
            logits = attention_logit_fn(edges, sent_attributes,
                                        received_attributes,
                                        global_edge_attributes)
            tree_calculate_weights = functools.partial(attention_normalize_fn,
                                                       segment_ids=receivers,
                                                       num_segments=sum_n_node)
            weights = tree.tree_map(tree_calculate_weights, logits)
            edges = attention_reduce_fn(edges, weights)

        if update_node_fn:
            sent_attributes = tree.tree_map(
                lambda e: aggregate_edges_for_nodes_fn(e, senders, sum_n_node),
                edges)
            received_attributes = tree.tree_map(
                lambda e: aggregate_edges_for_nodes_fn(e, receivers, sum_n_node
                                                       ), edges)
            # Here we scatter the global features to the corresponding nodes,
            # giving us tensors of shape [num_nodes, global_feat].
            global_attributes = tree.tree_map(
                lambda g: jnp.repeat(
                    g, n_node, axis=0, total_repeat_length=sum_n_node),
                globals_)
            nodes = update_node_fn(nodes, sent_attributes, received_attributes,
                                   global_attributes)

        if update_global_fn:
            n_graph = n_node.shape[0]
            graph_idx = jnp.arange(n_graph)
            # To aggregate nodes and edges from each graph to global features,
            # we first construct tensors that map the node to the corresponding graph.
            # For example, if you have `n_node=[1,2]`, we construct the tensor
            # [0, 1, 1]. We then do the same for edges.
            node_gr_idx = jnp.repeat(graph_idx,
                                     n_node,
                                     axis=0,
                                     total_repeat_length=sum_n_node)
            edge_gr_idx = jnp.repeat(graph_idx,
                                     n_edge,
                                     axis=0,
                                     total_repeat_length=sum_n_edge)
            # We use the aggregation function to pool the nodes/edges per graph.
            node_attributes = tree.tree_map(
                lambda n: aggregate_nodes_for_globals_fn(
                    n, node_gr_idx, n_graph), nodes)
            edge_attribtutes = tree.tree_map(
                lambda e: aggregate_edges_for_globals_fn(
                    e, edge_gr_idx, n_graph), edges)
            # These pooled nodes are the inputs to the global update fn.
            globals_ = update_global_fn(node_attributes, edge_attribtutes,
                                        globals_)
        # pylint: enable=g-long-lambda
        return gn_graph.GraphsTuple(nodes=nodes,
                                    edges=edges,
                                    receivers=receivers,
                                    senders=senders,
                                    globals=globals_,
                                    n_node=n_node,
                                    n_edge=n_edge)