コード例 #1
0
  def test_tf_convert_wrapped(self):

    def f():
      if tf.reduce_sum([1, 2]) > 0:
        return -1
      return 1

    @functools.wraps(f)
    def wrapper(*args, **kwargs):
      return wrapper.__wrapped__(*args, **kwargs)

    decorated_f = tf_decorator.make_decorator(f, wrapper)

    # Note: the autograph setting of tf has nothing to do with the
    # test case. We just disable it to avoid confusion.
    @def_function.function(autograph=False)
    def test_fn(ctx):
      return api.tf_convert(decorated_f, ctx)()

    self.assertEqual(
        self.evaluate(
            test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))), -1)

    # tf_convert mutates the decorator, so we need to create a new one for
    # another test.
    decorated_f = tf_decorator.make_decorator(f, wrapper)
    with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'):
      # The code in `f` is only valid with AutoGraph.
      test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
コード例 #2
0
    def test_tf_convert_overrides_current_context(self):
        def f(expect_converted):
            self.assertEqual(converter_testing.is_inside_generated_code(),
                             expect_converted)

        @api.do_not_convert
        def test_fn(ctx, expect_converted):
            return api.tf_convert(f, ctx)(expect_converted)

        test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED), True)
        test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED), False)
コード例 #3
0
  def test_tf_convert_whitelisted_method(self):

    model = sequential.Sequential([core.Dense(2)])
    converted_call = api.tf_convert(
        model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
    _, converted_target = tf_decorator.unwrap(converted_call)
    self.assertIs(converted_target.__func__, model.call.__func__)
コード例 #4
0
ファイル: functions_test.py プロジェクト: zzwATdhu/tensorflow
        def test_fn():
            def inner_fn():
                inner_fn_callee()

            with ag_ctx.ControlStatusCtx(
                    ag_ctx.Status.DISABLED,
                    converter.ConversionOptions(recursive=True)):
                inner_fn()
コード例 #5
0
  def test_tf_convert_direct(self):

    def f():
      if tf.reduce_sum([1, 2]) > 0:
        return -1
      return 1

    # Note: the autograph setting of tf.function has nothing to do with the
    # test case. We just disable it to avoid confusion.
    @def_function.function(autograph=False)
    def test_fn(ctx):
      return api.tf_convert(f, ctx)()

    self.assertEqual(
        self.evaluate(
            test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))), -1)
    with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'):
      # The code in `f` is only valid with AutoGraph.
      test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
コード例 #6
0
    def test_tf_convert_unspecified_not_converted_by_default(self):
        def f():
            self.assertEqual(ag_ctx.control_status_ctx().status,
                             ag_ctx.Status.UNSPECIFIED)
            self.assertFalse(converter_testing.is_inside_generated_code())

        @def_function.function
        def test_fn(ctx):
            return api.tf_convert(f, ctx, convert_by_default=False)()

        test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED))
コード例 #7
0
    def test_tf_convert_allowlisted_method(self):
        class TestClass(object):
            def method(self):
                return converter_testing.is_inside_generated_code()

        converter_testing.allowlist(TestClass.method)

        obj = TestClass()
        converted_call = api.tf_convert(
            obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
        _, converted_target = tf_decorator.unwrap(converted_call)
        self.assertIs(converted_target.__func__, obj.method.__func__)
コード例 #8
0
    def test_tf_convert_tf_decorator_unwrapping_context_disabled(self):
        def f():
            self.assertFalse(converter_testing.is_inside_generated_code())

        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            return wrapper.__wrapped__(*args, **kwargs)

        decorated_f = tf_decorator.make_decorator(f, wrapper)

        def test_fn(ctx):
            return api.tf_convert(decorated_f, ctx)()

        test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
コード例 #9
0
    def test_tf_convert_unspecified_not_converted_by_default(self):
        def f():
            self.assertEqual(ag_ctx.control_status_ctx().status,
                             ag_ctx.Status.UNSPECIFIED)
            if tf.reduce_sum([1, 2]) > 0:
                return -1
            return 1

        @def_function.function
        def test_fn(ctx):
            return api.tf_convert(f, ctx, convert_by_default=False)()

        with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'):
            # The code in `f` is only valid with AutoGraph.
            test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED))
コード例 #10
0
    def test_tf_convert_whitelisted_method(self):

        if six.PY2:
            self.skipTest('Test bank not comptible with Python 2.')

        class TestClass(object):
            def method(self):
                return converter_testing.is_inside_generated_code()

        converter_testing.whitelist(TestClass.method)

        obj = TestClass()
        converted_call = api.tf_convert(
            obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
        _, converted_target = tf_decorator.unwrap(converted_call)
        self.assertIs(converted_target.__func__, obj.method.__func__)
コード例 #11
0
 def wrapper(*args, **kwargs):
     """Wrapper that calls the converted version of f."""
     with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED):
         try:
             return converted_call(
                 f, None,
                 converter.ConversionOptions(
                     recursive=recursive,
                     force_conversion=True,
                     optional_features=optional_features,
                 ), args, kwargs)
         except Exception as e:  # pylint:disable=broad-except
             if hasattr(e, 'ag_error_metadata'):
                 raise e.ag_error_metadata.to_exception(type(e))
             else:
                 raise
コード例 #12
0
ファイル: function_wrappers.py プロジェクト: Harryi0/tinyML
  def __init__(self, function_name, scope_name, options):
    self.name = scope_name
    self.options = options

    if options.user_requested:
      self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED,
                                                   options)
    self.callopts = options.call_options()

    use_name_scope = options.uses(converter.Feature.NAME_SCOPES)
    self.use_name_scope = use_name_scope
    if use_name_scope:
      self.name_scope = ops.name_scope(self._sanitize(function_name))

    use_auto_deps = self.options.uses(converter.Feature.AUTO_CONTROL_DEPS)
    self.use_auto_deps = use_auto_deps
    if use_auto_deps:
      self.autodeps_scope = auto_control_deps.AutomaticControlDependencies()
      self._return_value_marked = False
コード例 #13
0
ファイル: api_test.py プロジェクト: neosensory/ext_tensorflow
    def test_tf_convert_tf_decorator_allowlist_method(self):
        def wrap(f):
            def wrapper(*args, **kwargs):
                return wrapper.__wrapped__(*args, **kwargs)

            return tf_decorator.make_decorator(f, wrapper)

        class TestClass(object):
            @wrap
            def method(self):
                return converter_testing.is_inside_generated_code()

        converter_testing.allowlist(TestClass.method)

        obj = TestClass()
        # It's intended that tf_convert modifies the original method in this case.
        # This is not desirable, but options are limited.
        api.tf_convert(obj.method,
                       ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
        self.assertTrue(obj.method())
コード例 #14
0
 def call_in_default_context():
   with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED):
     return api.converted_call(
         test_fn, (True,), None, options=DEFAULT_RECURSIVE)
コード例 #15
0
 def call_in_disabled_context():
   with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
     return api.converted_call(
         test_fn, (False,), None, options=DEFAULT_RECURSIVE)
コード例 #16
0
 def graph_wrapper(*args, **kwargs):
     with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
         return func(*args, **kwargs)
コード例 #17
0
 def wrapper(*args, **kwargs):
   with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
     return func(*args, **kwargs)