def test_graph_runner(self): graph_runner = tf_utils.TFGraphRunner() output = graph_runner.run(tf.nn.relu, [1, 1, -1, -1, 1]) self.assertAllEqual(output, [1, 1, 0, 0, 1]) output = graph_runner.run(tf.nn.relu, [-1, -1, -1, 1, 1]) self.assertAllEqual(output, [0, 0, 0, 1, 1]) # Cache should have been re-used, so should only contains one GraphRun # Ideally there should be two separate @tf.eager.run_test_in_graph() and # @tf.eager.run_test_in_eager() to avoid logic on the test. But haven't # found it. if not tf.executing_eagerly(): self.assertEqual(len(graph_runner._graph_run_cache), 1) else: self.assertEqual(len(graph_runner._graph_run_cache), 0) # Different signature (different shape), so new GraphRun created output = graph_runner.run(tf.nn.relu, [-1, 1, 1]) self.assertAllEqual(output, [0, 1, 1]) if not tf.executing_eagerly(): self.assertEqual(len(graph_runner._graph_run_cache), 2) else: self.assertEqual(len(graph_runner._graph_run_cache), 0)
def _get_runner(): return tf_utils.TFGraphRunner()