Ejemplo n.º 1
0
    def test_parser_compile_idempotent(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = x + 1
            return b

        node, _, _ = parser.parse_entity(test_fn, future_imports=())

        self.assertEqual(
            textwrap.dedent(tf_inspect.getsource(test_fn)),
            tf_inspect.getsource(compiler.ast_to_object([node])[0].test_fn))
Ejemplo n.º 2
0
    def test_parser_compile_identity(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = x + 1
            return b

        node, _ = parser.parse_entity(test_fn, future_features=())
        module, _, _ = compiler.ast_to_object(node)

        self.assertEqual(textwrap.dedent(tf_inspect.getsource(test_fn)),
                         tf_inspect.getsource(module.test_fn))
Ejemplo n.º 3
0
    def test_parser_compile_idempotent(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = x + 1
            return b

        _, _, all_nodes = parser.parse_entity(test_fn)

        self.assertEqual(
            textwrap.dedent(tf_inspect.getsource(test_fn)),
            tf_inspect.getsource(compiler.ast_to_object(all_nodes)[0].test_fn))
Ejemplo n.º 4
0
  def test_parser_compile_idempotent(self):

    def test_fn(x):
      a = True
      b = ''
      if a:
        b = x + 1
      return b

    self.assertEqual(
        textwrap.dedent(tf_inspect.getsource(test_fn)),
        tf_inspect.getsource(
            compiler.ast_to_object(
                parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
Ejemplo n.º 5
0
  def test_parser_compile_identity(self):

    def test_fn(x):
      a = True
      b = ''
      if a:
        b = x + 1
      return b

    node, _ = parser.parse_entity(test_fn, future_features=())
    module, _, _ = compiler.ast_to_object(node)

    self.assertEqual(
        textwrap.dedent(tf_inspect.getsource(test_fn)),
        tf_inspect.getsource(module.test_fn))
Ejemplo n.º 6
0
    def test_parse_load_identity(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = (x + 1)
            return b

        node, _ = parser.parse_entity(test_fn, future_features=())
        module, _, _ = loader.load_ast(node)
        source = tf_inspect.getsource(module.test_fn)
        expected_node_src = textwrap.dedent(tf_inspect.getsource(test_fn))

        self.assertAstMatches(node, source)
        self.assertAstMatches(node, expected_node_src)
Ejemplo n.º 7
0
    def test_simple_decorator(self):
        def simple_decorator(f):
            return lambda a: f(a) + 1

        # The Python parser does capture decorators into the AST.
        # However, the interpreter desugars them upon load, and refering to the
        # decorated function at runtime usually loses any trace of the decorator.
        # Below is an example when that doesn't happen.
        def static_wrapper():
            @simple_decorator
            def test_fn(a):  # pylint:disable=unused-variable
                return a

        node = self.parse_and_analyze(static_wrapper,
                                      {'simple_decorator': simple_decorator})
        node = node.body[0].body[0]

        node = decorators.transform(node, remove_decorators=())
        # Since the decorator is not removed, we need to include its source
        # code. We cannot do it after the fact because decorators are executed
        # on load.
        result, _ = compiler.ast_to_object(
            node,
            source_prefix=textwrap.dedent(
                tf_inspect.getsource(simple_decorator)))
        self.assertEqual(2, result.test_fn(1))

        node = decorators.transform(node,
                                    remove_decorators=(simple_decorator, ))
        with self.compiled(node) as result:
            self.assertEqual(1, result.test_fn(1))
Ejemplo n.º 8
0
    def test_function_decorator(self):
        def function_decorator():
            def decorator(f):
                return lambda a: f(a) + 1

            return decorator

        # The Python parser does capture decorators into the AST.
        # However, the interpreter desugars them on load, and refering to the
        # decorated function at runtime usually loses any trace of the decorator.
        # Below is an example when that doesn't happen.
        def static_wrapper():
            @function_decorator()
            def test_fn(a):  # pylint:disable=unused-variable
                return a

        node = self.parse_and_analyze(
            static_wrapper, {'function_decorator': function_decorator})
        node = node.body[0].body[0]

        node = decorators.transform(node, remove_decorators=())
        result = compiler.ast_to_object(
            node,
            source_prefix=textwrap.dedent(
                tf_inspect.getsource(function_decorator)))
        self.assertEqual(2, result.test_fn(1))

        node = decorators.transform(node,
                                    remove_decorators=(function_decorator, ))
        result = compiler.ast_to_object(node)
        self.assertEqual(1, result.test_fn(1))
Ejemplo n.º 9
0
  def test_parse_load_identity(self):

    def test_fn(x):
      a = True
      b = ''
      if a:
        b = (x + 1)
      return b

    node, _ = parser.parse_entity(test_fn, future_features=())
    module, _, _ = loader.load_ast(node)

    # astunparse uses fixed 4-space indenting.
    self.assertEqual(
        textwrap.dedent(tf_inspect.getsource(test_fn)),
        tf_inspect.getsource(module.test_fn).replace('    ', '  '))
Ejemplo n.º 10
0
def to_code(entity,
            recursive=True,
            experimental_optional_features=None):
  """Similar to `to_graph`, but returns Python source code as a string.

  Also see: `tf.autograph.to_graph`.

  `to_graph` returns the Python source code that can be used to generate a
  TensorFlow graph that is functionally identical to the input Python code.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.

  Returns:
    The converted code as string.
  """
  source = tf_inspect.getsource(
      to_graph(
          entity,
          recursive=recursive,
          experimental_optional_features=experimental_optional_features))
  return textwrap.dedent(source)
Ejemplo n.º 11
0
  def test_simple_decorator(self):

    def simple_decorator(f):
      return lambda a: f(a) + 1

    # The Python parser does capture decorators into the AST.
    # However, the interpreter desugars them upon load, and refering to the
    # decorated function at runtime usually loses any trace of the decorator.
    # Below is an example when that doesn't happen.
    def static_wrapper():

      @simple_decorator
      def test_fn(a):  # pylint:disable=unused-variable
        return a

    node = self.parse_and_analyze(static_wrapper,
                                  {'simple_decorator': simple_decorator})
    node = node.body[0].body[0]

    node = decorators.transform(node, remove_decorators=())
    result = compiler.ast_to_object(
        node,
        source_prefix=textwrap.dedent(tf_inspect.getsource(simple_decorator)))
    self.assertEqual(2, result.test_fn(1))

    node = decorators.transform(node, remove_decorators=(simple_decorator,))
    result = compiler.ast_to_object(node)
    self.assertEqual(1, result.test_fn(1))
Ejemplo n.º 12
0
    def _generic_test(self,
                      f_raw,
                      examples,
                      input_signature=None,
                      skip_modes=None):
        """Test a function `f_raw` against all tests `examples`.

    Args:
      f_raw: a callable.
      examples: A list of `Example` named tuples.
      input_signature: Input signature to tf.function.
      skip_modes: A list of `RunMode` enums to entirely skip testing in the
        specified `RunMode`s. This is necessary when things fail in a certain
        `RunMode` even before executing the function (e.g. during saving or
        loading in `RunMode.SAVED` mode).
    """
        f_tf = None
        if not skip_modes:
            skip_modes = []

        if tf_inspect.isfunction(f_raw):
            self.recordProperty('f', tf_inspect.getsource(f_raw))
        else:
            self.recordProperty('f', tf_inspect.getdoc(f_raw))

        for arg, out, failure, bugs in examples:
            del out
            self.recordProperty('Input "{}"'.format(arg), {
                'not-working': failure,
                'bugs': bugs
            })

        # Run the function without tf.function
        if RunMode.RAW not in skip_modes:
            self._run_and_check(f_raw, RunMode.RAW, examples)

        # TF Function
        if RunMode.FUNCTION not in skip_modes:
            f_tf = tf.function(f_raw, input_signature=input_signature)
            self._run_and_check(f_tf, RunMode.FUNCTION, examples)

        # XLA Function
        if RunMode.XLA not in skip_modes:
            f_xla = tf.function(f_raw,
                                input_signature=input_signature,
                                experimental_compile=True)
            self._run_and_check(f_xla, RunMode.XLA, examples)

        # Write a saved model and try to run it
        if RunMode.SAVED not in skip_modes:
            module = tf.Module()
            if f_tf:
                module.f = f_tf
            else:
                module.f = tf.function(f_raw, input_signature=input_signature)

            saved_model_dir = tempfile.gettempdir()
            tf.saved_model.save(module, saved_model_dir)
            module_loaded = tf.saved_model.load(saved_model_dir)
            self._run_and_check(module_loaded.f, RunMode.SAVED, examples)
Ejemplo n.º 13
0
def function_to_graph(f, conversion_map, arg_values, arg_types,
                      owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""
  node = parser.parse_object(f).body[0]
  namespace = six.get_function_globals(f)

  # This is needed for non-global functions.
  closure = six.get_function_closure(f)
  if closure:
    for e in closure:
      if callable(e.cell_contents):
        fn = e.cell_contents
        namespace[fn.__name__] = fn

  namer = conversion_map.new_namer(namespace)
  ctx = context.EntityContext(
      namer=namer,
      source_code=tf_inspect.getsource(f),
      source_file=tf_inspect.getfile(f),
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types)
  node = node_to_graph(node, ctx, conversion_map.nocompile_decorators)

  # Simulate a rename to ensure the top level is in the name map. This is needed
  # for top level functions, and it also helps the consistency verification made
  # by update_name_map.
  if owner_type is not None:
    new_name = namer.compiled_function_name(f.__name__, f, owner_type)
  else:
    new_name = namer.compiled_function_name(f.__name__, f)
  node.name = new_name
  conversion_map.update_name_map(namer)
  return node, conversion_map.name_map[f]
Ejemplo n.º 14
0
  def test_function_decorator(self):

    def function_decorator():

      def decorator(f):
        return lambda a: f(a) + 1

      return decorator

    # The Python parser does capture decorators into the AST.
    # However, the interpreter desugars them on load, and refering to the
    # decorated function at runtime usually loses any trace of the decorator.
    # Below is an example when that doesn't happen.
    def static_wrapper():

      @function_decorator()
      def test_fn(a):  # pylint:disable=unused-variable
        return a

    node = self.parse_and_analyze(static_wrapper,
                                  {'function_decorator': function_decorator})
    node = node.body[0].body[0]

    node = decorators.transform(node, remove_decorators=())
    # Since the decorator is not removed, we need to include its source
    # code. We cannot do it after the fact because decorators are executed
    # on load.
    result, _ = compiler.ast_to_object(
        node,
        source_prefix=textwrap.dedent(tf_inspect.getsource(function_decorator)))
    self.assertEqual(2, result.test_fn(1))

    node = decorators.transform(node, remove_decorators=(function_decorator,))
    with self.compiled(node) as result:
      self.assertEqual(1, result.test_fn(1))
Ejemplo n.º 15
0
def to_code(entity, recursive=True, experimental_optional_features=None):
    """Similar to `to_graph`, but returns Python source code as a string.

  Also see: `tf.autograph.to_graph`.

  `to_graph` returns the Python source code that can be used to generate a
  TensorFlow graph that is functionally identical to the input Python code.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the converted
      function may call.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of optional
      features in the conversion process.

  Returns:
    The converted code as string.
  """
    source = tf_inspect.getsource(
        to_graph(
            entity,
            recursive=recursive,
            experimental_optional_features=experimental_optional_features))
    return textwrap.dedent(source)
Ejemplo n.º 16
0
    def test_to_graph_caching_different_options(self):
        def called_fn():
            pass

        def test_fn():
            return called_fn()

        converted_recursive = api.to_graph(test_fn, recursive=True)
        converted_non_recursive = api.to_graph(test_fn, recursive=False)

        self.assertNotEqual(converted_recursive.ag_module,
                            converted_non_recursive.ag_module)
        self.assertRegex(tf_inspect.getsource(converted_recursive),
                         'FunctionScope(.*recursive=True.*)')
        self.assertRegex(tf_inspect.getsource(converted_non_recursive),
                         'FunctionScope(.*recursive=False.*)')
Ejemplo n.º 17
0
    def _generic_test(self, f_raw, examples):
        """Test a function `f_raw` against all tests `examples`.

    Args:
      f_raw: a callable.
      examples: A list of `Example` named tuples.
    """
        self.recordProperty('f', tf_inspect.getsource(f_raw))
        for arg, out, failure, bugs in examples:
            del out
            self.recordProperty('Input "%r"' % arg, {
                'not-working': failure,
                'bugs': bugs
            })

        # Run the function without tf.function
        self._run_and_check(f_raw, RunMode.RAW, examples)
        # TF Function
        f_tf = tf.function(f_raw)
        self._run_and_check(f_tf, RunMode.FUNCTION, examples)
        # XLA Function
        f_xla = tf.function(f_raw, experimental_compile=True)
        self._run_and_check(f_xla, RunMode.XLA, examples)
        # Write a saved model and try to run it
        module = tf.Module()
        module.f = f_tf
        saved_model_dir = tempfile.gettempdir()
        tf.saved_model.save(module, saved_model_dir)
        module_loaded = tf.saved_model.load(saved_model_dir)
        self._run_and_check(module_loaded.f, RunMode.SAVED, examples)
Ejemplo n.º 18
0
  def testGetSource(self):
    expected = '''@test_decorator('decorator')
def test_decorated_function_with_defaults(a, b=2, c='Hello'):
  """Test Decorated Function With Defaults Docstring."""
  return [a, b, c]
'''
    self.assertEqual(
        expected, tf_inspect.getsource(test_decorated_function_with_defaults))
Ejemplo n.º 19
0
  def testGetSource(self):
    expected = '''@test_decorator('decorator')
def test_decorated_function_with_defaults(a, b=2, c='Hello'):
  """Test Decorated Function With Defaults Docstring."""
  return [a, b, c]
'''
    self.assertEqual(
        expected, tf_inspect.getsource(test_decorated_function_with_defaults))
Ejemplo n.º 20
0
def parse_entity(entity):
  """Returns the AST of given entity."""
  source = tf_inspect.getsource(entity)

  def fail(comment):
    raise ValueError(
        'Failed to parse source code of {}, which Python reported as:\n{}\n'
        '{}'.format(entity, source, comment))

  # Comments and multiline strings can appear at arbitrary indentation levels,
  # causing textwrap.dedent to not correctly dedent source code.
  # TODO(b/115884650): Automatic handling of comments/multiline strings.
  source = textwrap.dedent(source)

  try:
    return parse_str(source), source

  except IndentationError:
    # The text below lists the causes of this error known to us. There may
    # be more.
    fail('This may be caused by multiline strings or comments not indented at'
         'the same level as the code.')

  except SyntaxError as e:
    if not tf_inspect.isfunction(entity) or entity.__name__ != '<lambda>':
      raise

    # Certain entities, like lambdas, only hold the raw code lines which defined
    # them, which may include surrounding tokens and may be syntactically
    # invalid out of context. For example:
    #
    #     l = (
    #         lambda x: x,)[0]
    #
    # will have the dedented source "lambda x: x,)[0]"
    # Here we make an attempt to stip away the garbage by looking at the
    # information in the syntax error.
    lines = source.split('\n')
    lineno, offset = e.lineno, e.offset  # 1-based

    # Give up if there's nothing we can chip away.
    if len(lines) == lineno and len(lines[-1]) == offset:
      fail('If this is a lambda function, the error may be avoided by creating'
           ' the lambda in a standalone statement.')

    # Drop all lines following the error location
    # TODO(mdan): What's with the pylint errors?
    lines = lines[:lineno]  # pylint:disable=invalid-slice-index
    # Drop all characters following the error location
    lines[-1] = lines[-1][:offset - 1]  # pylint:disable=invalid-slice-index
    new_source = '\n'.join(lines)

    try:
      return parse_str(new_source), new_source
    except SyntaxError as e:
      fail('If this is a lambda function, the error may be avoided by creating'
           ' the lambda in a standalone statement. Tried to strip down the'
           ' source to:\n{}\nBut that did not work.'.format(new_source))
Ejemplo n.º 21
0
    def test_to_graph_caching_different_options(self):
        def called_fn():
            pass

        def test_fn():
            return called_fn()

        converted_recursive = api.to_graph(test_fn, recursive=True)
        converted_non_recursive = api.to_graph(test_fn, recursive=False)

        self.assertNotEqual(converted_recursive.ag_module,
                            converted_non_recursive.ag_module)
        self.assertIn('ag__.STD', tf_inspect.getsource(converted_recursive))
        self.assertNotIn('internal_convert_user_code=False',
                         tf_inspect.getsource(converted_recursive))
        self.assertIn('internal_convert_user_code=False',
                      tf_inspect.getsource(converted_non_recursive))
        self.assertNotIn('internal_convert_user_code=True',
                         tf_inspect.getsource(converted_non_recursive))
Ejemplo n.º 22
0
  def test_to_graph_caching_different_options(self):

    def called_fn():
      pass

    def test_fn():
      return called_fn()

    converted_recursive = api.to_graph(test_fn, recursive=True)
    converted_non_recursive = api.to_graph(test_fn, recursive=False)

    self.assertNotEqual(converted_recursive.ag_module,
                        converted_non_recursive.ag_module)
    self.assertIn('internal_convert_user_code=True',
                  tf_inspect.getsource(converted_recursive))
    self.assertNotIn('internal_convert_user_code=False',
                     tf_inspect.getsource(converted_recursive))
    self.assertIn('internal_convert_user_code=False',
                  tf_inspect.getsource(converted_non_recursive))
    self.assertNotIn('internal_convert_user_code=True',
                     tf_inspect.getsource(converted_non_recursive))
Ejemplo n.º 23
0
def parse_entity(entity):
  """Returns the AST of given entity."""
  source = tf_inspect.getsource(entity)
  # Comments and multiline strings can appear at arbitrary indentation levels,
  # causing textwrap.dedent to not correctly dedent source code.
  # TODO(b/115884650): Automatic handling of comments/multiline strings.
  source = textwrap.dedent(source)
  try:
    return parse_str(source), source
  except IndentationError:
    # Because we are parsing the source code of entities that have already
    # successfully parsed once, any IndentationErrors are guaranteed to be
    # caused by insufficient dedenting.
    raise ValueError(
        'Failed to dedent prior to parsing source code. If you have comments '
        'or multiline strings in your code, try indenting them. '
        'Multiline strings can be rewritten using textwrap.dedent.\n'
        'Offending source code: \n %s' % source)
Ejemplo n.º 24
0
def parse_entity(entity):
  """Returns the AST of given entity."""
  source = tf_inspect.getsource(entity)
  # Comments and multiline strings can appear at arbitrary indentation levels,
  # causing textwrap.dedent to not correctly dedent source code.
  # TODO(b/115884650): Automatic handling of comments/multiline strings.
  source = textwrap.dedent(source)
  try:
    return parse_str(source), source
  except IndentationError:
    # Because we are parsing the source code of entities that have already
    # successfully parsed once, any IndentationErrors are guaranteed to be
    # caused by insufficient dedenting.
    raise ValueError(
        'Failed to dedent prior to parsing source code. If you have comments '
        'or multiline strings in your code, try indenting them. '
        'Multiline strings can be rewritten using textwrap.dedent.\n'
        'Offending source code: \n %s' % source)
Ejemplo n.º 25
0
def run_user_main(wrapped_test_module):
    """Runs the "if __name__ == '__main__'" at the bottom of a module.

  TensorFlow practice is to have a main if at the bottom of the module which
  might call an API compat function before calling test.main().

  Since this is a statement, not a function, we can't cleanly reference it, but
  we can inspect it from the user module and run it in the context of that
  module so all imports and variables are available to it.

  Args:
    wrapped_test_module: The user-provided test code to run.

  Raises:
    NotImplementedError: If main block was not found in module. This should not
      be caught, as it is likely an error on the user's part -- absltest is all
      too happy to report a successful status (and zero tests executed) if a
      user forgets to end a class with "test.main()".
  """
    tree = ast.parse(tf_inspect.getsource(wrapped_test_module))

    # Get string representation of just the condition `__name == "__main__"`.
    target = ast.dump(
        ast.parse('if __name__ == "__main__": pass').body[0].test)

    # `tree.body` is a list of top-level statements in the module, like imports
    # and class definitions. We search for our main block, starting from the end.
    for expr in reversed(tree.body):
        if isinstance(expr, ast.If) and ast.dump(expr.test) == target:
            break
    else:
        raise NotImplementedError(
            f'Could not find `if __name__ == "main":` block in {wrapped_test_module.__name__}.'
        )

    # expr is defined because we would have raised an error otherwise.
    new_ast = ast.Module(body=expr.body, type_ignores=[])  # pylint:disable=undefined-loop-variable
    exec(  # pylint:disable=exec-used
        compile(new_ast, '<ast>', 'exec'),
        globals(),
        wrapped_test_module.__dict__,
    )
Ejemplo n.º 26
0
def to_code(entity, recursive=True, experimental_optional_features=None):
  """Returns the source code generated by AutoGraph, as a string.

  Example usage:

  >>> def f(x):
  ...   if x < 0:
  ...     x = -x
  ...   return x
  >>> tf.autograph.to_code(f)
  "...def tf__f(x):..."

  Also see: `tf.autograph.to_graph`.

  Note: If a function has been decorated with `tf.function`, pass its
  underlying Python function, rather than the callable that `tf.function
  creates:

  >>> @tf.function
  ... def f(x):
  ...   if x < 0:
  ...     x = -x
  ...   return x
  >>> tf.autograph.to_code(f.python_function)
  "...def tf__f(x):..."

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the converted
      function may call.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value.

  Returns:
    The converted code as string.
  """
  source = tf_inspect.getsource(
      to_graph(
          entity,
          recursive=recursive,
          experimental_optional_features=experimental_optional_features))
  return textwrap.dedent(source)
Ejemplo n.º 27
0
def to_code(entity,
            recursive=True,
            arg_values=None,
            arg_types=None,
            indentation='  ',
            experimental_optional_features=converter.Feature.ALL):
    """Similar to `to_graph`, but returns Python source code as a string.

  Also see: `tf.autograph.to_graph`.

  `to_graph` returns the Python source code that can be used to generate a
  TensorFlow graph that is functionally identical to the input Python code.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    indentation: The string to use for indenting. Typically two or four spaces,
      or just the tab character.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.

  Returns:
    The converted code as string.
  """
    # TODO(b/129431421): Remove this arg.
    del indentation
    source = tf_inspect.getsource(
        to_graph(
            entity,
            recursive=recursive,
            arg_values=arg_values,
            arg_types=arg_types,
            experimental_optional_features=experimental_optional_features))
    return textwrap.dedent(source)
Ejemplo n.º 28
0
def to_code(entity,
            recursive=True,
            arg_values=None,
            arg_types=None,
            indentation='  ',
            experimental_optional_features=None):
  """Similar to `to_graph`, but returns Python source code as a string.

  Also see: `tf.autograph.to_graph`.

  `to_graph` returns the Python source code that can be used to generate a
  TensorFlow graph that is functionally identical to the input Python code.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    indentation: The string to use for indenting. Typically two or four spaces,
      or just the tab character.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.

  Returns:
    The converted code as string.
  """
  # TODO(b/129431421): Remove this arg.
  del indentation
  source = tf_inspect.getsource(
      to_graph(
          entity,
          recursive=recursive,
          arg_values=arg_values,
          arg_types=arg_types,
          experimental_optional_features=experimental_optional_features))
  return textwrap.dedent(source)
Ejemplo n.º 29
0
def parse_entity(entity):
  """Return the AST of given entity."""
  source = tf_inspect.getsource(entity)
  source = textwrap.dedent(source)
  return parse_str(source), source
Ejemplo n.º 30
0
def parse_object(obj):
  """Return the AST of given object."""
  return parse_str(tf_inspect.getsource(obj))
Ejemplo n.º 31
0
def _generate_signature(func, reverse_index):
  """Given a function, returns a list of strings representing its args.

  This function produces a list of strings representing the arguments to a
  python function. It uses tf_inspect.getargspec, which
  does not generalize well to Python 3.x, which is more flexible in how *args
  and **kwargs are handled. This is not a problem in TF, since we have to remain
  compatible to Python 2.7 anyway.

  This function uses `__name__` for callables if it is available. This can lead
  to poor results for functools.partial and other callable objects.

  The returned string is Python code, so if it is included in a Markdown
  document, it should be typeset as code (using backticks), or escaped.

  Args:
    func: A function, method, or functools.partial to extract the signature for.
    reverse_index: A map from object ids to canonical full names to use.

  Returns:
    A list of strings representing the argument signature of `func` as python
    code.
  """

  args_list = []

  argspec = _get_arg_spec(func)
  first_arg_with_default = (
      len(argspec.args or []) - len(argspec.defaults or []))

  # Python documentation skips `self` when printing method signatures.
  # Note we cannot test for ismethod here since unbound methods do not register
  # as methods (in Python 3).
  first_arg = 1 if 'self' in argspec.args[:1] else 0

  # Add all args without defaults.
  for arg in argspec.args[first_arg:first_arg_with_default]:
    args_list.append(arg)

  # Add all args with defaults.
  if argspec.defaults:
    try:
      source = _remove_first_line_indent(tf_inspect.getsource(func))
      func_ast = ast.parse(source)
      ast_defaults = func_ast.body[0].args.defaults
    except IOError:  # If this is a builtin, getsource fails with IOError
      # If we cannot get the source, assume the AST would be equal to the repr
      # of the defaults.
      ast_defaults = [None] * len(argspec.defaults)

    for arg, default, ast_default in zip(
        argspec.args[first_arg_with_default:], argspec.defaults, ast_defaults):
      if id(default) in reverse_index:
        default_text = reverse_index[id(default)]
      elif ast_default is not None:
        default_text = codegen.to_source(ast_default)
        if default_text != repr(default):
          # This may be an internal name. If so, handle the ones we know about.
          # TODO(wicke): This should be replaced with a lookup in the index.
          # TODO(wicke): (replace first ident with tf., check if in index)
          internal_names = {
              'ops.GraphKeys': 'tf.GraphKeys',
              '_ops.GraphKeys': 'tf.GraphKeys',
              'init_ops.zeros_initializer': 'tf.zeros_initializer',
              'init_ops.ones_initializer': 'tf.ones_initializer',
              'saver_pb2.SaverDef': 'tf.train.SaverDef',
          }
          full_name_re = '^%s(.%s)+' % (IDENTIFIER_RE, IDENTIFIER_RE)
          match = re.match(full_name_re, default_text)
          if match:
            lookup_text = default_text
            for internal_name, public_name in six.iteritems(internal_names):
              if match.group(0).startswith(internal_name):
                lookup_text = public_name + default_text[len(internal_name):]
                break
            if default_text is lookup_text:
              print('WARNING: Using default arg, failed lookup: %s, repr: %r' %
                    (default_text, default))
            else:
              default_text = lookup_text
      else:
        default_text = repr(default)

      args_list.append('%s=%s' % (arg, default_text))

  # Add *args and *kwargs.
  if argspec.varargs:
    args_list.append('*' + argspec.varargs)
  if argspec.keywords:
    args_list.append('**' + argspec.keywords)

  return args_list
Ejemplo n.º 32
0
def parse_object(obj):
    """Return the AST of given object."""
    return parse_str(tf_inspect.getsource(obj))
Ejemplo n.º 33
0
def parse_entity(entity):
    """Return the AST of given entity."""
    source = tf_inspect.getsource(entity)
    source = textwrap.dedent(source)
    return parse_str(source), source
Ejemplo n.º 34
0
def parse_entity(entity):
    """Returns the AST of given entity."""
    source = tf_inspect.getsource(entity)

    def fail(comment):
        raise ValueError(
            'Failed to parse source code of {}, which Python reported as:\n{}\n'
            '{}'.format(entity, source, comment))

    # Comments and multiline strings can appear at arbitrary indentation levels,
    # causing textwrap.dedent to not correctly dedent source code.
    # TODO(b/115884650): Automatic handling of comments/multiline strings.
    source = textwrap.dedent(source)

    try:
        return parse_str(source), source

    except IndentationError:
        # The text below lists the causes of this error known to us. There may
        # be more.
        fail(
            'This may be caused by multiline strings or comments not indented at'
            'the same level as the code.')

    except SyntaxError as e:
        if not tf_inspect.isfunction(entity) or entity.__name__ != '<lambda>':
            raise

        # Certain entities, like lambdas, only hold the raw code lines which defined
        # them, which may include surrounding tokens and may be syntactically
        # invalid out of context. For example:
        #
        #     l = (
        #         lambda x: x,)[0]
        #
        # will have the dedented source "lambda x: x,)[0]"
        # Here we make an attempt to stip away the garbage by looking at the
        # information in the syntax error.
        lines = source.split('\n')
        lineno, offset = e.lineno, e.offset  # 1-based

        # Give up if there's nothing we can chip away.
        if len(lines) == lineno and len(lines[-1]) == offset:
            fail(
                'If this is a lambda function, the error may be avoided by creating'
                ' the lambda in a standalone statement.')

        # Drop all lines following the error location
        # TODO(mdan): What's with the pylint errors?
        lines = lines[:lineno]  # pylint:disable=invalid-slice-index
        # Drop all characters following the error location
        lines[-1] = lines[-1][:offset - 1]  # pylint:disable=invalid-slice-index
        new_source = '\n'.join(lines)

        try:
            return parse_str(new_source), new_source
        except SyntaxError as e:
            fail(
                'If this is a lambda function, the error may be avoided by creating'
                ' the lambda in a standalone statement. Tried to strip down the'
                ' source to:\n{}\nBut that did not work.'.format(new_source))
Ejemplo n.º 35
0
def _generate_signature(func, reverse_index):
    """Given a function, returns a list of strings representing its args.

  This function produces a list of strings representing the arguments to a
  python function. It uses tf_inspect.getargspec, which
  does not generalize well to Python 3.x, which is more flexible in how *args
  and **kwargs are handled. This is not a problem in TF, since we have to remain
  compatible to Python 2.7 anyway.

  This function uses `__name__` for callables if it is available. This can lead
  to poor results for functools.partial and other callable objects.

  The returned string is Python code, so if it is included in a Markdown
  document, it should be typeset as code (using backticks), or escaped.

  Args:
    func: A function, method, or functools.partial to extract the signature for.
    reverse_index: A map from object ids to canonical full names to use.

  Returns:
    A list of strings representing the argument signature of `func` as python
    code.
  """

    args_list = []

    argspec = _get_arg_spec(func)
    first_arg_with_default = (len(argspec.args or []) -
                              len(argspec.defaults or []))

    # Python documentation skips `self` when printing method signatures.
    # Note we cannot test for ismethod here since unbound methods do not register
    # as methods (in Python 3).
    first_arg = 1 if 'self' in argspec.args[:1] else 0

    # Add all args without defaults.
    for arg in argspec.args[first_arg:first_arg_with_default]:
        args_list.append(arg)

    # Add all args with defaults.
    if argspec.defaults:
        try:
            source = _remove_first_line_indent(tf_inspect.getsource(func))
            func_ast = ast.parse(source)
            ast_defaults = func_ast.body[0].args.defaults
        except IOError:  # If this is a builtin, getsource fails with IOError
            # If we cannot get the source, assume the AST would be equal to the repr
            # of the defaults.
            ast_defaults = [None] * len(argspec.defaults)

        for arg, default, ast_default in zip(
                argspec.args[first_arg_with_default:], argspec.defaults,
                ast_defaults):
            if id(default) in reverse_index:
                default_text = reverse_index[id(default)]
            elif ast_default is not None:
                default_text = codegen.to_source(ast_default)
                if default_text != repr(default):
                    # This may be an internal name. If so, handle the ones we know about.
                    # TODO(wicke): This should be replaced with a lookup in the index.
                    # TODO(wicke): (replace first ident with tf., check if in index)
                    internal_names = {
                        'ops.GraphKeys': 'tf.GraphKeys',
                        '_ops.GraphKeys': 'tf.GraphKeys',
                        'init_ops.zeros_initializer': 'tf.zeros_initializer',
                        'init_ops.ones_initializer': 'tf.ones_initializer',
                        'saver_pb2.SaverDef': 'tf.train.SaverDef',
                    }
                    full_name_re = '^%s(.%s)+' % (IDENTIFIER_RE, IDENTIFIER_RE)
                    match = re.match(full_name_re, default_text)
                    if match:
                        lookup_text = default_text
                        for internal_name, public_name in six.iteritems(
                                internal_names):
                            if match.group(0).startswith(internal_name):
                                lookup_text = public_name + default_text[
                                    len(internal_name):]
                                break
                        if default_text is lookup_text:
                            print(
                                'WARNING: Using default arg, failed lookup: %s, repr: %r'
                                % (default_text, default))
                        else:
                            default_text = lookup_text
            else:
                default_text = repr(default)

            args_list.append('%s=%s' % (arg, default_text))

    # Add *args and *kwargs.
    if argspec.varargs:
        args_list.append('*' + argspec.varargs)
    if argspec.keywords:
        args_list.append('**' + argspec.keywords)

    return args_list