def test_interpolate_graphs_tuple_endpoints(self, n_steps):
   """Check that our interpolation matches at endpoints."""
   several_ends, _, tensorizer = self._setup_graphs()
   for end in graph_utils.split_graphs_tuple(several_ends):
     start = graph_utils.make_constant_like(end,
                                            *tensorizer.get_null_vectors())
     interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps)
     start_interp = graph_utils.get_graphs_tf(interp, np.array([0]))
     end_interp = graph_utils.get_graphs_tf(interp, np.array([n_steps - 1]))
     self.assertEqualGraphsTuple(start, start_interp)
     self.assertEqualGraphsTuple(end, end_interp)
 def test_get_graphs_tf(self, indices):
   """Check that we can split graphtuples into a list of graphs."""
   graphs, _, _ = self._setup_graphs()
   sub_graphs = graph_utils.get_graphs_tf(graphs, np.array(indices))
   graph_list = [graph_nets.utils_tf.get_graph(graphs, i) for i in indices]
   expected_graphs = graph_nets.utils_tf.concat(graph_list, axis=0)
   self.assertEqualGraphsTuple(sub_graphs, expected_graphs)
 def test_get_graphs_tf_noedges(self):
   """Check that we can split graphtuples with no edge information."""
   indices = [1, 2, 5, 6]
   graphs, _, _ = self._setup_graphs()
   graphs = graphs.replace(edges=None)
   sub_graphs = graph_utils.get_graphs_tf(graphs, np.array(indices))
   graph_list = [graph_nets.utils_tf.get_graph(graphs, i) for i in indices]
   expected_graphs = graph_nets.utils_tf.concat(graph_list, axis=0)
   self.assertEqualGraphsTuple(sub_graphs, expected_graphs)
 def test_perturb_graphs_tuple_zero(self, n_samples):
   """When sigma is zero, graphs should be the same."""
   graphs, _, _ = self._setup_graphs()
   sigma = 0.0
   n_graphs = graph_utils.get_num_graphs(graphs)
   noisy_graphs = graph_utils.perturb_graphs_tuple(graphs, n_samples, sigma)
   for index in range(n_samples):
     indices = np.arange(index * n_graphs, (index + 1) * n_graphs)
     sub_graphs = graph_utils.get_graphs_tf(noisy_graphs, indices)
     np.testing.assert_allclose(graphs.nodes, sub_graphs.nodes)
     np.testing.assert_allclose(graphs.edges, sub_graphs.edges)
 def test_interpolate_graphs_tuple_batch(self, n_steps):
   """Check that interpolated graphs are same, irrepective if batched."""
   end, _, tensorizer = self._setup_graphs()
   n_graphs = graph_utils.get_num_graphs(end)
   start = graph_utils.make_constant_like(end, *tensorizer.get_null_vectors())
   interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps)
   start_iter = graph_utils.split_graphs_tuple(start)
   end_iter = graph_utils.split_graphs_tuple(end)
   for i, (start_i, end_i) in enumerate(zip(start_iter, end_iter)):
     indices = np.arange(0, n_steps * n_graphs, n_graphs) + i
     actual = graph_utils.get_graphs_tf(interp, indices)
     expected, _, _ = graph_utils.interpolate_graphs_tuple(
         start_i, end_i, n_steps)
     self.assertEqualGraphsTuple(expected, actual)