def test_compute_stacked_offsets(self): sizes = np.array([5, 4, 3, 1, 2, 0, 3, 0, 4, 7]) repeats = [2, 2, 0, 2, 1, 3, 2, 0, 3, 2] offsets0 = utils_np._compute_stacked_offsets(sizes, repeats) offsets1 = utils_np._compute_stacked_offsets(sizes, np.array(repeats)) expected_offsets = [ 0, 0, 5, 5, 12, 12, 13, 15, 15, 15, 15, 15, 18, 18, 18, 22, 22 ] self.assertAllEqual(expected_offsets, offsets0.tolist()) self.assertAllEqual(expected_offsets, offsets1.tolist())
def populate_test_data(self, max_size): """Populates the class fields with data used for the tests. This creates a batch of graphs with number of nodes from 0 to `num`, number of edges ranging from 1 to `num`, plus an empty graph with no nodes and no edges (so that the total number of graphs is 1 + (num ** (num + 1)). The nodes states, edges states and global states of the graphs are created to have different types and shapes. Those graphs are stored both as dictionaries (in `self.graphs_dicts_in`, without `n_node` and `n_edge` information, and in `self.graphs_dicts_out` with these two fields filled), and a corresponding numpy `graphs.GraphsTuple` is stored in `self.reference_graph`. Args: max_size: The maximum number of nodes and edges (inclusive). """ filt = lambda x: (x[0] > 0) or (x[1] == 0) n_node, n_edge = zip(*list( filter(filt, itertools.product( range(max_size + 1), range(max_size + 1))))) graphs_dicts = [] nodes = [] edges = [] receivers = [] senders = [] globals_ = [] def _make_default_state(shape, dtype): return np.arange(np.prod(shape)).reshape(shape).astype(dtype) for i, (n_node_, n_edge_) in enumerate(zip(n_node, n_edge)): n = _make_default_state([n_node_, 7, 11], "f4") + i * 100. e = _make_default_state([n_edge_, 13, 14], np.float64) + i * 100. + 1000. r = _make_default_state([n_edge_], np.int32) % n_node[i] s = (_make_default_state([n_edge_], np.int32) + 1) % n_node[i] g = _make_default_state([5, 3], "f4") - i * 100. - 1000. nodes.append(n) edges.append(e) receivers.append(r) senders.append(s) globals_.append(g) graphs_dict = dict(nodes=n, edges=e, receivers=r, senders=s, globals=g) graphs_dicts.append(graphs_dict) # Graphs dicts without n_node / n_edge (to be used as inputs). self.graphs_dicts_in = graphs_dicts # Graphs dicts with n_node / n_node (to be checked against outputs). self.graphs_dicts_out = [] for dict_ in self.graphs_dicts_in: completed_dict = dict_.copy() completed_dict["n_node"] = completed_dict["nodes"].shape[0] completed_dict["n_edge"] = completed_dict["edges"].shape[0] self.graphs_dicts_out.append(completed_dict) # pylint: disable=protected-access offset = utils_np._compute_stacked_offsets(n_node, n_edge) # pylint: enable=protected-access self.reference_graph = graphs.GraphsTuple(**dict( nodes=np.concatenate(nodes, axis=0), edges=np.concatenate(edges, axis=0), receivers=np.concatenate(receivers, axis=0) + offset, senders=np.concatenate(senders, axis=0) + offset, globals=np.stack(globals_), n_node=np.array(n_node), n_edge=np.array(n_edge)))