def test_apply_random_rotation(self): graphs_tuple = self._get_graphs_tuple() rotated_graphs_tuple_op = graph_model.apply_random_rotation(graphs_tuple) rotated_graphs_tuple = self.evaluate(rotated_graphs_tuple_op) np.testing.assert_almost_equal(rotated_graphs_tuple.nodes, self._nodes) np.testing.assert_almost_equal(rotated_graphs_tuple.senders, self._senders) np.testing.assert_almost_equal( rotated_graphs_tuple.receivers, self._receivers) np.testing.assert_almost_equal( rotated_graphs_tuple.globals, np.array([[0.0]])) self.assertTrue(self._is_equal_up_to_rotation(rotated_graphs_tuple.edges, self._edges))
def _apply_random_rotation(graph, targets, types): """Applies random rotations to the graph and forwards targets and types.""" return graph_model.apply_random_rotation(graph), targets, types