def test_name_scope(self): """Tests that the name scope are correctly pushed through this function.""" graph = self._graph with tf.name_scope("test"): graph_id = utils_tf.identity(graph) for field in [ "nodes", "edges", "globals", "receivers", "senders", "n_node", "n_edge" ]: self.assertEqual("test", getattr(graph_id, field).name.split("/")[0])
def test_output(self, none_fields): """Tests that this function produces the identity.""" graph = self._graph.map(lambda _: None, none_fields) with tf.name_scope("test"): graph_id = utils_tf.identity(graph) expected_out = utils_tf.nest_to_numpy(graph) actual_out = utils_tf.nest_to_numpy(graph_id) for field in [ "nodes", "edges", "globals", "receivers", "senders", "n_node", "n_edge" ]: if field in none_fields: self.assertEqual(None, getattr(actual_out, field)) else: self.assertNDArrayNear( getattr(expected_out, field), getattr(actual_out, field), err=1e-4)
def test_output(self, none_fields): """Tests that this function produces the identity.""" graph = self._graph.map(lambda _: None, none_fields) with tf.compat.v1.name_scope("test"): graph_id = utils_tf.identity(graph) graph = utils_tf.make_runnable_in_session(graph) graph_id = utils_tf.make_runnable_in_session(graph_id) with self.test_session() as sess: expected_out, actual_out = sess.run([graph, graph_id]) for field in [ "nodes", "edges", "globals", "receivers", "senders", "n_node", "n_edge" ]: if field in none_fields: self.assertEqual(None, getattr(actual_out, field)) else: self.assertNDArrayNear(getattr(expected_out, field), getattr(actual_out, field), err=1e-4)