コード例 #1
0
  def test_graph_runner(self):
    graph_runner = tf_utils.TFGraphRunner()

    output = graph_runner.run(tf.nn.relu, [1, 1, -1, -1, 1])
    self.assertAllEqual(output, [1, 1, 0, 0, 1])

    output = graph_runner.run(tf.nn.relu, [-1, -1, -1, 1, 1])
    self.assertAllEqual(output, [0, 0, 0, 1, 1])

    # Cache should have been re-used, so should only contains one GraphRun
    # Ideally there should be two separate @tf.eager.run_test_in_graph() and
    # @tf.eager.run_test_in_eager() to avoid logic on the test. But haven't
    # found it.
    if not tf.executing_eagerly():
      self.assertEqual(len(graph_runner._graph_run_cache), 1)
    else:
      self.assertEqual(len(graph_runner._graph_run_cache), 0)

    # Different signature (different shape), so new GraphRun created
    output = graph_runner.run(tf.nn.relu, [-1, 1, 1])
    self.assertAllEqual(output, [0, 1, 1])
    if not tf.executing_eagerly():
      self.assertEqual(len(graph_runner._graph_run_cache), 2)
    else:
      self.assertEqual(len(graph_runner._graph_run_cache), 0)
コード例 #2
0
ファイル: image_utils.py プロジェクト: thisiscam/datasets
def _get_runner():
    return tf_utils.TFGraphRunner()