def test_make_constant_like(self): """Check constant graph has same shape, and constant nodes/edges.""" graphs, _, tensorizer = self._setup_graphs() node_vec, edge_vec = tensorizer.get_null_vectors() const_graphs = graph_utils.make_constant_like(graphs, node_vec, edge_vec) self.assertEqual(graphs.nodes.shape, const_graphs.nodes.shape) self.assertEqual(graphs.edges.shape, const_graphs.edges.shape) self.assertTrue(all(np.allclose(node_vec, x) for x in const_graphs.nodes)) self.assertTrue(all(np.allclose(edge_vec, x) for x in const_graphs.edges))
def test_interpolate_graphs_tuple_endpoints(self, n_steps): """Check that our interpolation matches at endpoints.""" several_ends, _, tensorizer = self._setup_graphs() for end in graph_utils.split_graphs_tuple(several_ends): start = graph_utils.make_constant_like(end, *tensorizer.get_null_vectors()) interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps) start_interp = graph_utils.get_graphs_tf(interp, np.array([0])) end_interp = graph_utils.get_graphs_tf(interp, np.array([n_steps - 1])) self.assertEqualGraphsTuple(start, start_interp) self.assertEqualGraphsTuple(end, end_interp)
def test_interp_array(self, n_steps): """Check that our interpolation works on arrays.""" graphs, _, tensorizer = self._setup_graphs() ref = graph_utils.make_constant_like(graphs, *tensorizer.get_null_vectors()) start = ref.nodes end = graphs.nodes interp = graph_utils._interp_array(start, end, n_steps) # pylint:disable=protected-access mean_arr = np.mean([start, end], axis=0) np.testing.assert_allclose(interp[0], start) np.testing.assert_allclose(interp[int((n_steps - 1) / 2)], mean_arr) np.testing.assert_allclose(interp[-1], end) self.assertEqual(interp.shape, (n_steps, start.shape[0], start.shape[1]))
def test_interpolate_graphs_tuple_batch(self, n_steps): """Check that interpolated graphs are same, irrepective if batched.""" end, _, tensorizer = self._setup_graphs() n_graphs = graph_utils.get_num_graphs(end) start = graph_utils.make_constant_like(end, *tensorizer.get_null_vectors()) interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps) start_iter = graph_utils.split_graphs_tuple(start) end_iter = graph_utils.split_graphs_tuple(end) for i, (start_i, end_i) in enumerate(zip(start_iter, end_iter)): indices = np.arange(0, n_steps * n_graphs, n_graphs) + i actual = graph_utils.get_graphs_tf(interp, indices) expected, _, _ = graph_utils.interpolate_graphs_tuple( start_i, end_i, n_steps) self.assertEqualGraphsTuple(expected, actual)
def test_interpolate_graphs_tuple_differences(self): """Check that our interpolation has constant differences between steps.""" n_steps = 8 several_ends, _, tensorizer = self._setup_graphs() for end in graph_utils.split_graphs_tuple(several_ends): start = graph_utils.make_constant_like(end, *tensorizer.get_null_vectors()) interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps) steps = list(graph_utils.split_graphs_tuple(interp)) expected_nodes_diff = tf.divide(end.nodes - start.nodes, n_steps - 1) expected_edges_diff = tf.divide(end.edges - start.edges, n_steps - 1) for x_cur, x_next in zip(steps[:-1], steps[1:]): actual_nodes_diff = x_next.nodes - x_cur.nodes actual_edges_diff = x_next.edges - x_cur.edges np.testing.assert_allclose( expected_nodes_diff, actual_nodes_diff, atol=1e-7) np.testing.assert_allclose( expected_edges_diff, actual_edges_diff, atol=1e-7)