Exemplo n.º 1
  def test_runtime_error_rewriting(self):

    def g(x, s):
      while tf.reduce_sum(x) > s:
        x //= 0
      return x

    def test_fn(x):
      return g(x, 10)

    compiled_fn = ag.to_graph(test_fn)

    with self.assertRaises(ag.TfRuntimeError) as error:
      with self.cached_session() as sess:
        x = compiled_fn(tf.constant([4, 8]))
        with ag.improved_errors(compiled_fn):
    expected = error.exception
    custom_traceback = expected.custom_traceback
    found_correct_filename = False
    num_test_fn_frames = 0
    num_g_frames = 0
    for frame in custom_traceback:
      filename, _, fn_name, source_code = frame
      self.assertFalse('/tmp/' in filename)
      self.assertFalse('control_flow.py' in filename)
      self.assertFalse('ag__.' in fn_name)
      found_correct_filename |= __file__ in filename
      num_test_fn_frames += int('test_fn' == fn_name and
                                'return g(x, 10)' in source_code)
      num_g_frames += int('g' == fn_name and 'x //= 0' in source_code)
    self.assertEqual(num_test_fn_frames, 1)
    self.assertEqual(num_g_frames, 1)
Exemplo n.º 2
  def test_graph_construction_error_rewriting_call_tree(self):

    def test_fn():
      return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)

    def inner_caller():
      return test_fn()

    def caller():
      return inner_caller()

    with self.assertRaises(ag.GraphConstructionError) as error:
      graph = ag.to_graph(caller)
    expected = error.exception
    custom_traceback = expected.custom_traceback
    found_correct_filename = False
    num_test_fn_names = 0
    num_inner_caller_names = 0
    num_caller_names = 0
    for frame in custom_traceback:
      filename, _, fn_name, _ = frame
      self.assertFalse('/tmp/' in filename)
      found_correct_filename |= __file__ in filename
      self.assertNotEqual('tf__test_fn', fn_name)
      num_test_fn_names += int('test_fn' == fn_name)
      self.assertNotEqual('tf__inner_caller', fn_name)
      num_inner_caller_names += int('inner_caller' == fn_name)
      self.assertNotEqual('tf__caller', fn_name)
      num_caller_names += int('caller' == fn_name)
    self.assertEqual(num_test_fn_names, 1)
    self.assertEqual(num_inner_caller_names, 1)
    self.assertEqual(num_caller_names, 1)
Exemplo n.º 4
  def test_runtime_error_rewriting_nested(self):

    def test_fn(x):

      def g(y):
        return y**2 // 0

      s = 0
      for xi in x:
        s += g(xi)
      return s

    compiled_fn = ag.to_graph(test_fn)

    # TODO(b/111408261): Nested functions currently do not rewrite correctly,
    # when they do we should change this test to check for the same traceback
    # properties as the other tests.  This should throw a runtime error with a
    # frame with "g" as the function name but because we don't yet add
    # try/except blocks to inner functions the name is "tf__g".
    with self.assertRaises(ag.TfRuntimeError) as error:
      with self.cached_session() as sess:
        x = compiled_fn(tf.constant([4, 8]))
        with ag.improved_errors(compiled_fn):
    expected = error.exception
    custom_traceback = expected.custom_traceback
    num_tf_g_frames = 0
    for frame in custom_traceback:
      _, _, fn_name, _ = frame
      self.assertNotEqual('g', fn_name)
      num_tf_g_frames += int('tf__g' == fn_name)
    self.assertEqual(num_tf_g_frames, 1)
Exemplo n.º 7
    def test_graph_construction_error_rewriting_class(self):
        class TestClass(object):
            def test_fn(self):
                return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)

            def inner_caller(self):
                return self.test_fn()

            def caller(self):
                return self.inner_caller()

        # Note we expect a TypeError here because the traceback will not be
        # rewritten for classes.
        with self.assertRaises(TypeError):
            graph = ag.to_graph(TestClass)
  def test_basic(self):
    converted = ag.to_graph(list_used_as_tuple)
    result = converted()

    with self.cached_session() as sess:
      self.assertAllEqual(self.evaluate(result), [1, 2, 3])
