def test_make_constant_like(self):
   """Check constant graph has same shape, and constant nodes/edges."""
   graphs, _, tensorizer = self._setup_graphs()
   node_vec, edge_vec = tensorizer.get_null_vectors()
   const_graphs = graph_utils.make_constant_like(graphs, node_vec, edge_vec)
   self.assertEqual(graphs.nodes.shape, const_graphs.nodes.shape)
   self.assertEqual(graphs.edges.shape, const_graphs.edges.shape)
   self.assertTrue(all(np.allclose(node_vec, x) for x in const_graphs.nodes))
   self.assertTrue(all(np.allclose(edge_vec, x) for x in const_graphs.edges))
 def test_interpolate_graphs_tuple_endpoints(self, n_steps):
   """Check that our interpolation matches at endpoints."""
   several_ends, _, tensorizer = self._setup_graphs()
   for end in graph_utils.split_graphs_tuple(several_ends):
     start = graph_utils.make_constant_like(end,
                                            *tensorizer.get_null_vectors())
     interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps)
     start_interp = graph_utils.get_graphs_tf(interp, np.array([0]))
     end_interp = graph_utils.get_graphs_tf(interp, np.array([n_steps - 1]))
     self.assertEqualGraphsTuple(start, start_interp)
     self.assertEqualGraphsTuple(end, end_interp)
 def test_interp_array(self, n_steps):
   """Check that our interpolation works on arrays."""
   graphs, _, tensorizer = self._setup_graphs()
   ref = graph_utils.make_constant_like(graphs, *tensorizer.get_null_vectors())
   start = ref.nodes
   end = graphs.nodes
   interp = graph_utils._interp_array(start, end, n_steps)  # pylint:disable=protected-access
   mean_arr = np.mean([start, end], axis=0)
   np.testing.assert_allclose(interp[0], start)
   np.testing.assert_allclose(interp[int((n_steps - 1) / 2)], mean_arr)
   np.testing.assert_allclose(interp[-1], end)
   self.assertEqual(interp.shape, (n_steps, start.shape[0], start.shape[1]))
 def test_interpolate_graphs_tuple_batch(self, n_steps):
   """Check that interpolated graphs are same, irrepective if batched."""
   end, _, tensorizer = self._setup_graphs()
   n_graphs = graph_utils.get_num_graphs(end)
   start = graph_utils.make_constant_like(end, *tensorizer.get_null_vectors())
   interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps)
   start_iter = graph_utils.split_graphs_tuple(start)
   end_iter = graph_utils.split_graphs_tuple(end)
   for i, (start_i, end_i) in enumerate(zip(start_iter, end_iter)):
     indices = np.arange(0, n_steps * n_graphs, n_graphs) + i
     actual = graph_utils.get_graphs_tf(interp, indices)
     expected, _, _ = graph_utils.interpolate_graphs_tuple(
         start_i, end_i, n_steps)
     self.assertEqualGraphsTuple(expected, actual)
 def test_interpolate_graphs_tuple_differences(self):
   """Check that our interpolation has constant differences between steps."""
   n_steps = 8
   several_ends, _, tensorizer = self._setup_graphs()
   for end in graph_utils.split_graphs_tuple(several_ends):
     start = graph_utils.make_constant_like(end,
                                            *tensorizer.get_null_vectors())
     interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps)
     steps = list(graph_utils.split_graphs_tuple(interp))
     expected_nodes_diff = tf.divide(end.nodes - start.nodes, n_steps - 1)
     expected_edges_diff = tf.divide(end.edges - start.edges, n_steps - 1)
     for x_cur, x_next in zip(steps[:-1], steps[1:]):
       actual_nodes_diff = x_next.nodes - x_cur.nodes
       actual_edges_diff = x_next.edges - x_cur.edges
       np.testing.assert_allclose(
           expected_nodes_diff, actual_nodes_diff, atol=1e-7)
       np.testing.assert_allclose(
           expected_edges_diff, actual_edges_diff, atol=1e-7)