def assert_same_ast(self, expected_node, node, msg=None):
     expected_source = compiler.ast_to_source(expected_node,
                                              indentation='  ')
     expected_str = textwrap.dedent(expected_source).strip()
     got_source = compiler.ast_to_source(node, indentation='  ')
     got_str = textwrap.dedent(got_source).strip()
     self.assertEqual(expected_str, got_str, msg=msg)
Esempio n. 2
0
 def __repr__(self):
     if isinstance(self.ast_node, gast.FunctionDef):
         return 'def %s' % self.ast_node.name
     elif isinstance(self.ast_node, gast.ClassDef):
         return 'class %s' % self.ast_node.name
     elif isinstance(self.ast_node, gast.withitem):
         return compiler.ast_to_source(self.ast_node.context_expr).strip()
     return compiler.ast_to_source(self.ast_node).strip()
Esempio n. 3
0
    def get_definition_directive(self, node, directive, arg, default):
        """Returns the unique directive argument for a symbol.

    See lang/directives.py for details on directives.

    Example:
       # Given a directive in the code:
       ag.foo_directive(bar, baz=1)

       # One can write for an AST node Name(id='bar'):
       get_definition_directive(node, ag.foo_directive, 'baz')

    Args:
      node: ast.AST, the node representing the symbol for which the directive
        argument is needed.
      directive: Callable[..., Any], the directive to search.
      arg: str, the directive argument to return.
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
        defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
        if not defs:
            return default

        arg_values_found = []
        for def_ in defs:
            if (directive in def_.directives
                    and arg in def_.directives[directive]):
                arg_values_found.append(def_.directives[directive][arg])

        if not arg_values_found:
            return default

        if len(arg_values_found) == 1:
            return arg_values_found[0]

        # If multiple annotations reach the symbol, they must all match. If they do,
        # return any of them.
        first_value = arg_values_found[0]
        for other_value in arg_values_found[1:]:
            if not ast_util.matches(first_value, other_value):
                qn = anno.getanno(node, anno.Basic.QN)
                raise ValueError(
                    '%s has ambiguous annotations for %s(%s): %s, %s' %
                    (qn, directive.__name__, arg,
                     compiler.ast_to_source(other_value).strip(),
                     compiler.ast_to_source(first_value).strip()))
        return first_value
Esempio n. 4
0
 def assert_body_anfs_as_expected(self, expected_fn, test_fn):
   # Testing the code bodies only.  Wrapping them in functions so the
   # syntax highlights nicely, but Python doesn't try to execute the
   # statements.
   node, _ = parser.parse_entity(test_fn, future_features=())
   orig_source = compiler.ast_to_source(node, indentation='  ')
   orig_str = textwrap.dedent(orig_source).strip()
   config = [(anf.ANY, anf.LEAVE)]  # Configuration to trasform nothing
   node = anf.transform(
       node, self._simple_context(),
       config=config, gensym_source=DummyGensym)
   new_source = compiler.ast_to_source(node, indentation='  ')
   new_str = textwrap.dedent(new_source).strip()
   self.assertEqual(orig_str, new_str)
Esempio n. 5
0
  def get_definition_directive(self, node, directive, arg, default):
    """Returns the unique directive argument for a symbol.

    See lang/directives.py for details on directives.

    Example:
       # Given a directive in the code:
       ag.foo_directive(bar, baz=1)

       # One can write for an AST node Name(id='bar'):
       get_definition_directive(node, ag.foo_directive, 'baz')

    Args:
      node: ast.AST, the node representing the symbol for which the directive
        argument is needed.
      directive: Callable[..., Any], the directive to search.
      arg: str, the directive argument to return.
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
    defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
    if not defs:
      return default

    arg_values_found = []
    for def_ in defs:
      if (directive in def_.directives and arg in def_.directives[directive]):
        arg_values_found.append(def_.directives[directive][arg])

    if not arg_values_found:
      return default

    if len(arg_values_found) == 1:
      return arg_values_found[0]

    # If multiple annotations reach the symbol, they must all match. If they do,
    # return any of them.
    first_value = arg_values_found[0]
    for other_value in arg_values_found[1:]:
      if not ast_util.matches(first_value, other_value):
        qn = anno.getanno(node, anno.Basic.QN)
        raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
                         (qn, directive.__name__, arg,
                          compiler.ast_to_source(other_value).strip(),
                          compiler.ast_to_source(first_value).strip()))
    return first_value
Esempio n. 6
0
 def assertLambdaNodes(self, matching_nodes, expected_bodies):
   self.assertEqual(len(matching_nodes), len(expected_bodies))
   for node in matching_nodes:
     self.assertIsInstance(node, gast.Lambda)
     self.assertIn(
         compiler.ast_to_source(node.body,
                                include_encoding_marker=False).strip(),
         expected_bodies)
Esempio n. 7
0
    def test_rename_symbols_attributes(self):
        node = parser.parse_str('b.c = b.c.d')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

        source = compiler.ast_to_source(node)
        self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
Esempio n. 8
0
  def test_rename_symbols_attributes(self):
    node = parser.parse_str('b.c = b.c.d')
    node = qual_names.resolve(node)

    node = ast_util.rename_symbols(
        node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

    source = compiler.ast_to_source(node)
    self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
Esempio n. 9
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
    """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
      parameters.
    arg_types: A dict containing type hints for symbols like function
      parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        nodes, name, entity_info = class_to_graph(o, program_ctx)
    elif tf_inspect.isfunction(o):
        nodes, name, entity_info = function_to_graph(o, program_ctx,
                                                     arg_values, arg_types)
    elif tf_inspect.ismethod(o):
        nodes, name, entity_info = function_to_graph(o, program_ctx,
                                                     arg_values, arg_types)
    # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
    elif hasattr(o, '__class__'):
        raise NotImplementedError(
            'Object conversion is not yet supported. If you are '
            'trying to convert code that uses an existing object, '
            'try including the creation of that object in the '
            'conversion. For example, instead of converting the method '
            'of a class, try converting the entire class instead. '
            'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
            'python/autograph/README.md#using-the-functional-api '
            'for more information.')
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    compiler.ast_to_source(nodes))
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
Esempio n. 10
0
 def _get_source(self, node):
     try:
         source, _ = compiler.ast_to_source(node)
         return source
     # pylint: disable=broad-except
     # This function is used for error reporting.  If an exception occurs here,
     # it should be suppressed, in favor of emitting as informative a message
     # about the original error as possible.
     except Exception:
         return '<could not convert AST to source>'
Esempio n. 11
0
    def test_rename_symbols_basic(self):
        node = parser.parse_str('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

        self.assertIsInstance(node.body[0].value.left.id, str)
        source = compiler.ast_to_source(node)
        self.assertEqual(source.strip(), 'renamed_a + b')
Esempio n. 12
0
  def test_rename_symbols_basic(self):
    node = parser.parse_str('a + b')
    node = qual_names.resolve(node)

    node = ast_util.rename_symbols(
        node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

    self.assertIsInstance(node.body[0].value.left.id, str)
    source = compiler.ast_to_source(node)
    self.assertEqual(source.strip(), 'renamed_a + b')
Esempio n. 13
0
 def _get_source(self, node):
   try:
     source, _ = compiler.ast_to_source(node)
     return source
   # pylint: disable=broad-except
   # This function is used for error reporting.  If an exception occurs here,
   # it should be suppressed, in favor of emitting as informative a message
   # about the original error as possible.
   except Exception:
     return '<could not convert AST to source>'
Esempio n. 14
0
    def test_source_map_no_origin(self):
        def test_fn(x):
            return x + 1

        node, _, _ = parser.parse_entity(test_fn)
        converted_code = compiler.ast_to_source(node)

        source_map = origin_info.create_source_map(node, converted_code,
                                                   'test_filename', [0])

        self.assertEqual(len(source_map), 0)
Esempio n. 15
0
    def get_definition_directive(self, node, directive, arg, default):
        """Returns the unique directive for a symbol, or a default if none exist.

    See lang/directives.py for details on directives.

    Args:
      node: ast.AST
      directive: Callable[..., Any]
      arg: str
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
        defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
        if not defs:
            return default

        # TODO(mdan): Simplify this.
        arg_values = []
        for def_ in defs:
            if (directive not in def_.directives
                    or arg not in def_.directives[directive]):
                continue
            arg_value = def_.directives[directive][arg]
            for prev_value in arg_values:
                if not ast_util.matches(arg_value, prev_value):
                    qn = anno.getanno(node, anno.Basic.QN)
                    raise ValueError(
                        '%s has ambiguous annotations for %s(%s): %s, %s' %
                        (qn, directive.__name__, arg,
                         compiler.ast_to_source(arg_value).strip(),
                         compiler.ast_to_source(prev_value).strip()))
            arg_values.append(arg_value)

        if not arg_values:
            return default

        arg_value, = arg_values
        return arg_value
Esempio n. 16
0
  def test_entity_to_graph_function_with_defaults(self):
    b = 2
    c = 1
    def f(a, d=c + 1):
      return a + b + d

    program_ctx = self._simple_program_ctx()
    nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
    fn_node, _ = nodes
    self.assertIsInstance(fn_node, gast.FunctionDef)
    self.assertEqual('tf__f', name)
    self.assertEqual(
        compiler.ast_to_source(fn_node.args.defaults[0]).strip(), 'None')
  def test_source_map_no_origin(self):

    def test_fn(x):
      return x + 1

    node, _ = parser.parse_entity(test_fn)
    fn_node = node.body[0]
    converted_code = compiler.ast_to_source(fn_node)

    source_map = origin_info.create_source_map(
        fn_node, converted_code, 'test_filename', [0])

    self.assertEqual(len(source_map), 0)
Esempio n. 18
0
  def test_entity_to_graph_function_with_defaults(self):
    b = 2
    c = 1
    def f(a, d=c + 1):
      return a + b + d

    program_ctx = self._simple_program_ctx()
    nodes, name, _ = conversion.entity_to_graph(f, program_ctx, None, None)
    fn_node, _ = nodes
    self.assertIsInstance(fn_node, gast.FunctionDef)
    self.assertEqual('tf__f', name)
    self.assertEqual(
        compiler.ast_to_source(fn_node.args.defaults[0]).strip(), 'None')
Esempio n. 19
0
  def get_definition_directive(self, node, directive, arg, default):
    """Returns the unique directive for a symbol, or a default if none exist.

    See lang/directives.py for details on directives.

    Args:
      node: ast.AST
      directive: Callable[..., Any]
      arg: str
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
    defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
    if not defs:
      return default

    # TODO(mdan): Simplify this.
    arg_values = []
    for def_ in defs:
      if (directive not in def_.directives or
          arg not in def_.directives[directive]):
        continue
      arg_value = def_.directives[directive][arg]
      for prev_value in arg_values:
        if not ast_util.matches(arg_value, prev_value):
          qn = anno.getanno(node, anno.Basic.QN)
          raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
                           (qn, directive.__name__, arg,
                            compiler.ast_to_source(arg_value).strip(),
                            compiler.ast_to_source(prev_value).strip()))
      arg_values.append(arg_value)

    if not arg_values:
      return default

    arg_value, = arg_values
    return arg_value
Esempio n. 20
0
def to_code(entity,
            recursive=True,
            arg_values=None,
            arg_types=None,
            indentation='  ',
            experimental_optional_features=converter.Feature.ALL,
            experimental_partial_types=None):
  """Similar to `to_graph`, but returns Python source code as a string.

  Also see: `tf.autograph.to_graph`.

  `to_graph` returns the Python source code that can be used to generate a
  TensorFlow graph that is functionally identical to the input Python code.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    indentation: The string to use for indenting. Typically two or four spaces,
      or just the tab character.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    The converted code as string.
  """
  program_ctx = converter.ProgramContext(
      options=converter.ConversionOptions(
          recursive=recursive,
          verbose=converter.Verbosity.BRIEF,
          strip_decorators=(convert, do_not_convert, converted_call),
          optional_features=experimental_optional_features),
      partial_types=experimental_partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(entity, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
      for dep in reversed(program_ctx.conversion_order))

  return program_ctx.required_imports + '\n\n' + code
Esempio n. 21
0
def convert_entity_to_ast(o, program_ctx):
  """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  logging.log(1, 'Converting %s', o)

  if tf_inspect.isclass(o):
    nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
  elif tf_inspect.isfunction(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  elif tf_inspect.ismethod(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
  elif hasattr(o, '__class__'):
    raise NotImplementedError(
        'Object conversion is not yet supported. If you are '
        'trying to convert code that uses an existing object, '
        'try including the creation of that object in the '
        'conversion. For example, instead of converting the method '
        'of a class, try converting the entire class instead. '
        'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
        'python/autograph/README.md#using-the-functional-api '
        'for more information.')
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  if logging.has_verbosity(2):
    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                compiler.ast_to_source(nodes))
  if logging.has_verbosity(4):
    for n in nodes:
      logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                  pretty_printer.fmt(n, color=False))

  return nodes, name, entity_info
Esempio n. 22
0
def to_code(entity,
            recursive=True,
            arg_values=None,
            arg_types=None,
            indentation='  ',
            experimental_optional_features=converter.Feature.ALL,
            experimental_partial_types=None):
    """Similar to `to_graph`, but returns Python source code as a string.

  Also see: `tf.autograph.to_graph`.

  `to_graph` returns the Python source code that can be used to generate a
  TensorFlow graph that is functionally identical to the input Python code.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    indentation: The string to use for indenting. Typically two or four spaces,
      or just the tab character.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    The converted code as string.
  """
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            verbose=converter.Verbosity.BRIEF,
            strip_decorators=(convert, do_not_convert, converted_call),
            optional_features=experimental_optional_features),
        partial_types=experimental_partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    conversion.entity_to_graph(entity, program_ctx, arg_values, arg_types)

    code = '\n'.join(
        compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
        for dep in reversed(program_ctx.conversion_order))

    return program_ctx.required_imports + '\n\n' + code
Esempio n. 23
0
    def test_convert_entity_to_ast_function_with_defaults(self):
        b = 2
        c = 1

        def f(a, d=c + 1):
            return a + b + d

        program_ctx = self._simple_program_ctx()
        nodes, name, _ = conversion.convert_entity_to_ast(f, program_ctx)
        fn_node, = nodes
        self.assertIsInstance(fn_node, gast.FunctionDef)
        self.assertEqual('tf__f', name)
        self.assertEqual(
            compiler.ast_to_source(fn_node.args.defaults[0],
                                   include_encoding_marker=False).strip(),
            'None')
Esempio n. 24
0
def to_code(entity,
            recursive=True,
            arg_values=None,
            arg_types=None,
            indentation='  ',
            experimental_optional_features=converter.Feature.ALL,
            experimental_partial_types=None):
    """Similar to `to_graph`, but returns Python source code as a string.

  Also see: `tf.autograph.to_graph`.

  `to_graph` returns the Python source code that can be used to generate a
  TensorFlow graph that is functionally identical to the input Python code.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    indentation: The string to use for indenting. Typically two or four spaces,
      or just the tab character.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_partial_types:  Deprecated, unused.

  Returns:
    The converted code as string.
  """
    del experimental_partial_types

    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            optional_features=experimental_optional_features),
        autograph_module=tf_inspect.getmodule(to_graph))
    nodes, _, _ = conversion.entity_to_graph(entity, program_ctx, arg_values,
                                             arg_types)

    code = compiler.ast_to_source(nodes, indentation)

    return program_ctx.required_imports + '\n\n' + code
Esempio n. 25
0
def convert_entity_to_ast(o, program_ctx):
    """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
    elif tf_inspect.isfunction(o):
        nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
    elif tf_inspect.ismethod(o):
        nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
    elif hasattr(o, '__class__'):
        # Note: this should only be raised when attempting to convert the object
        # directly. converted_call should still support it.
        raise NotImplementedError(
            'cannot convert entity "{}": object conversion is not yet'
            ' supported.'.format(o))
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    compiler.ast_to_source(nodes))
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
Esempio n. 26
0
def convert_entity_to_ast(o, program_ctx):
  """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  logging.log(1, 'Converting %s', o)

  if tf_inspect.isclass(o):
    nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
  elif tf_inspect.isfunction(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  elif tf_inspect.ismethod(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  elif hasattr(o, '__class__'):
    # Note: this should only be raised when attempting to convert the object
    # directly. converted_call should still support it.
    raise NotImplementedError(
        'cannot convert entity "{}": object conversion is not yet'
        ' supported.'.format(o))
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  if logging.has_verbosity(2):
    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                compiler.ast_to_source(nodes))
  if logging.has_verbosity(4):
    for n in nodes:
      logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                  pretty_printer.fmt(n, color=False))

  return nodes, name, entity_info
Esempio n. 27
0
    def test_create_source_map(self):
        def test_fn(x):
            return x + 1

        node, _, _ = parser.parse_entity(test_fn)
        fake_origin = origin_info.OriginInfo(
            loc=origin_info.Location('fake_filename', 3, 7),
            function_name='fake_function_name',
            source_code_line='fake source line',
            comment=None)
        anno.setanno(node.body[0], anno.Basic.ORIGIN, fake_origin)
        converted_code = compiler.ast_to_source(node)

        source_map = origin_info.create_source_map(node, converted_code,
                                                   'test_filename', [0])

        loc = origin_info.LineLocation('test_filename', 2)
        self.assertIn(loc, source_map)
        self.assertIs(source_map[loc], fake_origin)
    def test_source_map(self):
        def test_fn(x):
            if x > 0:
                x += 1
            return x

        node, source = parser.parse_entity(test_fn)
        fn_node = node.body[0]
        origin_info.resolve(fn_node, source)

        # Insert a traced line.
        new_node = parser.parse_str('x = abs(x)').body[0]
        anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
        fn_node.body.insert(0, new_node)

        # Insert an untraced line.
        fn_node.body.insert(0, parser.parse_str('x = 0').body[0])

        modified_source = compiler.ast_to_source(fn_node)

        source_map = origin_info.source_map(fn_node, modified_source,
                                            'test_filename', [0])

        loc = origin_info.LineLocation('test_filename', 1)
        origin = source_map[loc]
        self.assertEqual(origin.source_code_line, 'def test_fn(x):')
        self.assertEqual(origin.loc.lineno, 1)

        # The untraced line, inserted second.
        loc = origin_info.LineLocation('test_filename', 2)
        self.assertFalse(loc in source_map)

        # The traced line, inserted first.
        loc = origin_info.LineLocation('test_filename', 3)
        origin = source_map[loc]
        self.assertEqual(origin.source_code_line, '  if x > 0:')
        self.assertEqual(origin.loc.lineno, 2)

        loc = origin_info.LineLocation('test_filename', 4)
        origin = source_map[loc]
        self.assertEqual(origin.source_code_line, '  if x > 0:')
        self.assertEqual(origin.loc.lineno, 2)
Esempio n. 29
0
    def test_ast_to_source(self):
        node = gast.If(
            test=gast.Num(1),
            body=[
                gast.Assign(targets=[gast.Name('a', gast.Store(), None)],
                            value=gast.Name('b', gast.Load(), None))
            ],
            orelse=[
                gast.Assign(targets=[gast.Name('a', gast.Store(), None)],
                            value=gast.Str('c'))
            ])

        source = compiler.ast_to_source(node, indentation='  ')
        self.assertEqual(
            textwrap.dedent("""
            if 1:
              a = b
            else:
              a = 'c'
        """).strip(), source.strip())
  def test_create_source_map(self):

    def test_fn(x):
      return x + 1

    node, _ = parser.parse_entity(test_fn)
    fake_origin = origin_info.OriginInfo(
        loc=origin_info.Location('fake_filename', 3, 7),
        function_name='fake_function_name',
        source_code_line='fake source line',
        comment=None)
    fn_node = node.body[0]
    anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin)
    converted_code = compiler.ast_to_source(fn_node)

    source_map = origin_info.create_source_map(
        fn_node, converted_code, 'test_filename', [0])

    loc = origin_info.LineLocation('test_filename', 2)
    self.assertIn(loc, source_map)
    self.assertIs(source_map[loc], fake_origin)
Esempio n. 31
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
    """Returns the equivalent code that uses TensorFlow ops.

  Also see: `to_graph`, `convert`

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
      function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
      function arguments.
    partial_types: Set[Type], reserved for internal use.
    indentation: Text, when to use for each level of indentation.

  Returns:
    Text, the converted code.
  """
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=recursive,
                                            strip_decorators=(convert,
                                                              do_not_convert,
                                                              converted_call)),
        partial_types=partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

    code = '\n'.join(
        compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
        for dep in reversed(program_ctx.conversion_order))

    return program_ctx.required_imports + '\n\n' + code
Esempio n. 32
0
  def test_ast_to_source(self):
    node = gast.If(
        test=gast.Num(1),
        body=[
            gast.Assign(
                targets=[gast.Name('a', gast.Store(), None)],
                value=gast.Name('b', gast.Load(), None))
        ],
        orelse=[
            gast.Assign(
                targets=[gast.Name('a', gast.Store(), None)],
                value=gast.Str('c'))
        ])

    source = compiler.ast_to_source(node, indentation='  ')
    self.assertEqual(
        textwrap.dedent("""
            if 1:
              a = b
            else:
              a = 'c'
        """).strip(), source.strip())
Esempio n. 33
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Returns the equivalent code that uses TensorFlow ops.

  Also see: `to_graph`, `convert`

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
      function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
      function arguments.
    partial_types: Set[Type], reserved for internal use.
    indentation: Text, when to use for each level of indentation.

  Returns:
    Text, the converted code.
  """
  program_ctx = converter.ProgramContext(
      options=converter.ConversionOptions(
          recursive=recursive,
          strip_decorators=(convert, do_not_convert, converted_call)),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
      for dep in reversed(program_ctx.conversion_order))

  return program_ctx.required_imports + '\n\n' + code
Esempio n. 34
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
    """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        node, name, ns = class_to_graph(o, program_ctx)
    elif tf_inspect.isfunction(o):
        node, name, ns = function_to_graph(o, program_ctx, arg_values,
                                           arg_types)
    elif tf_inspect.ismethod(o):
        node, name, ns = function_to_graph(o, program_ctx, arg_values,
                                           arg_types)
    # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
    elif hasattr(o, '__class__'):
        raise NotImplementedError(
            'Object conversion is not yet supported. If you are '
            'trying to convert code that uses an existing object, '
            'try including the creation of that object in the '
            'conversion. For example, instead of converting the method '
            'of a class, try converting the entire class instead. '
            'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
            'contrib/autograph/README.md#using-the-functional-api '
            'for more information.')
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    # TODO(mdan): This is temporary. it should be created using a converter.
    # TODO(mdan): The attribute should be added with a helper, not directly.
    # The helper can ensure there are no collisions.
    template = '''
      entity.autograph_info__ = {}
  '''
    node.extend(templates.replace(template, entity=name))

    program_ctx.add_to_cache(o, node)

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    compiler.ast_to_source(node))

    if program_ctx.options.recursive:
        while True:
            candidate = None
            for obj in program_ctx.name_map.keys():
                if obj not in program_ctx.dependency_cache:
                    candidate = obj
                    break
            if candidate is None:
                break
            if (hasattr(candidate, 'im_class') and getattr(
                    candidate, 'im_class') not in program_ctx.partial_types):
                # Class members are converted with their objects, unless they're
                # only converted partially.
                continue
            entity_to_graph(candidate, program_ctx, {}, {})

    return node, name, ns
Esempio n. 35
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
  """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  if program_ctx.options.verbose == converter.Verbosity.VERBOSE:
    logging.info('Converting {}'.format(o))

  if tf_inspect.isclass(o):
    node, name, ns = class_to_graph(o, program_ctx)
  elif tf_inspect.isfunction(o):
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  elif tf_inspect.ismethod(o):
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
  elif hasattr(o, '__class__'):
    raise NotImplementedError(
        'Object conversion is not yet supported. If you are '
        'trying to convert code that uses an existing object, '
        'try including the creation of that object in the '
        'conversion. For example, instead of converting the method '
        'of a class, try converting the entire class instead. '
        'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
        'contrib/autograph/README.md#using-the-functional-api '
        'for more information.')
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  # TODO(mdan): This is temporary. it should be created using a converter.
  # TODO(mdan): The attribute should be added with a helper, not directly.
  # The helper can ensure there are no collisions.
  template = '''
      entity.autograph_info__ = {}
  '''
  node.extend(templates.replace(template, entity=name))

  program_ctx.add_to_cache(o, node)

  if program_ctx.options.verbose == converter.Verbosity.VERBOSE:
    logging.info('Compiled output of {}:\n\n{}\n'.format(
        o, compiler.ast_to_source(node)))

  if program_ctx.options.recursive:
    while True:
      candidate = None
      for obj in program_ctx.name_map.keys():
        if obj not in program_ctx.dependency_cache:
          candidate = obj
          break
      if candidate is None:
        break
      if (hasattr(candidate, 'im_class') and
          getattr(candidate, 'im_class') not in program_ctx.partial_types):
        # Class members are converted with their objects, unless they're
        # only converted partially.
        continue
      entity_to_graph(candidate, program_ctx, {}, {})

  return node, name, ns
Esempio n. 36
0
 def assert_same_ast(self, expected_node, node, msg=None):
   expected_source = compiler.ast_to_source(expected_node, indentation='  ')
   expected_str = textwrap.dedent(expected_source).strip()
   got_source = compiler.ast_to_source(node, indentation='  ')
   got_str = textwrap.dedent(got_source).strip()
   self.assertEqual(expected_str, got_str, msg=msg)
Esempio n. 37
0
 def debug_print_src(self, node):
     """Helper method useful for debugging. Prints the AST as code."""
     if __debug__:
         print(compiler.ast_to_source(node))
     return node
Esempio n. 38
0
 def debug_print_src(self, node):
   """Helper method useful for debugging. Prints the AST as code."""
   if __debug__:
     print(compiler.ast_to_source(node))
   return node
Esempio n. 39
0
 def _mock_apply_fn(self, target, source):
   target = compiler.ast_to_source(target, include_encoding_marker=False)
   source = compiler.ast_to_source(source, include_encoding_marker=False)
   self._invocation_counts[(target.strip(), source.strip())] += 1
Esempio n. 40
0
 def assertFunctionDefNodes(self, matching_nodes, expected_bodies):
     self.assertEqual(len(matching_nodes), len(expected_bodies))
     for node in matching_nodes:
         self.assertIsInstance(node, gast.FunctionDef)
         self.assertIn(
             compiler.ast_to_source(node.body).strip(), expected_bodies)
Esempio n. 41
0
 def assertFunctionDefNodes(self, matching_nodes, expected_bodies):
   self.assertEqual(len(matching_nodes), len(expected_bodies))
   for node in matching_nodes:
     self.assertIsInstance(node, gast.FunctionDef)
     self.assertIn(compiler.ast_to_source(node.body).strip(), expected_bodies)
Esempio n. 42
0
 def _mock_apply_fn(self, target, source):
   target = compiler.ast_to_source(target)
   source = compiler.ast_to_source(source)
   self._invocation_counts[(target.strip(), source.strip())] += 1
Esempio n. 43
0
 def _mock_apply_fn(self, target, source):
     target = compiler.ast_to_source(target)
     source = compiler.ast_to_source(source)
     self._invocation_counts[(target.strip(), source.strip())] += 1
Esempio n. 44
0
 def __repr__(self):
   if isinstance(self.ast_node, gast.FunctionDef):
     return 'def %s' % self.ast_node.name
   elif isinstance(self.ast_node, gast.withitem):
     return compiler.ast_to_source(self.ast_node.context_expr).strip()
   return compiler.ast_to_source(self.ast_node).strip()
 def assertTransformedFirstLineIs(self, node, expected):
     self.assertEqual(
         compiler.ast_to_source(
             node, include_encoding_marker=False).split('\n')[0], expected)
Esempio n. 46
0
 def assertTransformedFirstLineIs(self, node, expected):
   self.assertEqual(compiler.ast_to_source(node).split('\n')[0], expected)