def test_pad_graphs_by_device(self):
     graphs = [
         jraph.GraphsTuple(
             nodes=np.arange(5)[:, None],  # pad to 8
             edges=np.arange(3)[:, None],  # pad to 4
             senders=np.array([0, 1, 4]),  # pad to 4
             receivers=np.array([1, 0, 2]),  # pad to 4
             n_node=np.array([2, 3]),  # pad to 3
             n_edge=np.array([2, 1]),  # pad to 3
             globals=None),
         jraph.GraphsTuple(
             nodes=np.arange(4)[:, None],  # pad to 8
             edges=np.arange(1)[:, None],  # pad to 4
             senders=np.array([1]),  # pad to 4
             receivers=np.array([0]),  # pad to 4
             n_node=np.array([2, 2]),  # pad to 3
             n_edge=np.array([1, 0]),  # pad to 3
             globals=None),
     ]
     padded = gn.pad_graphs_by_device(graphs)
     np.testing.assert_array_equal(
         padded.nodes,
         np.array([0, 1, 2, 3, 4, 0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0])[:,
                                                                    None])
     np.testing.assert_array_equal(
         padded.edges,
         np.array([0, 1, 2, 0, 0, 0, 0, 0])[:, None])
     np.testing.assert_array_equal(padded.senders,
                                   np.array([0, 1, 4, 5, 1, 4, 4, 4]))
     np.testing.assert_array_equal(padded.receivers,
                                   np.array([1, 0, 2, 5, 0, 4, 4, 4]))
     np.testing.assert_array_equal(padded.n_node,
                                   np.array([2, 3, 3, 2, 2, 4]))
     np.testing.assert_array_equal(padded.n_edge,
                                   np.array([2, 1, 1, 1, 0, 3]))
 def test_batch_graphs_by_device(self):
     # batch 4 graphs for 2 devices
     num_devices = 2
     graphs = [
         jraph.GraphsTuple(nodes=np.arange(2)[:, None],
                           edges=np.arange(2)[:, None],
                           senders=np.array([0, 1]),
                           receivers=np.array([1, 0]),
                           n_node=np.array([2]),
                           n_edge=np.array([2]),
                           globals=None),
         jraph.GraphsTuple(nodes=np.arange(3)[:, None],
                           edges=np.arange(1)[:, None],
                           senders=np.array([2]),
                           receivers=np.array([0]),
                           n_node=np.array([3]),
                           n_edge=np.array([1]),
                           globals=None),
         jraph.GraphsTuple(nodes=np.arange(4)[:, None],
                           edges=np.arange(2)[:, None],
                           senders=np.array([1, 0]),
                           receivers=np.array([2, 3]),
                           n_node=np.array([4]),
                           n_edge=np.array([2]),
                           globals=None),
         jraph.GraphsTuple(nodes=np.arange(5)[:, None],
                           edges=np.arange(3)[:, None],
                           senders=np.array([2, 1, 3]),
                           receivers=np.array([1, 4, 0]),
                           n_node=np.array([5]),
                           n_edge=np.array([3]),
                           globals=None),
     ]
     batched = gn.batch_graphs_by_device(graphs, num_devices)
     self.assertLen(batched, num_devices)
     np.testing.assert_array_equal(batched[0].nodes,
                                   np.array([0, 1, 0, 1, 2])[:, None])
     np.testing.assert_array_equal(batched[0].edges,
                                   np.array([0, 1, 0])[:, None])
     np.testing.assert_array_equal(batched[0].senders, np.array([0, 1, 4]))
     np.testing.assert_array_equal(batched[0].receivers, np.array([1, 0,
                                                                   2]))
     np.testing.assert_array_equal(batched[0].n_node, np.array([2, 3]))
     np.testing.assert_array_equal(batched[0].n_edge, np.array([2, 1]))
     np.testing.assert_array_equal(
         batched[1].nodes,
         np.array([0, 1, 2, 3, 0, 1, 2, 3, 4])[:, None])
     np.testing.assert_array_equal(batched[1].edges,
                                   np.array([0, 1, 0, 1, 2])[:, None])
     np.testing.assert_array_equal(batched[1].senders,
                                   np.array([1, 0, 6, 5, 7]))
     np.testing.assert_array_equal(batched[1].receivers,
                                   np.array([2, 3, 5, 8, 4]))
     np.testing.assert_array_equal(batched[1].n_node, np.array([4, 5]))
     np.testing.assert_array_equal(batched[1].n_edge, np.array([2, 3]))
Exemplo n.º 3
0
def conway_graph(size) -> jraph.GraphsTuple:
    """Returns a graph representing the game field of conway's game of life."""
    # Creates nodes: each node represents a cell in the game.
    n_node = size**2
    nodes = np.zeros((n_node, 1))
    node_indices = jnp.arange(n_node)
    # Creates edges, senders and receivers:
    # the senders represent the connections to the 8 neighboring fields.
    n_edge = 8 * n_node
    edges = jnp.zeros((n_edge, 1))
    senders = jnp.vstack([
        node_indices - size - 1, node_indices - size, node_indices - size + 1,
        node_indices - 1, node_indices + 1, node_indices + size - 1,
        node_indices + size, node_indices + size + 1
    ])
    senders = senders.T.reshape(-1)
    senders = (senders + size**2) % size**2
    receivers = jnp.repeat(node_indices, 8)
    # Adds a glider to the game
    nodes[0, 0] = 1.0
    nodes[1, 0] = 1.0
    nodes[2, 0] = 1.0
    nodes[2 + size, 0] = 1.0
    nodes[1 + 2 * size, 0] = 1.0
    return jraph.GraphsTuple(n_node=jnp.array([n_node]),
                             n_edge=jnp.array([n_edge]),
                             nodes=jnp.asarray(nodes),
                             edges=edges,
                             globals=None,
                             senders=senders,
                             receivers=receivers)
Exemplo n.º 4
0
def pad_graphs(graphs: jraph.GraphsTuple,
               pad_n_nodes: Optional[int] = None,
               pad_n_edges: Optional[int] = None) -> jraph.GraphsTuple:
    """Pad graphs to have a canonical number of nodes and edges.

  Here we pad the number of nodes and number of edges to powers of 2 by adding a
  placeholder graph to the end of the batch.  So that the batch gets at most 2x
  as large as before, and number of graphs increase by 1.

  Note this method always adds at least one new node to the placeholder graph to
  make sure any edges if added are valid.

  Args:
    graphs: a batch of graphs.
    pad_n_nodes: (optional) number of nodes to pad to.
    pad_n_edges: (optional) number of edges to pad to.

  Returns:
    padded: the input batch padded to canonical sizes.
  """
    n_nodes, node_dim = graphs.nodes.shape
    n_edges, edge_dim = graphs.edges.shape
    # Add at least one extra node to the placeholder graph.
    if pad_n_nodes is None:
        pad_n_nodes = pad_size(n_nodes + 1)
    if pad_n_edges is None:
        pad_n_edges = pad_size(n_edges)

    nodes = np.concatenate([
        graphs.nodes,
        np.zeros((pad_n_nodes - n_nodes, node_dim), dtype=graphs.nodes.dtype)
    ],
                           axis=0)
    edges = np.concatenate([
        graphs.edges,
        np.zeros((pad_n_edges - n_edges, edge_dim), dtype=graphs.edges.dtype)
    ],
                           axis=0)
    # Add padding edges
    senders = np.concatenate([
        graphs.senders,
        np.full(pad_n_edges - n_edges, n_nodes, dtype=graphs.senders.dtype)
    ],
                             axis=0)
    receivers = np.concatenate([
        graphs.receivers,
        np.full(pad_n_edges - n_edges, n_nodes, dtype=graphs.receivers.dtype)
    ],
                               axis=0)
    n_node = np.concatenate(
        [graphs.n_node, np.full(1, pad_n_nodes - n_nodes)], axis=0)
    n_edge = np.concatenate(
        [graphs.n_edge, np.full(1, pad_n_edges - n_edges)], axis=0)
    return jraph.GraphsTuple(nodes=nodes,
                             edges=edges,
                             senders=senders,
                             receivers=receivers,
                             n_node=n_node,
                             n_edge=n_edge,
                             globals=None)
Exemplo n.º 5
0
def get_voting_problem(min_n_voters: int, max_n_voters: int) -> Problem:
    """Creates set of one-hot vectors representing a randomly generated election.

  Args:
    min_n_voters: minimum number of voters in the election.
    max_n_voters: maximum number of voters in the election.

  Returns:
    set, one-hot vector encoding the winner.
  """
    n_candidates = 20
    n_voters = random.randint(min_n_voters, max_n_voters)
    votes = np.random.randint(0, n_candidates, size=(n_voters, ))
    one_hot_votes = np.eye(n_candidates)[votes]
    winner = np.argmax(np.sum(one_hot_votes, axis=0))
    one_hot_winner = np.eye(n_candidates)[winner]

    graph = jraph.GraphsTuple(
        n_node=np.asarray([n_voters]),
        n_edge=np.asarray([0]),
        nodes=one_hot_votes,
        edges=None,
        globals=np.zeros((1, n_candidates)),
        # There are no edges in our graph.
        senders=np.array([], dtype=np.int32),
        receivers=np.array([], dtype=np.int32))

    # In order to jit compile our code, we have to pad the nodes and edges of
    # the GraphsTuple to a static shape.
    graph = jraph.pad_with_graphs(graph, max_n_voters + 1, 0)

    return Problem(graph=graph, labels=one_hot_winner)
Exemplo n.º 6
0
  def test_graph_conditioned_transformer_learns(self):
    graphs = jraph.GraphsTuple(
        nodes=np.ones((4, 3), dtype=np.float32),
        edges=np.ones((3, 1), dtype=np.float32),
        senders=np.array([0, 2, 3], dtype=np.int32),
        receivers=np.array([1, 3, 2], dtype=np.int32),
        n_node=np.array([2, 2], dtype=np.int32),
        n_edge=np.array([1, 2], dtype=np.int32),
        globals=None,
        )
    seqs = np.array([[1, 2, 2, 0],
                     [1, 3, 3, 3]], dtype=np.int32)
    vocab_size = seqs.max() + 1
    embed_dim = 8
    max_graph_size = graphs.n_node.max()

    logging.info('Training seqs: %r', seqs)

    x = seqs[:, :-1]
    y = seqs[:, 1:]

    def model_fn(vocab_size, embed_dim):
      return models.Graph2TextTransformer(
          vocab_size=vocab_size,
          emb_dim=embed_dim,
          num_layers=2,
          num_heads=4,
          cutoffs=[],
          gnn_embed_dim=embed_dim,
          gnn_num_layers=2)

    def forward(graphs, inputs, labels, max_graph_size):
      input_mask = (labels != 0).astype(jnp.float32)
      return model_fn(vocab_size, embed_dim).loss(
          graphs, max_graph_size, False, inputs, labels, mask=input_mask)

    init_fn, apply_fn = hk.transform_with_state(forward)
    rng = hk.PRNGSequence(8)
    params, state = init_fn(next(rng), graphs, x, y, max_graph_size)

    def apply(*args, **kwargs):
      out, state = apply_fn(*args, **kwargs)
      return out[0], (out[1], state)
    apply = jax.jit(apply, static_argnums=6)

    optimizer = optax.chain(
        optax.scale_by_adam(),
        optax.scale(-1e-3))
    opt_state = optimizer.init(params)
    for i in range(500):
      (loss, model_state), grad = jax.value_and_grad(apply, has_aux=True)(
          params, state, next(rng), graphs, x, y, max_graph_size)
      metrics, state = model_state
      updates, opt_state = optimizer.update(grad, opt_state, params)
      params = optax.apply_updates(params, updates)
      if (i + 1) % 100 == 0:
        logging.info(
            'Step %d, %r', i + 1, {k: float(v) for k, v in metrics.items()})
    logging.info('Loss: %.8f', loss)
    self.assertLess(loss, 1.0)
Exemplo n.º 7
0
  def test_graph_embedding_model_runs(self):
    graph = jraph.GraphsTuple(
        nodes=np.array([[0, 1, 1],
                        [1, 2, 0],
                        [0, 3, 0],
                        [0, 4, 4]], dtype=np.float32),
        edges=np.array([[1, 1],
                        [2, 2],
                        [3, 3]], dtype=np.float32),
        senders=np.array([0, 1, 2], dtype=np.int32),
        receivers=np.array([1, 2, 3], dtype=np.int32),
        n_node=np.array([4], dtype=np.int32),
        n_edge=np.array([3], dtype=np.int32),
        globals=None)
    embed_dim = 3

    def forward(graph):
      return embedding.GraphEmbeddingModel(embed_dim=3, num_layers=2)(graph)

    init_fn, apply_fn = hk.without_apply_rng(hk.transform(forward))
    key = hk.PRNGSequence(8)
    params = init_fn(next(key), graph)
    out = apply_fn(params, graph)

    self.assertEqual(out.nodes.shape, (graph.nodes.shape[0], embed_dim))
    self.assertEqual(out.edges.shape, (graph.edges.shape[0], embed_dim))
    np.testing.assert_array_equal(out.senders, graph.senders)
    np.testing.assert_array_equal(out.receivers, graph.receivers)
    np.testing.assert_array_equal(out.n_node, graph.n_node)
Exemplo n.º 8
0
def get_zacharys_karate_club() -> jraph.GraphsTuple:
    """Returns GraphsTuple representing Zachary's karate club."""
    social_graph = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2), (4, 0),
                    (5, 0), (6, 0), (6, 4), (6, 5), (7, 0), (7, 1), (7, 2),
                    (7, 3), (8, 0), (8, 2), (9, 2), (10, 0), (10, 4), (10, 5),
                    (11, 0), (12, 0), (12, 3), (13, 0), (13, 1), (13, 2),
                    (13, 3), (16, 5), (16, 6), (17, 0), (17, 1), (19, 0),
                    (19, 1), (21, 0), (21, 1), (25, 23), (25, 24), (27, 2),
                    (27, 23), (27, 24), (28, 2), (29, 23), (29, 26), (30, 1),
                    (30, 8), (31, 0), (31, 24), (31, 25), (31, 28), (32, 2),
                    (32, 8), (32, 14), (32, 15), (32, 18), (32, 20), (32, 22),
                    (32, 23), (32, 29), (32, 30), (32, 31), (33, 8), (33, 9),
                    (33, 13), (33, 14), (33, 15), (33, 18), (33, 19), (33, 20),
                    (33, 22), (33, 23), (33, 26), (33, 27), (33, 28), (33, 29),
                    (33, 30), (33, 31), (33, 32)]
    # Add reverse edges.
    social_graph += [(edge[1], edge[0]) for edge in social_graph]
    n_club_members = 34

    return jraph.GraphsTuple(
        n_node=jnp.asarray([n_club_members]),
        n_edge=jnp.asarray([len(social_graph)]),
        # One-hot encoding for nodes.
        nodes=jnp.eye(n_club_members),
        # No edge features.
        edges=None,
        globals=None,
        senders=jnp.asarray([edge[0] for edge in social_graph]),
        receivers=jnp.asarray([edge[1] for edge in social_graph]))
Exemplo n.º 9
0
def get_dummy_graph(add_self_loops, symmetrize_edges, adjacency_normalization):
    """Returns a small dummy GraphsTuple."""
    senders = np.array([0, 2])
    receivers = np.array([1, 1])
    num_edges = len(senders)
    num_nodes = 3
    node_features = np.array([[2.], [1.], [1.]], dtype=np.float32)

    if symmetrize_edges:
        new_senders = np.concatenate([senders, receivers], axis=0)
        new_receivers = np.concatenate([receivers, senders], axis=0)
        senders, receivers = new_senders, new_receivers
        num_edges *= 2

    if add_self_loops:
        senders = np.concatenate([senders, np.arange(num_nodes)], axis=0)
        receivers = np.concatenate([receivers, np.arange(num_nodes)], axis=0)
        num_edges += num_nodes

    dummy_graph = jraph.GraphsTuple(
        n_node=np.asarray([num_nodes]),
        n_edge=np.asarray([num_edges]),
        senders=senders,
        receivers=receivers,
        nodes=node_features,
        edges=np.ones((num_edges, 1)),
        globals=np.zeros((1, 1)),
    )

    return normalizations.normalize_edges_with_mask(
        dummy_graph,
        mask=None,
        adjacency_normalization=adjacency_normalization)
Exemplo n.º 10
0
def get_2sat_problem(min_n_literals: int, max_n_literals: int) -> Problem:
  """Creates bipartite-graph representing a randomly generated 2-sat problem.

  Args:
    min_n_literals: minimum number of literals in the 2-sat problem.
    max_n_literals: maximum number of literals in the 2-sat problem.

  Returns:
    bipartite-graph, node labels and node mask.
  """
  n_literals = random.randint(min_n_literals, max_n_literals)
  n_literals_true = random.randint(1, n_literals - 1)
  n_constraints = n_literals * (n_literals - 1) // 2

  n_node = n_literals +  n_constraints
  # 0 indicates a literal node
  # 1 indicates a constraint node.
  nodes = [0 if i < n_literals else 1 for i in range(n_node)]
  edges = []
  senders = []
  for literal_node1 in range(n_literals):
    for literal_node2 in range(literal_node1 + 1, n_literals):
      senders.append(literal_node1)
      senders.append(literal_node2)
      # 1 indicates that the literal must be true for this constraint.
      # 0 indicates that the literal must be false for this constraint.
      # I.e. with literals a and b, we have the following possible constraints:
      # 0, 0 -> a or b
      # 1, 0 -> not a or b
      # 0, 1 -> a or not b
      # 1, 1 -> not a or not b
      edges.append(1 if literal_node1 < n_literals_true else 0)
      edges.append(1 if literal_node2 < n_literals_true else 0)

  graph = jraph.GraphsTuple(
      n_node=np.asarray([n_node]),
      n_edge=np.asarray([2 * n_constraints]),
      # One-hot encoding for nodes and edges.
      nodes=np.eye(2)[nodes],
      edges=np.eye(2)[edges],
      globals=None,
      senders=np.asarray(senders),
      receivers=np.repeat(np.arange(n_constraints) + n_literals, 2))

  # In order to jit compile our code, we have to pad the nodes and edges of
  # the GraphsTuple to a static shape.
  max_n_constraints = max_n_literals * (max_n_literals - 1) // 2
  max_nodes = max_n_literals + max_n_constraints  + 1
  max_edges = 2 * max_n_constraints
  graph = jraph.pad_with_graphs(graph, max_nodes, max_edges)

  # The ground truth solution for the 2-sat problem.
  labels = (np.arange(max_nodes) < n_literals_true).astype(np.int32)
  labels = np.eye(2)[labels]

  # For the loss calculation we create a mask for the nodes, which masks the
  # the constraint nodes and the padding nodes.
  mask = (np.arange(max_nodes) < n_literals).astype(np.int32)
  return Problem(graph=graph, labels=labels, mask=mask)
Exemplo n.º 11
0
  def _loss(
      self, **graph: Mapping[str, chex.ArrayTree]) -> chex.ArrayTree:

    graph = jraph.GraphsTuple(**graph)
    model_instance = model.GraphPropertyEncodeProcessDecode(
        loss_config=self._construct_loss_config(), **self.config.model)
    loss, scalars = model_instance.get_loss(graph)
    return loss, scalars
Exemplo n.º 12
0
 def _forward(self, graph: jraph.GraphsTuple, is_training: bool):
     input_graph = jraph.GraphsTuple(*graph)
     with hk.experimental.name_scope("encoder_scope"):
         graph = self._encoder(graph, is_training)
     with hk.experimental.name_scope("processor_scope"):
         graph = self._processor(graph, is_training)
     with hk.experimental.name_scope("decoder_scope"):
         out = self._decoder(graph, input_graph, is_training)
     return out
Exemplo n.º 13
0
def get_random_graph() -> jraph.GraphsTuple:
    return jraph.GraphsTuple(
        n_node=np.asarray([NUM_NODES]),
        n_edge=np.asarray([NUM_EDGES]),
        nodes=np.random.normal(size=[NUM_NODES, EMBEDDING_SIZE]),
        edges=np.random.normal(size=[NUM_EDGES, EMBEDDING_SIZE]),
        globals=None,
        senders=np.random.randint(0, NUM_NODES, [NUM_EDGES]),
        receivers=np.random.randint(0, NUM_NODES, [NUM_EDGES]))
Exemplo n.º 14
0
def convert_to_graphstuple(graph):
    """Converts a dataset to one entire jraph.GraphsTuple, extracting labels."""
    return jraph.GraphsTuple(
        nodes=np.asarray(graph.node_features),
        edges=np.ones_like(graph.senders),
        senders=np.asarray(graph.senders),
        receivers=np.asarray(graph.receivers),
        globals=np.zeros(1),
        n_node=np.asarray([graph.num_nodes()]),
        n_edge=np.asarray([graph.num_edges()]),
    ), np.asarray(graph.node_labels)
Exemplo n.º 15
0
 def get_fake_batch(self, hps):
   assert 'input_node_shape' in hps and 'input_edge_shape' in hps
   graph = jraph.GraphsTuple(
       n_node=jnp.asarray([1]),
       n_edge=jnp.asarray([1]),
       nodes=jnp.ones((1,) + hps.input_node_shape),
       edges=jnp.ones((1,) + hps.input_edge_shape),
       globals=jnp.zeros((1,) + hps.output_shape),
       senders=jnp.asarray([0]),
       receivers=jnp.asarray([0]))
   # We need to wrap the GraphsTuple in a list so that it can be passed as
   # *inputs to the model init function.
   return [graph]
Exemplo n.º 16
0
def _to_jraph(example, add_bidirectional_edges, add_virtual_node,
              add_self_loops):
    """Converts an example graph to jraph.GraphsTuple."""
    example = data_utils.tf_to_numpy(example)
    edge_feat = example['edge_feat']
    node_feat = example['node_feat']
    edge_index = example['edge_index']
    labels = example['labels']
    num_nodes = np.squeeze(example['num_nodes'])
    num_edges = len(edge_index)

    senders = edge_index[:, 0]
    receivers = edge_index[:, 1]

    new_senders, new_receivers = senders[:], receivers[:]

    if add_bidirectional_edges:
        new_senders = np.concatenate([senders, receivers])
        new_receivers = np.concatenate([receivers, senders])
        edge_feat = np.concatenate([edge_feat, edge_feat])
        num_edges *= 2

    if add_self_loops:
        new_senders = np.concatenate([new_senders, np.arange(num_nodes)])
        new_receivers = np.concatenate([new_receivers, np.arange(num_nodes)])
        edge_feat = np.concatenate(
            [edge_feat, np.zeros((num_nodes, edge_feat.shape[-1]))])
        num_edges += num_nodes

    if add_virtual_node:
        node_feat = np.concatenate(
            [node_feat, np.zeros_like(node_feat[0, None])])
        new_senders = np.concatenate([new_senders, np.arange(num_nodes)])
        new_receivers = np.concatenate(
            [new_receivers, np.full((num_nodes, ), num_nodes)])
        edge_feat = np.concatenate(
            [edge_feat, np.zeros((num_nodes, edge_feat.shape[-1]))])
        num_edges += num_nodes
        num_nodes += 1

    return jraph.GraphsTuple(
        n_node=np.array([num_nodes]),
        n_edge=np.array([num_edges]),
        nodes=node_feat,
        edges=edge_feat,
        senders=new_senders,
        receivers=new_receivers,
        # Keep the labels with the graph for batching. They will be removed
        # in the processed batch.
        globals=np.expand_dims(labels, axis=0))
Exemplo n.º 17
0
def broadcasted_sharded_graphs_tuple_to_graphs_tuple(sharded_graphs_tuple):
    """Converts a broadcasted ShardedGraphsTuple to a GraphsTuple."""
    # We index the first element of replicated arrays, since they have been
    # repeated. For edges, we reshape to recover all of the edge features.
    unbroadcast = lambda y: tree.tree_map(lambda x: x[0], y)
    unshard = lambda x: jnp.reshape(x,
                                    (x.shape[0] * x.shape[1], ) + x.shape[2:])
    # TODO(jonathangodwin): check senders and receivers are consistent.
    return jraph.GraphsTuple(nodes=unbroadcast(sharded_graphs_tuple.nodes),
                             edges=tree.tree_map(
                                 unshard, sharded_graphs_tuple.device_edges),
                             n_node=sharded_graphs_tuple.n_node[0],
                             n_edge=sharded_graphs_tuple.n_edge[0],
                             globals=unbroadcast(sharded_graphs_tuple.globals),
                             senders=sharded_graphs_tuple.senders[0],
                             receivers=sharded_graphs_tuple.receivers[0])
Exemplo n.º 18
0
def _get_graphs_from_n_edge(n_edge):
    """Get a graphs tuple from n_edge."""
    graphs = []
    for el in n_edge:
        graphs.append(
            jraph.GraphsTuple(
                nodes=np.random.uniform(size=(128, 2)),
                edges=np.random.uniform(size=(el, 2)),
                senders=np.random.choice(128, el),
                receivers=np.random.choice(128, el),
                n_edge=np.array([el]),
                n_node=np.array([128]),
                globals=np.array([[el]]),
            ))
    graphs = utils.batch_np(graphs)
    return graphs
Exemplo n.º 19
0
def create_jraph(data_path, dataset):
    """Creates a jraph graph for a dataset."""
    adj, features, labels = load_from_npz(data_path, dataset)
    edges, n_edge = get_graph_edges(adj, np.array(features))
    n_node = len(features)
    features = jnp.asarray(features)
    graph = jraph.GraphsTuple(n_node=jnp.asarray([n_node]),
                              n_edge=jnp.asarray([n_edge]),
                              nodes=features,
                              edges=None,
                              globals=None,
                              senders=jnp.asarray([edge[0] for edge in edges]),
                              receivers=jnp.asarray(
                                  [edge[1] for edge in edges]))

    return graph, np.asarray(labels), labels.shape[1]
Exemplo n.º 20
0
def _convert_ogb_graph_to_graphs_tuple(ogb_graph):
    """Converts an OGB Graph to a GraphsTuple."""
    senders = ogb_graph["edge_index"][0]
    receivers = ogb_graph["edge_index"][1]
    edges = ogb_graph["edge_feat"]
    nodes = ogb_graph["node_feat"]
    n_node = np.array([ogb_graph["num_nodes"]])
    n_edge = np.array([len(senders)])
    graph = jraph.GraphsTuple(nodes=nodes,
                              edges=edges,
                              senders=senders,
                              receivers=receivers,
                              n_node=n_node,
                              n_edge=n_edge,
                              globals=None)
    return tree.map_structure(lambda x: x
                              if x is not None else np.array(0.), graph)
Exemplo n.º 21
0
def pad_graphs_by_device(graphs: List[jraph.GraphsTuple]) -> jraph.GraphsTuple:
  """Pad and concatenate the list of graphs.

  Each graph in the list is padded according to the maximum n_nodes and n_edges
  in the list, such that all graphs have the same length. Then they are
  concatenated. This is need for pmap.

  Args:
    graphs: a list of graphs.

  Returns:
    graph: a single padded and merged graph.
  """
  # Add at least one extra node to the placeholder graph.
  pad_n_nodes = pad_size(max([g.nodes.shape[0] for g in graphs]) + 1)
  pad_n_edges = pad_size(max([g.edges.shape[0] for g in graphs]))
  padded_graphs = [pad_graphs(g, pad_n_nodes, pad_n_edges) for g in graphs]
  nodes = []
  edges = []
  senders = []
  receivers = []
  n_node = []
  n_edge = []
  for g in padded_graphs:
    assert g.nodes.shape[0] == pad_n_nodes
    assert g.edges.shape[0] == pad_n_edges
    assert g.senders.size == pad_n_edges
    assert g.receivers.size == pad_n_edges
    assert g.n_node.size == padded_graphs[0].n_node.size
    assert g.n_edge.size == padded_graphs[0].n_edge.size
    nodes.append(g.nodes)
    edges.append(g.edges)
    senders.append(g.senders)
    receivers.append(g.receivers)
    n_node.append(g.n_node)
    n_edge.append(g.n_edge)

  return jraph.GraphsTuple(
      nodes=np.concatenate(nodes, axis=0),
      edges=np.concatenate(edges, axis=0),
      senders=np.concatenate(senders, axis=0),
      receivers=np.concatenate(receivers, axis=0),
      n_node=np.concatenate(n_node, axis=0),
      n_edge=np.concatenate(n_edge, axis=0),
      globals=None)
Exemplo n.º 22
0
def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple:
  # Calculates offsets for sender and receiver arrays, caused by concatenating
  # the nodes arrays.
  offsets = np.cumsum(np.array([0] + [np.sum(g.n_node) for g in graphs[:-1]]))

  def _map_concat(nests):
    concat = lambda *args: np.concatenate(args)
    return tree.tree_multimap(concat, *nests)

  return jraph.GraphsTuple(
      n_node=np.concatenate([g.n_node for g in graphs]),
      n_edge=np.concatenate([g.n_edge for g in graphs]),
      nodes=_map_concat([g.nodes for g in graphs]),
      edges=_map_concat([g.edges for g in graphs]),
      globals=_map_concat([g.globals for g in graphs]),
      senders=np.concatenate([g.senders + o for g, o in zip(graphs, offsets)]),
      receivers=np.concatenate(
          [g.receivers + o for g, o in zip(graphs, offsets)]))
Exemplo n.º 23
0
def build_hookes_particle_state_graph(num_particles: int) -> jraph.GraphsTuple:
    """Generates a graph representing a Hooke's system in a random state."""

    mass = np.random.uniform(0, 5, [num_particles])
    velocity = get_random_uniform_norm2d_vectors(0, 0.1, num_particles)
    position = get_random_uniform_norm2d_vectors(0, 1, num_particles)
    momentum = velocity * np.expand_dims(mass, axis=-1)
    # Remove average momentum, so center of mass does not move.
    momentum = momentum - momentum.mean(0, keepdims=True)

    # Connect all particles to all particles.
    particle_indices = np.arange(num_particles)
    senders, receivers = np.meshgrid(particle_indices, particle_indices)
    senders, receivers = senders.flatten(), receivers.flatten()

    # Generate a symmetric random matrix of spring constants.
    # Generate random elements stringly in the lower triangular part.
    spring_constants = np.random.uniform(1e-2, 1e-1,
                                         [num_particles, num_particles])
    spring_constants = np.tril(spring_constants) + np.tril(
        spring_constants, -1).T
    spring_constants = spring_constants.flatten()

    # Remove interactions of particles to themselves.
    mask = senders != receivers
    senders, receivers = senders[mask], receivers[mask]
    spring_constants = spring_constants[mask]
    num_interactions = receivers.shape[0]

    return jraph.GraphsTuple(
        n_node=np.asarray([num_particles]),
        n_edge=np.asarray([num_interactions]),
        nodes={
            "mass": mass,  # Scalar mass for each particle.
            "position": position,  # 2d position for each particle.
            "momentum": momentum,  # 2d momentum for each particle.
        },
        edges={
            # Scalar spring constant for each interaction
            "spring_constant": spring_constants,
        },
        globals={},
        senders=senders,
        receivers=receivers)
Exemplo n.º 24
0
 def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
     rng, params_rng, dropout_rng = jax.random.split(rng, 3)
     init_fn = jax.jit(functools.partial(self._model.init, train=False))
     fake_batch = jraph.GraphsTuple(n_node=jnp.asarray([1]),
                                    n_edge=jnp.asarray([1]),
                                    nodes=jnp.ones((1, 3)),
                                    edges=jnp.ones((1, 7)),
                                    globals=jnp.zeros(
                                        (1, self._num_outputs)),
                                    senders=jnp.asarray([0]),
                                    receivers=jnp.asarray([0]))
     params = init_fn({
         'params': params_rng,
         'dropout': dropout_rng
     }, fake_batch)
     params = params['params']
     self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape),
                                       params)
     return jax_utils.replicate(params), None
Exemplo n.º 25
0
def to_jraph(neighbor: NeighborList, mask: Array = None) -> jraph.GraphsTuple:
    """Convert a sparse neighbor list to a `jraph.GraphsTuple`.

  As in jraph, padding here is accomplished by adding a ficticious graph with a
  single node.

  Args:
    neighbor: A neighbor list that we will convert to the jraph format. Must be
      sparse.
    mask: An optional mask on the edges.

  Returns:
    A `jraph.GraphsTuple` that contains the topology of the neighbor list.
  """
    if not is_sparse(neighbor.format):
        raise ValueError(
            'Cannot convert a dense neighbor list to jraph format. '
            'Please use either NeighborListFormat.Sparse or '
            'NeighborListFormat.OrderedSparse.')

    receivers, senders = neighbor.idx
    N = len(neighbor.reference_position)

    _mask = neighbor_list_mask(neighbor)

    if mask is not None:
        _mask = _mask & mask
        cumsum = jnp.cumsum(_mask)
        index = jnp.where(_mask, cumsum - 1, len(receivers))
        ordered = N * jnp.ones((len(receivers) + 1, ), i32)
        receivers = ordered.at[index].set(receivers)[:-1]
        senders = ordered.at[index].set(senders)[:-1]
        mask = receivers < N

    return jraph.GraphsTuple(
        nodes=None,
        edges=None,
        receivers=receivers,
        senders=senders,
        globals=None,
        n_node=jnp.array([N, 1]),
        n_edge=jnp.array([jnp.sum(_mask), jnp.sum(~_mask)]),
    )
Exemplo n.º 26
0
    def test_graph2text_sampler_runs(self):
        graphs = jraph.GraphsTuple(
            nodes=np.ones((4, 3), dtype=np.float32),
            edges=np.ones((3, 1), dtype=np.float32),
            senders=np.array([0, 2, 3], dtype=np.int32),
            receivers=np.array([1, 3, 2], dtype=np.int32),
            n_node=np.array([2, 2], dtype=np.int32),
            n_edge=np.array([1, 2], dtype=np.int32),
            globals=None,
        )
        prompt = np.array([[0, 1, 2, -1, -1, -1], [0, 1, 2, -1, -1, -1]],
                          dtype=np.int32)
        vocab_size = prompt.max() + 1
        bos_token = 0
        memory_size = 2
        params = None

        def model_fn(graphs, max_graph_size, x):
            return models.Graph2TextTransformer(vocab_size=vocab_size,
                                                emb_dim=8,
                                                num_layers=2,
                                                num_heads=4,
                                                cutoffs=[],
                                                gnn_embed_dim=8,
                                                gnn_num_layers=2)(
                                                    graphs,
                                                    max_graph_size,
                                                    True,
                                                    x,
                                                    is_training=False,
                                                    cache_steps=memory_size)

        graph_sampler = sampler.Graph2TextTransformerSampler(model_fn)
        sample = graph_sampler.sample(params, prompt, graphs)
        self.assertTrue((sample[:, 0] == bos_token).all())
        self.assertTrue((sample != -1).all())
        self.assertEqual(sample.shape, prompt.shape)
        sample2 = graph_sampler.sample(params, prompt, graphs)
        self.assertTrue((sample2[:, 0] == bos_token).all())
        self.assertTrue((sample2 != -1).all())
        self.assertEqual(sample2.shape, prompt.shape)
        self.assertTrue((sample != sample2).any())
Exemplo n.º 27
0
  def test_graph_conditioned_transformer_runs(self):
    graphs = jraph.GraphsTuple(
        nodes=np.ones((4, 3), dtype=np.float32),
        edges=np.ones((3, 1), dtype=np.float32),
        senders=np.array([0, 2, 3], dtype=np.int32),
        receivers=np.array([1, 3, 2], dtype=np.int32),
        n_node=np.array([2, 2], dtype=np.int32),
        n_edge=np.array([1, 2], dtype=np.int32),
        globals=None,
        )
    seqs = np.array([[1, 1, 0],
                     [2, 2, 2]], dtype=np.int32)
    vocab_size = seqs.max() + 1
    embed_dim = 8

    x = seqs[:, :-1]
    y = seqs[:, 1:]

    def forward(graphs, inputs, labels):
      graphs = models.GraphEmbeddingModel(embed_dim=embed_dim,
                                          num_layers=2)(graphs)
      extra, extra_mask = models.unpack_and_pad(graphs.nodes,
                                                graphs.n_node,
                                                graphs.n_node.max())
      input_mask = (labels != 0).astype(jnp.float32)
      transformer = models.TransformerXL(vocab_size=vocab_size,
                                         emb_dim=embed_dim,
                                         num_layers=2,
                                         num_heads=4,
                                         cutoffs=[])
      return transformer.loss(inputs, labels, mask=input_mask, extra=extra,
                              extra_mask=extra_mask)

    init_fn, apply_fn = hk.transform_with_state(forward)
    key = hk.PRNGSequence(8)
    params, state = init_fn(next(key), graphs, x, y)
    out, _ = apply_fn(params, state, next(key), graphs, x, y)
    loss, metrics = out

    logging.info('loss: %g', loss)
    logging.info('metrics: %r', metrics)
Exemplo n.º 28
0
def make_graph_from_static_structure(positions, types, box, edge_threshold):
    """Returns graph representing the static structure of the glass.

  Each particle is represented by a node in the graph. The particle type is
  stored as a node feature.
  Two particles at a distance less than the threshold are connected by an edge.
  The relative distance vector is stored as an edge feature.

  Args:
    positions: particle positions with shape [n_particles, 3].
    types: particle types with shape [n_particles].
    box: dimensions of the cubic box that contains the particles with shape [3].
    edge_threshold: particles at distance less than threshold are connected by
      an edge.
  """
    # Calculate pairwise relative distances between particles: shape [n, n, 3].
    cross_positions = positions[None, :, :] - positions[:, None, :]
    # Enforces periodic boundary conditions.
    box_ = box[None, None, :]
    cross_positions += (cross_positions < -box_ / 2.).astype(np.float32) * box_
    cross_positions -= (cross_positions > box_ / 2.).astype(np.float32) * box_
    # Calculates adjacency matrix in a sparse format (indices), based on the given
    # distances and threshold.
    distances = np.linalg.norm(cross_positions, axis=-1)
    indices = np.where(distances < edge_threshold)
    # Defines graph.
    nodes = types[:, None]
    senders = indices[0]
    receivers = indices[1]
    edges = cross_positions[indices]

    return jraph.pad_with_graphs(jraph.GraphsTuple(
        nodes=nodes.astype(np.float32),
        n_node=np.reshape(nodes.shape[0], [1]),
        edges=edges.astype(np.float32),
        n_edge=np.reshape(edges.shape[0], [1]),
        globals=np.zeros((1, 1), dtype=np.float32),
        receivers=receivers.astype(np.int32),
        senders=senders.astype(np.int32)),
                                 n_node=4097,
                                 n_edge=200000)
Exemplo n.º 29
0
def get_higgs_problem(min_n_photons: int, max_n_photons: int) -> Problem:
  """Creates fully connected graph containing the detected photons.

  Args:
    min_n_photons: minimum number of photons in the detector.
    max_n_photons: maximum number of photons in the detector.

  Returns:
    graph, one-hot label whether a higgs was present or not.
  """
  assert min_n_photons >= 2, "Number of photons must be at least 2."
  n_photons = random.randint(min_n_photons, max_n_photons)
  photons = np.stack([get_random_background_photon() for _ in range(n_photons)])

  # Add a higgs
  if random.random() > 0.5:
    label = np.eye(2)[0]
    photons[:2] = np.stack(get_random_higgs_photons())
  else:
    label = np.eye(2)[1]

  # The graph is fully connected.
  senders = np.repeat(np.arange(n_photons), n_photons)
  receivers = np.tile(np.arange(n_photons), n_photons)
  graph = jraph.GraphsTuple(
      n_node=np.asarray([n_photons]),
      n_edge=np.asarray([len(senders)]),
      nodes=photons,
      edges=None,
      globals=None,
      senders=senders,
      receivers=receivers)

  # In order to jit compile our code, we have to pad the nodes and edges of
  # the GraphsTuple to a static shape.
  graph = jraph.pad_with_graphs(graph, max_n_photons + 1,
                                max_n_photons * max_n_photons)

  return Problem(graph=graph, labels=label)
def _to_jraph(example):
    """Converts an example graph to jraph.GraphsTuple."""
    example = jax.tree_map(lambda x: x._numpy(), example)  # pylint: disable=protected-access
    edge_feat = example['edge_feat']
    node_feat = example['node_feat']
    edge_index = example['edge_index']
    labels = example['labels']
    num_nodes = example['num_nodes']

    senders = edge_index[:, 0]
    receivers = edge_index[:, 1]

    return jraph.GraphsTuple(
        n_node=num_nodes,
        n_edge=np.array([len(edge_index) * 2]),
        nodes=node_feat,
        edges=np.concatenate([edge_feat, edge_feat]),
        # Make the edges bidirectional
        senders=np.concatenate([senders, receivers]),
        receivers=np.concatenate([receivers, senders]),
        # Keep the labels with the graph for batching. They will be removed
        # in the processed batch.
        globals=np.expand_dims(labels, axis=0))