def test_to_graph_with_globals(self): def test_fn(x): global testing_global_numeric testing_global_numeric = x + testing_global_numeric return testing_global_numeric # TODO(b/122368197) with self.assertRaisesRegex( errors.AutoGraphError, 'global keyword is not yet supported'): api.to_graph(test_fn)
def test_to_graph_basic(self): def test_fn(x, s): while tf.reduce_sum(x) > s: x //= 2 return x compiled_fn = api.to_graph(test_fn) with self.cached_session() as sess: x = compiled_fn(constant_op.constant([4, 8]), 4) self.assertListEqual([1, 2], sess.run(x).tolist())
def test_to_graph_caching_different_options(self): def called_fn(): pass def test_fn(): return called_fn() converted_recursive = api.to_graph(test_fn, recursive=True) converted_non_recursive = api.to_graph(test_fn, recursive=False) self.assertNotEqual(converted_recursive.ag_module, converted_non_recursive.ag_module) self.assertIn('internal_convert_user_code=True', tf_inspect.getsource(converted_recursive)) self.assertNotIn('internal_convert_user_code=False', tf_inspect.getsource(converted_recursive)) self.assertIn('internal_convert_user_code=False', tf_inspect.getsource(converted_non_recursive)) self.assertNotIn('internal_convert_user_code=True', tf_inspect.getsource(converted_non_recursive))
def test_to_graph_with_kwargs_clashing_converted_call(self): def called_fn(**kwargs): return kwargs['f'] + kwargs['owner'] def test_fn(): # These arg names intentionally match converted_call's return called_fn(f=1, owner=2) compiled_fn = api.to_graph(test_fn) self.assertEqual(compiled_fn(), 3)
def test_to_graph_basic(self): def test_fn(x, s): while tf.reduce_sum(x) > s: x //= 2 return x compiled_fn = api.to_graph(test_fn) with tf.Graph().as_default(): x = compiled_fn(constant_op.constant((4, 8)), 4) self.assertAllEqual(self.evaluate(x), (1, 2))
def test_to_graph_preserves_bindings(self): y = 3 def test_fn(): return y converted = api.to_graph(test_fn) self.assertEqual(converted(), 3) y = 7 self.assertEqual(converted(), 7)
def test_converted_call_already_converted(self): def f(x): return x == 0 x = api.converted_call(f, None, converter.ConversionOptions(), (constant_op.constant(0),), {}) self.assertTrue(self.evaluate(x)) converted_f = api.to_graph(f) x = api.converted_call(converted_f, None, converter.ConversionOptions(), (constant_op.constant(0),), {}) self.assertTrue(self.evaluate(x))
def test_converted_call_already_converted(self): def f(x): return x == 0 with self.cached_session() as sess: x = api.converted_call(f, None, converter.ConversionOptions(), constant_op.constant(0)) self.assertTrue(sess.run(x)) converted_f = api.to_graph(f) x = api.converted_call(converted_f, None, converter.ConversionOptions(), constant_op.constant(0)) self.assertTrue(sess.run(x))
def test_to_graph_with_defaults(self): foo = 4 def test_fn(x, s=foo): while tf.reduce_sum(x) > s: x //= 2 return x compiled_fn = api.to_graph(test_fn) with self.cached_session() as sess: x = compiled_fn(constant_op.constant([4, 8])) self.assertListEqual([1, 2], self.evaluate(x).tolist())
def test_converted_call_already_converted(self): def f(x): return x == 0 with self.test_session() as sess: x = api.converted_call(f, False, False, False, {}, constant_op.constant(0)) self.assertTrue(sess.run(x)) converted_f = api.to_graph(f) x = api.converted_call(converted_f, False, False, False, {}, constant_op.constant(0)) self.assertTrue(sess.run(x))
def test_converted_call_already_converted(self): def f(x): return x == 0 x = api.converted_call(f, None, converter.ConversionOptions(recursive=True), (constant_op.constant(0),), {}) self.assertTrue(self.evaluate(x)) converted_f = api.to_graph( f, experimental_optional_features=converter.Feature.ALL) x = api.converted_call(converted_f, None, converter.ConversionOptions(recursive=True), (constant_op.constant(0),), {}) self.assertTrue(self.evaluate(x))
def test_to_graph_caching(self): def test_fn(x): if x > 0: return x else: return -x converted_functions = tuple(api.to_graph(test_fn) for _ in (-1, 0, 1)) # All outputs are from the same module. We can't use __module__ because # that's reset when we instantiate the function (see conversion.py). # TODO(mdan): Can and should we overwrite __module__ instead? module_names = frozenset(f.ag_module for f in converted_functions) self.assertEqual(len(module_names), 1) self.assertNotIn('__main__', module_names) self.assertEqual(len(frozenset(id(f) for f in converted_functions)), 3)
def f(): converted_g = api.to_graph(g) converted_g()
def test_source_map_attribute_present(self): def test_fn(y): return y**2 self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map'))
"""Scales the loss, computes the gradients, and unscales the gradients.""" loss_scale_val = loss_scale() with gradient_tape: # re-enter gradient tape so it sees the loss scaling scaled_target = nest.map_structure( lambda t: t * loss_scale_val, target) old_grads = super(LossScaleGradientTape, gradient_tape).gradient(scaled_target, sources, output_gradients, unconnected_gradients) inv_loss_scale = 1.0 / loss_scale_val grads = nest.map_structure(lambda g: inv_loss_scale * g, old_grads) return grads # Switch to a replica-context to compute gradients once per replica. grads = distribution.experimental_run_v2( replica_fn, args=(loss_scale_gradient_tapes, target, sources, output_gradients)) # Check for non-finite gradients possibly resulting from scaling _, ready_to_update = loss_scale.update(grads) return grads # For some reason, AutoGraph does not convert _compute_gradients_until_finite # automatically inside a tf.function, so we convert it manually. # TODO(b/143572314): Determine why AutoGraph does not do the conversion # automatically _compute_gradients_until_finite_autograph = api.to_graph( _compute_gradients_until_finite)
def test_to_graph_source_map(self): def test_fn(y): return y**2 self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map'))