Exemplo n.º 1
0
  def test_to_graph_with_globals(self):

    def test_fn(x):
      global testing_global_numeric
      testing_global_numeric = x + testing_global_numeric
      return testing_global_numeric

    # TODO(b/122368197)
    with self.assertRaisesRegex(
        errors.AutoGraphError, 'global keyword is not yet supported'):
      api.to_graph(test_fn)
Exemplo n.º 2
0
  def test_to_graph_basic(self):

    def test_fn(x, s):
      while tf.reduce_sum(x) > s:
        x //= 2
      return x

    compiled_fn = api.to_graph(test_fn)

    with self.cached_session() as sess:
      x = compiled_fn(constant_op.constant([4, 8]), 4)
      self.assertListEqual([1, 2], sess.run(x).tolist())
Exemplo n.º 3
0
  def test_to_graph_caching_different_options(self):

    def called_fn():
      pass

    def test_fn():
      return called_fn()

    converted_recursive = api.to_graph(test_fn, recursive=True)
    converted_non_recursive = api.to_graph(test_fn, recursive=False)

    self.assertNotEqual(converted_recursive.ag_module,
                        converted_non_recursive.ag_module)
    self.assertIn('internal_convert_user_code=True',
                  tf_inspect.getsource(converted_recursive))
    self.assertNotIn('internal_convert_user_code=False',
                     tf_inspect.getsource(converted_recursive))
    self.assertIn('internal_convert_user_code=False',
                  tf_inspect.getsource(converted_non_recursive))
    self.assertNotIn('internal_convert_user_code=True',
                     tf_inspect.getsource(converted_non_recursive))
Exemplo n.º 4
0
  def test_to_graph_with_kwargs_clashing_converted_call(self):

    def called_fn(**kwargs):
      return kwargs['f'] + kwargs['owner']

    def test_fn():
      # These arg names intentionally match converted_call's
      return called_fn(f=1, owner=2)

    compiled_fn = api.to_graph(test_fn)

    self.assertEqual(compiled_fn(), 3)
Exemplo n.º 5
0
  def test_to_graph_basic(self):

    def test_fn(x, s):
      while tf.reduce_sum(x) > s:
        x //= 2
      return x

    compiled_fn = api.to_graph(test_fn)

    with tf.Graph().as_default():
      x = compiled_fn(constant_op.constant((4, 8)), 4)
      self.assertAllEqual(self.evaluate(x), (1, 2))
Exemplo n.º 6
0
  def test_to_graph_preserves_bindings(self):
    y = 3

    def test_fn():
      return y

    converted = api.to_graph(test_fn)

    self.assertEqual(converted(), 3)

    y = 7

    self.assertEqual(converted(), 7)
Exemplo n.º 7
0
  def test_converted_call_already_converted(self):

    def f(x):
      return x == 0

    x = api.converted_call(f, None, converter.ConversionOptions(),
                           (constant_op.constant(0),), {})
    self.assertTrue(self.evaluate(x))

    converted_f = api.to_graph(f)
    x = api.converted_call(converted_f, None, converter.ConversionOptions(),
                           (constant_op.constant(0),), {})
    self.assertTrue(self.evaluate(x))
Exemplo n.º 8
0
  def test_converted_call_already_converted(self):

    def f(x):
      return x == 0

    with self.cached_session() as sess:
      x = api.converted_call(f, None, converter.ConversionOptions(),
                             constant_op.constant(0))
      self.assertTrue(sess.run(x))

      converted_f = api.to_graph(f)
      x = api.converted_call(converted_f, None, converter.ConversionOptions(),
                             constant_op.constant(0))
      self.assertTrue(sess.run(x))
Exemplo n.º 9
0
  def test_to_graph_with_defaults(self):

    foo = 4

    def test_fn(x, s=foo):
      while tf.reduce_sum(x) > s:
        x //= 2
      return x

    compiled_fn = api.to_graph(test_fn)

    with self.cached_session() as sess:
      x = compiled_fn(constant_op.constant([4, 8]))
      self.assertListEqual([1, 2], self.evaluate(x).tolist())
Exemplo n.º 10
0
  def test_converted_call_already_converted(self):

    def f(x):
      return x == 0

    with self.test_session() as sess:
      x = api.converted_call(f, False, False, False, {},
                             constant_op.constant(0))
      self.assertTrue(sess.run(x))

      converted_f = api.to_graph(f)
      x = api.converted_call(converted_f, False, False, False, {},
                             constant_op.constant(0))
      self.assertTrue(sess.run(x))
Exemplo n.º 11
0
  def test_converted_call_already_converted(self):

    def f(x):
      return x == 0

    x = api.converted_call(f, None, converter.ConversionOptions(recursive=True),
                           (constant_op.constant(0),), {})
    self.assertTrue(self.evaluate(x))

    converted_f = api.to_graph(
        f, experimental_optional_features=converter.Feature.ALL)
    x = api.converted_call(converted_f, None,
                           converter.ConversionOptions(recursive=True),
                           (constant_op.constant(0),), {})
    self.assertTrue(self.evaluate(x))
Exemplo n.º 12
0
  def test_to_graph_caching(self):

    def test_fn(x):
      if x > 0:
        return x
      else:
        return -x

    converted_functions = tuple(api.to_graph(test_fn) for _ in (-1, 0, 1))

    # All outputs are from the same module. We can't use __module__ because
    # that's reset when we instantiate the function (see conversion.py).
    # TODO(mdan): Can and should we overwrite __module__ instead?
    module_names = frozenset(f.ag_module for f in converted_functions)
    self.assertEqual(len(module_names), 1)
    self.assertNotIn('__main__', module_names)

    self.assertEqual(len(frozenset(id(f) for f in converted_functions)), 3)
Exemplo n.º 13
0
 def f():
     converted_g = api.to_graph(g)
     converted_g()
Exemplo n.º 14
0
  def test_source_map_attribute_present(self):

    def test_fn(y):
      return y**2

    self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map'))
            """Scales the loss, computes the gradients, and unscales the gradients."""
            loss_scale_val = loss_scale()
            with gradient_tape:  # re-enter gradient tape so it sees the loss scaling
                scaled_target = nest.map_structure(
                    lambda t: t * loss_scale_val, target)
            old_grads = super(LossScaleGradientTape,
                              gradient_tape).gradient(scaled_target, sources,
                                                      output_gradients,
                                                      unconnected_gradients)
            inv_loss_scale = 1.0 / loss_scale_val
            grads = nest.map_structure(lambda g: inv_loss_scale * g, old_grads)
            return grads

        # Switch to a replica-context to compute gradients once per replica.
        grads = distribution.experimental_run_v2(
            replica_fn,
            args=(loss_scale_gradient_tapes, target, sources,
                  output_gradients))
        # Check for non-finite gradients possibly resulting from scaling
        _, ready_to_update = loss_scale.update(grads)

    return grads


# For some reason, AutoGraph does not convert _compute_gradients_until_finite
# automatically inside a tf.function, so we convert it manually.
# TODO(b/143572314): Determine why AutoGraph does not do the conversion
# automatically
_compute_gradients_until_finite_autograph = api.to_graph(
    _compute_gradients_until_finite)
Exemplo n.º 16
0
    def test_source_map_attribute_present(self):
        def test_fn(y):
            return y**2

        self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map'))
Exemplo n.º 17
0
    def test_to_graph_source_map(self):
        def test_fn(y):
            return y**2

        self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map'))
Exemplo n.º 18
0
  def test_to_graph_source_map(self):

    def test_fn(y):
      return y**2

    self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map'))