def test_compute_stacked_offsets(self):
     sizes = np.array([5, 4, 3, 1, 2, 0, 3, 0, 4, 7])
     repeats = [2, 2, 0, 2, 1, 3, 2, 0, 3, 2]
     offsets0 = utils_np._compute_stacked_offsets(sizes, repeats)
     offsets1 = utils_np._compute_stacked_offsets(sizes, np.array(repeats))
     expected_offsets = [
         0, 0, 5, 5, 12, 12, 13, 15, 15, 15, 15, 15, 18, 18, 18, 22, 22
     ]
     self.assertAllEqual(expected_offsets, offsets0.tolist())
     self.assertAllEqual(expected_offsets, offsets1.tolist())
Exemplo n.º 2
0
  def populate_test_data(self, max_size):
    """Populates the class fields with data used for the tests.

    This creates a batch of graphs with number of nodes from 0 to `num`,
    number of edges ranging from 1 to `num`, plus an empty graph with no nodes
    and no edges (so that the total number of graphs is 1 + (num ** (num + 1)).

    The nodes states, edges states and global states of the graphs are
    created to have different types and shapes.

    Those graphs are stored both as dictionaries (in `self.graphs_dicts_in`,
    without `n_node` and `n_edge` information, and in `self.graphs_dicts_out`
    with these two fields filled), and a corresponding numpy
    `graphs.GraphsTuple` is stored in `self.reference_graph`.

    Args:
      max_size: The maximum number of nodes and edges (inclusive).
    """
    filt = lambda x: (x[0] > 0) or (x[1] == 0)
    n_node, n_edge = zip(*list(
        filter(filt, itertools.product(
            range(max_size + 1), range(max_size + 1)))))

    graphs_dicts = []
    nodes = []
    edges = []
    receivers = []
    senders = []
    globals_ = []

    def _make_default_state(shape, dtype):
      return np.arange(np.prod(shape)).reshape(shape).astype(dtype)

    for i, (n_node_, n_edge_) in enumerate(zip(n_node, n_edge)):
      n = _make_default_state([n_node_, 7, 11], "f4") + i * 100.
      e = _make_default_state([n_edge_, 13, 14], np.float64) + i * 100. + 1000.
      r = _make_default_state([n_edge_], np.int32) % n_node[i]
      s = (_make_default_state([n_edge_], np.int32) + 1) % n_node[i]
      g = _make_default_state([5, 3], "f4") - i * 100. - 1000.

      nodes.append(n)
      edges.append(e)
      receivers.append(r)
      senders.append(s)
      globals_.append(g)
      graphs_dict = dict(nodes=n, edges=e, receivers=r, senders=s, globals=g)
      graphs_dicts.append(graphs_dict)

    # Graphs dicts without n_node / n_edge (to be used as inputs).
    self.graphs_dicts_in = graphs_dicts
    # Graphs dicts with n_node / n_node (to be checked against outputs).
    self.graphs_dicts_out = []
    for dict_ in self.graphs_dicts_in:
      completed_dict = dict_.copy()
      completed_dict["n_node"] = completed_dict["nodes"].shape[0]
      completed_dict["n_edge"] = completed_dict["edges"].shape[0]
      self.graphs_dicts_out.append(completed_dict)

    # pylint: disable=protected-access
    offset = utils_np._compute_stacked_offsets(n_node, n_edge)
    # pylint: enable=protected-access
    self.reference_graph = graphs.GraphsTuple(**dict(
        nodes=np.concatenate(nodes, axis=0),
        edges=np.concatenate(edges, axis=0),
        receivers=np.concatenate(receivers, axis=0) + offset,
        senders=np.concatenate(senders, axis=0) + offset,
        globals=np.stack(globals_),
        n_node=np.array(n_node),
        n_edge=np.array(n_edge)))