Пример #1
0
  def assertAstMatches(self, actual_node, expected_node_src):
    expected_node = gast.parse(expected_node_src).body[0]

    msg = 'AST did not match expected:\n{}\nActual:\n{}'.format(
        pretty_printer.fmt(expected_node),
        pretty_printer.fmt(actual_node))
    self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
Пример #2
0
  def assertAstMatches(self, actual_node, expected_node_src, expr=True):
    if expr:
      # Ensure multi-line expressions parse.
      expected_node = gast.parse('({})'.format(expected_node_src)).body[0]
      expected_node = expected_node.value
    else:
      expected_node = gast.parse(expected_node_src).body[0]

    msg = 'AST did not match expected:\n{}\nActual:\n{}'.format(
        pretty_printer.fmt(expected_node),
        pretty_printer.fmt(actual_node))
    self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
Пример #3
0
    def compiled(self, node, namespace, *symbols):
        source = None

        self.dynamic_calls = []

        def converted_call(*args):
            """Mock version of api.converted_call."""
            self.dynamic_calls.append(args)
            return 7

        try:
            result, source = compiler.ast_to_object(node,
                                                    include_source_map=True)

            # TODO(mdan): Move this into self.prepare()
            result.tf = self.make_fake_mod('fake_tf', *symbols)
            fake_ag = self.make_fake_mod('fake_ag', converted_call,
                                         converter.ConversionOptions)
            fake_ag.__dict__.update(operators.__dict__)
            fake_ag.__dict__['utils'] = utils
            fake_ag.__dict__['rewrite_graph_construction_error'] = (
                errors.rewrite_graph_construction_error)
            fake_ag.__dict__[
                'function_scope'] = function_wrapping.function_scope
            result.__dict__['ag__'] = fake_ag
            for k, v in namespace.items():
                result.__dict__[k] = v
            yield result
        except Exception:  # pylint:disable=broad-except
            if source is None:
                print('Offending AST:\n%s' %
                      pretty_printer.fmt(node, color=False))
            else:
                print('Offending compiled code:\n%s' % source)
            raise
Пример #4
0
def convert_entity_to_ast(o, program_ctx):
    """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    NotImplementedError: if entity is of a type that is not yet supported.
  """
    logging.log(1, 'Converting %s', o)

    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    parser.unparse(nodes))
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
Пример #5
0
 def test_unicode_bytes(self):
     source = textwrap.dedent('''
 def f():
   return b'b', u'u', 'depends_py2_py3'
 ''')
     node = ast.parse(source)
     self.assertIsNotNone(pretty_printer.fmt(node))
Пример #6
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
    """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
      parameters.
    arg_types: A dict containing type hints for symbols like function
      parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        nodes, name, entity_info = class_to_graph(o, program_ctx)
    elif tf_inspect.isfunction(o):
        nodes, name, entity_info = function_to_graph(o, program_ctx,
                                                     arg_values, arg_types)
    elif tf_inspect.ismethod(o):
        nodes, name, entity_info = function_to_graph(o, program_ctx,
                                                     arg_values, arg_types)
    # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
    elif hasattr(o, '__class__'):
        raise NotImplementedError(
            'Object conversion is not yet supported. If you are '
            'trying to convert code that uses an existing object, '
            'try including the creation of that object in the '
            'conversion. For example, instead of converting the method '
            'of a class, try converting the entire class instead. '
            'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
            'python/autograph/README.md#using-the-functional-api '
            'for more information.')
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    compiler.ast_to_source(nodes))
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
Пример #7
0
    def compiled(self, node, namespace, symbols=()):
        source = None

        self.dynamic_calls = []

        # See api.converted_call
        def converted_call(f,
                           args,
                           kwargs,
                           unused_opts=None,
                           unused_function_ctx=None):
            """Mock version of api.converted_call."""
            self.dynamic_calls.append((args, kwargs))
            if kwargs is None:
                kwargs = {}
            return f(*args, **kwargs)

        def fake_autograph_artifact(f):
            setattr(f, 'fake_autograph_artifact', True)
            return f

        try:
            result, source, source_map = loader.load_ast(
                node, include_source_map=True)
            # TODO(mdan): Move the unparsing from converter into pyct and reuse here.

            # TODO(mdan): Move this into self.prepare()
            result.tf = self.make_fake_mod('fake_tf', *symbols)
            fake_ag = self.make_fake_mod('fake_ag', converted_call,
                                         converter.ConversionOptions)
            fake_ag.__dict__.update(operators.__dict__)
            fake_ag.__dict__.update(special_functions.__dict__)
            fake_ag.ConversionOptions = converter.ConversionOptions
            fake_ag.Feature = converter.Feature
            fake_ag.utils = utils
            fake_ag.FunctionScope = function_wrappers.FunctionScope
            fake_ag.autograph_artifact = fake_autograph_artifact
            result.ag__ = fake_ag
            result.ag_source_map__ = source_map
            for k, v in namespace.items():
                result.__dict__[k] = v
            yield result
        except Exception:  # pylint:disable=broad-except
            if source is None:
                print('Offending AST:\n%s' %
                      pretty_printer.fmt(node, color=False))
            else:
                print('Offending source code:\n%s' % source)
            raise
Пример #8
0
def convert_entity_to_ast(o, program_ctx):
    """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
    elif tf_inspect.isfunction(o):
        nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
    elif tf_inspect.ismethod(o):
        nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
    elif hasattr(o, '__class__'):
        # Note: this should only be raised when attempting to convert the object
        # directly. converted_call should still support it.
        raise NotImplementedError(
            'cannot convert entity "{}": object conversion is not yet'
            ' supported.'.format(o))
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    compiler.ast_to_source(nodes))
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
Пример #9
0
 def test_format(self):
     node = ast.FunctionDef(name='f',
                            args=ast.arguments(
                                args=[ast.Name(id='a', ctx=ast.Param())],
                                vararg=None,
                                kwarg=None,
                                defaults=[]),
                            body=[
                                ast.Return(
                                    ast.BinOp(op=ast.Add(),
                                              left=ast.Name(id='a',
                                                            ctx=ast.Load()),
                                              right=ast.Num(1)))
                            ],
                            decorator_list=[],
                            returns=None)
     # Just checking for functionality, the color control characters make it
     # difficult to inspect the result.
     self.assertIsNotNone(pretty_printer.fmt(node))
Пример #10
0
    def compiled(self, node, namespace, *symbols):
        source = None

        self.dynamic_calls = []

        def converted_call(*args):
            """Mock version of api.converted_call."""
            self.dynamic_calls.append(args)
            return 7

        class ConversionOptions(object):
            """Mock version of api.ConversionOptions."""
            def __init__(self, recursive):
                self.recursive = recursive

            @classmethod
            def new(cls, recursive):
                cls(recursive)

        try:
            result, source = compiler.ast_to_object(node,
                                                    include_source_map=True)

            result.tf = self.make_fake_mod('fake_tf', *symbols)
            fake_ag = self.make_fake_mod('fake_ag', converted_call,
                                         ConversionOptions)
            fake_ag.__dict__.update(operators.__dict__)
            fake_ag.__dict__['utils'] = utils
            fake_ag.__dict__['rewrite_graph_construction_error'] = (
                errors.rewrite_graph_construction_error)
            result.__dict__['ag__'] = fake_ag
            for k, v in namespace.items():
                result.__dict__[k] = v
            yield result
        except Exception:  # pylint:disable=broad-except
            if source is None:
                print('Offending AST:\n%s' %
                      pretty_printer.fmt(node, color=False))
            else:
                print('Offending compiled code:\n%s' % source)
            raise
Пример #11
0
    def compiled(self, node, namespace, *symbols):
        source = None

        self.dynamic_calls = []

        def converted_call(*args):
            """Mock version of api.converted_call."""
            self.dynamic_calls.append(
                args[3:])  # args only; see api.converted_call
            return RESULT_OF_MOCK_CONVERTED_CALL

        try:
            result, source, source_map = compiler.ast_to_object(
                node, include_source_map=True)
            # TODO(mdan): Move the unparsing from converter into pyct and reuse here.

            # TODO(mdan): Move this into self.prepare()
            result.tf = self.make_fake_mod('fake_tf', *symbols)
            fake_ag = self.make_fake_mod('fake_ag', converted_call,
                                         converter.ConversionOptions)
            fake_ag.__dict__.update(operators.__dict__)
            fake_ag.__dict__.update(special_functions.__dict__)
            fake_ag.ConversionOptions = converter.ConversionOptions
            fake_ag.Feature = converter.Feature
            fake_ag.utils = utils
            fake_ag.rewrite_graph_construction_error = (
                errors.rewrite_graph_construction_error)
            fake_ag.function_scope = function_wrapping.function_scope
            result.ag__ = fake_ag
            result.ag_source_map__ = source_map
            for k, v in namespace.items():
                result.__dict__[k] = v
            yield result
        except Exception:  # pylint:disable=broad-except
            if source is None:
                print('Offending AST:\n%s' %
                      pretty_printer.fmt(node, color=False))
            else:
                print('Offending compiled code:\n%s' % source)
            raise
Пример #12
0
 def debug_print(self, node):
     """Helper method useful for debugging. Prints the AST."""
     if __debug__:
         print(pretty_printer.fmt(node))
     return node
def create_source_map(nodes, code, filename, indices_in_code):
    """Creates a source map between an annotated AST and the code it compiles to.

  Args:
    nodes: Iterable[ast.AST, ...]
    code: Text
    filename: Optional[Text]
    indices_in_code: Union[int, Iterable[int, ...]], the positions at which
        nodes appear in code. The parser always returns a module when parsing
        code. This argument indicates the position in that module's body at
        which the corresponding of node should appear.

  Returns:
    Dict[LineLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
    reparsed_nodes = parser.parse_str(code)
    reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]
    for node in reparsed_nodes:
        resolve(node, code)

    result = {}

    try:
        for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
            # Note: generated code might not be mapped back to its origin.
            # TODO(mdan): Generated code should always be mapped to something.
            origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
            final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
            if origin_info is None or final_info is None:
                continue

            line_loc = LineLocation(filename, final_info.loc.lineno)

            existing_origin = result.get(line_loc)
            if existing_origin is not None:
                # Overlaps may exist because of child nodes, but almost never to
                # different line locations. Exception make decorated functions, where
                # both lines are mapped to the same line in the AST.

                # Line overlaps: keep bottom node.
                if existing_origin.loc.line_loc == origin_info.loc.line_loc:
                    if existing_origin.loc.lineno >= origin_info.loc.lineno:
                        continue

                # In case of overlaps, keep the leftmost node.
                if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
                    continue

            result[line_loc] = origin_info
    except ValueError:
        if logging.has_verbosity(3):
            for n, rn in zip(nodes, reparsed_nodes):
                nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
                reparsed_nodes_str = pretty_printer.fmt(rn,
                                                        color=False,
                                                        noanno=True)
                diff = difflib.context_diff(nodes_str.split('\n'),
                                            reparsed_nodes_str.split('\n'),
                                            fromfile='Original nodes',
                                            tofile='Reparsed nodes',
                                            n=7)
                diff = '\n'.join(diff)
                logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff)
        raise

    return result
Пример #14
0
def create_source_map(nodes, code, filepath):
  """Creates a source map between an annotated AST and the code it compiles to.

  Note: this function assumes nodes nodes, code and filepath correspond to the
  same code.

  Args:
    nodes: Iterable[ast.AST, ...], one or more AST modes.
    code: Text, the source code in which nodes are found.
    filepath: Text

  Returns:
    Dict[LineLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
  reparsed_nodes = parser.parse_str(code, preamble_len=0, single_node=False)
  for node in reparsed_nodes:
    resolve(node, code, filepath, node.lineno, node.col_offset)

  source_map = {}

  try:
    for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
      # Note: generated code might not be mapped back to its origin.
      # TODO(mdan): Generated code should always be mapped to something.
      origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
      final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
      if origin_info is None or final_info is None:
        continue

      # Note: the keys are by line only, excluding the column offset.
      line_loc = LineLocation(final_info.loc.filename, final_info.loc.lineno)

      existing_origin = source_map.get(line_loc)
      if existing_origin is not None:
        # Overlaps may exist because of child nodes, but almost never to
        # different line locations. Exception make decorated functions, where
        # both lines are mapped to the same line in the AST.

        # Line overlaps: keep bottom node.
        if existing_origin.loc.line_loc == origin_info.loc.line_loc:
          if existing_origin.loc.lineno >= origin_info.loc.lineno:
            continue

        # In case of column overlaps, keep the leftmost node.
        if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
          continue

      source_map[line_loc] = origin_info

  except ValueError:
    if logging.has_verbosity(3):
      for n, rn in zip(nodes, reparsed_nodes):
        nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
        reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True)
        diff = difflib.context_diff(
            nodes_str.split('\n'),
            reparsed_nodes_str.split('\n'),
            fromfile='Original nodes',
            tofile='Reparsed nodes',
            n=7)
        diff = '\n'.join(diff)
        logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff)
    raise

  return source_map
Пример #15
0
    def visit(self, node):
        if not isinstance(node, gast.AST):
            # This is not that uncommon a mistake: various node bodies are lists, for
            # example, posing a land mine for transformers that need to recursively
            # call `visit`.  The error needs to be raised before the exception handler
            # below is installed, because said handler will mess up if `node` is not,
            # in fact, a node.
            msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
                   ' visit lists of nodes, use "visit_block" instead').format(
                       type(node))
            raise ValueError(msg)

        source_code = self.entity_info.source_code
        source_file = self.entity_info.source_file
        did_enter_function = False
        local_scope_size_at_entry = len(self._local_scope_state)
        processing_expr_node = False

        try:
            if isinstance(node,
                          (gast.FunctionDef, gast.ClassDef, gast.Lambda)):
                did_enter_function = True
            elif isinstance(node, gast.Expr):
                processing_expr_node = True

            if did_enter_function:
                self._enclosing_entities.append(node)

            if source_code and hasattr(node, 'lineno'):
                self._lineno = node.lineno
                self._col_offset = node.col_offset

            if processing_expr_node:
                entry_expr_value = node.value

            if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
                result = super(Base, self).visit(node)

            # Adjust for consistency: replacing the value of an Expr with
            # an Assign node removes the need for the Expr node.
            if processing_expr_node:
                if isinstance(result,
                              gast.Expr) and result.value != entry_expr_value:
                    # When the replacement is a list, it is assumed that the list came
                    # from a template that contained a number of statements, which
                    # themselves are standalone and don't require an enclosing Expr.
                    if isinstance(result.value,
                                  (list, tuple, gast.Assign, gast.AugAssign)):
                        result = result.value

            # On exception, the local scope integrity is not guaranteed.
            if did_enter_function:
                self._enclosing_entities.pop()

            if local_scope_size_at_entry != len(self._local_scope_state):
                raise AssertionError(
                    'Inconsistent local scope stack. Before entering node %s, the'
                    ' stack had length %d, after exit it has length %d. This'
                    ' indicates enter_local_scope and exit_local_scope are not'
                    ' well paired.' % (node, local_scope_size_at_entry,
                                       len(self._local_scope_state)))
            return result

        except (ValueError, AttributeError, KeyError,
                NotImplementedError) as e:
            msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % (
                e.__class__.__name__, str(e), self._get_source(node),
                pretty_printer.fmt(node, color=False))
            if source_code:
                line = source_code.splitlines()[self._lineno - 1]
            else:
                line = '<no source available>'
            # TODO(mdan): Avoid the printing of the original exception.
            # In other words, we need to find how to suppress the "During handling
            # of the above exception, another exception occurred" message.
            six.reraise(
                AutographParseError,
                AutographParseError(
                    msg,
                    (source_file, self._lineno, self._col_offset + 1, line)),
                sys.exc_info()[2])
Пример #16
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
    """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        node, name, ns = class_to_graph(o, program_ctx)
    elif tf_inspect.isfunction(o):
        node, name, ns = function_to_graph(o, program_ctx, arg_values,
                                           arg_types)
    elif tf_inspect.ismethod(o):
        node, name, ns = function_to_graph(o, program_ctx, arg_values,
                                           arg_types)
    # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
    elif hasattr(o, '__class__'):
        raise NotImplementedError(
            'Object conversion is not yet supported. If you are '
            'trying to convert code that uses an existing object, '
            'try including the creation of that object in the '
            'conversion. For example, instead of converting the method '
            'of a class, try converting the entire class instead. '
            'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
            'contrib/autograph/README.md#using-the-functional-api '
            'for more information.')
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    # TODO(mdan): This is temporary. it should be created using a converter.
    # TODO(mdan): The attribute should be added with a helper, not directly.
    # The helper can ensure there are no collisions.
    template = '''
      entity.autograph_info__ = {}
  '''
    node.extend(templates.replace(template, entity=name))

    program_ctx.add_to_cache(o, node)

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    compiler.ast_to_source(node))
    if logging.has_verbosity(4):
        for n in node:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    if program_ctx.options.recursive:
        while True:
            candidate = None
            for obj in program_ctx.name_map.keys():
                if obj not in program_ctx.dependency_cache:
                    candidate = obj
                    break
            if candidate is None:
                break
            if (hasattr(candidate, 'im_class') and getattr(
                    candidate, 'im_class') not in program_ctx.partial_types):
                # Class members are converted with their objects, unless they're
                # only converted partially.
                continue
            entity_to_graph(candidate, program_ctx, {}, {})

    return node, name, ns