def test_tf_convert_wrapped(self): def f(): if tf.reduce_sum([1, 2]) > 0: return -1 return 1 @functools.wraps(f) def wrapper(*args, **kwargs): return wrapper.__wrapped__(*args, **kwargs) decorated_f = tf_decorator.make_decorator(f, wrapper) # Note: the autograph setting of tf has nothing to do with the # test case. We just disable it to avoid confusion. @def_function.function(autograph=False) def test_fn(ctx): return api.tf_convert(decorated_f, ctx)() self.assertEqual( self.evaluate( test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))), -1) # tf_convert mutates the decorator, so we need to create a new one for # another test. decorated_f = tf_decorator.make_decorator(f, wrapper) with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'): # The code in `f` is only valid with AutoGraph. test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
def test_tf_convert_overrides_current_context(self): def f(expect_converted): self.assertEqual(converter_testing.is_inside_generated_code(), expect_converted) @api.do_not_convert def test_fn(ctx, expect_converted): return api.tf_convert(f, ctx)(expect_converted) test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED), True) test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED), False)
def test_tf_convert_whitelisted_method(self): model = sequential.Sequential([core.Dense(2)]) converted_call = api.tf_convert( model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED)) _, converted_target = tf_decorator.unwrap(converted_call) self.assertIs(converted_target.__func__, model.call.__func__)
def test_fn(): def inner_fn(): inner_fn_callee() with ag_ctx.ControlStatusCtx( ag_ctx.Status.DISABLED, converter.ConversionOptions(recursive=True)): inner_fn()
def test_tf_convert_direct(self): def f(): if tf.reduce_sum([1, 2]) > 0: return -1 return 1 # Note: the autograph setting of tf.function has nothing to do with the # test case. We just disable it to avoid confusion. @def_function.function(autograph=False) def test_fn(ctx): return api.tf_convert(f, ctx)() self.assertEqual( self.evaluate( test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))), -1) with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'): # The code in `f` is only valid with AutoGraph. test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
def test_tf_convert_unspecified_not_converted_by_default(self): def f(): self.assertEqual(ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED) self.assertFalse(converter_testing.is_inside_generated_code()) @def_function.function def test_fn(ctx): return api.tf_convert(f, ctx, convert_by_default=False)() test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED))
def test_tf_convert_allowlisted_method(self): class TestClass(object): def method(self): return converter_testing.is_inside_generated_code() converter_testing.allowlist(TestClass.method) obj = TestClass() converted_call = api.tf_convert( obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED)) _, converted_target = tf_decorator.unwrap(converted_call) self.assertIs(converted_target.__func__, obj.method.__func__)
def test_tf_convert_tf_decorator_unwrapping_context_disabled(self): def f(): self.assertFalse(converter_testing.is_inside_generated_code()) @functools.wraps(f) def wrapper(*args, **kwargs): return wrapper.__wrapped__(*args, **kwargs) decorated_f = tf_decorator.make_decorator(f, wrapper) def test_fn(ctx): return api.tf_convert(decorated_f, ctx)() test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
def test_tf_convert_unspecified_not_converted_by_default(self): def f(): self.assertEqual(ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED) if tf.reduce_sum([1, 2]) > 0: return -1 return 1 @def_function.function def test_fn(ctx): return api.tf_convert(f, ctx, convert_by_default=False)() with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'): # The code in `f` is only valid with AutoGraph. test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED))
def test_tf_convert_whitelisted_method(self): if six.PY2: self.skipTest('Test bank not comptible with Python 2.') class TestClass(object): def method(self): return converter_testing.is_inside_generated_code() converter_testing.whitelist(TestClass.method) obj = TestClass() converted_call = api.tf_convert( obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED)) _, converted_target = tf_decorator.unwrap(converted_call) self.assertIs(converted_target.__func__, obj.method.__func__)
def wrapper(*args, **kwargs): """Wrapper that calls the converted version of f.""" with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED): try: return converted_call( f, None, converter.ConversionOptions( recursive=recursive, force_conversion=True, optional_features=optional_features, ), args, kwargs) except Exception as e: # pylint:disable=broad-except if hasattr(e, 'ag_error_metadata'): raise e.ag_error_metadata.to_exception(type(e)) else: raise
def __init__(self, function_name, scope_name, options): self.name = scope_name self.options = options if options.user_requested: self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED, options) self.callopts = options.call_options() use_name_scope = options.uses(converter.Feature.NAME_SCOPES) self.use_name_scope = use_name_scope if use_name_scope: self.name_scope = ops.name_scope(self._sanitize(function_name)) use_auto_deps = self.options.uses(converter.Feature.AUTO_CONTROL_DEPS) self.use_auto_deps = use_auto_deps if use_auto_deps: self.autodeps_scope = auto_control_deps.AutomaticControlDependencies() self._return_value_marked = False
def test_tf_convert_tf_decorator_allowlist_method(self): def wrap(f): def wrapper(*args, **kwargs): return wrapper.__wrapped__(*args, **kwargs) return tf_decorator.make_decorator(f, wrapper) class TestClass(object): @wrap def method(self): return converter_testing.is_inside_generated_code() converter_testing.allowlist(TestClass.method) obj = TestClass() # It's intended that tf_convert modifies the original method in this case. # This is not desirable, but options are limited. api.tf_convert(obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED)) self.assertTrue(obj.method())
def call_in_default_context(): with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED): return api.converted_call( test_fn, (True,), None, options=DEFAULT_RECURSIVE)
def call_in_disabled_context(): with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED): return api.converted_call( test_fn, (False,), None, options=DEFAULT_RECURSIVE)
def graph_wrapper(*args, **kwargs): with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED): return func(*args, **kwargs)
def wrapper(*args, **kwargs): with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED): return func(*args, **kwargs)