Beispiel #1
0
    def test_parse_lambda_in_expression(self):

        l = (
            lambda x: lambda y: x + y + 1,
            lambda x: lambda y: x + y + 2,
        )

        node, source = parser.parse_entity(l[0], future_features=())
        self.assertEqual(parser.unparse(node, include_encoding_marker=False),
                         '(lambda x: (lambda y: ((x + y) + 1)))')
        self.assertMatchesWithPotentialGarbage(
            source, 'lambda x: lambda y: x + y + 1', ',')

        node, source = parser.parse_entity(l[0](0), future_features=())
        self.assertEqual(parser.unparse(node, include_encoding_marker=False),
                         '(lambda y: ((x + y) + 1))')
        self.assertMatchesWithPotentialGarbage(source, 'lambda y: x + y + 1',
                                               ',')

        node, source = parser.parse_entity(l[1], future_features=())
        self.assertEqual(parser.unparse(node, include_encoding_marker=False),
                         '(lambda x: (lambda y: ((x + y) + 2)))')
        self.assertMatchesWithPotentialGarbage(
            source, 'lambda x: lambda y: x + y + 2', ',')

        node, source = parser.parse_entity(l[1](0), future_features=())
        self.assertEqual(parser.unparse(node, include_encoding_marker=False),
                         '(lambda y: ((x + y) + 2))')
        self.assertMatchesWithPotentialGarbage(source, 'lambda y: x + y + 2',
                                               ',')
Beispiel #2
0
 def __repr__(self):
   if isinstance(self.ast_node, gast.FunctionDef):
     return 'def %s' % self.ast_node.name
   elif isinstance(self.ast_node, gast.ClassDef):
     return 'class %s' % self.ast_node.name
   elif isinstance(self.ast_node, gast.withitem):
     return parser.unparse(
         self.ast_node.context_expr, include_encoding_marker=False).strip()
   return parser.unparse(self.ast_node, include_encoding_marker=False).strip()
Beispiel #3
0
 def assert_body_anfs_as_expected(self, expected_fn, test_fn):
   # Testing the code bodies only.  Wrapping them in functions so the
   # syntax highlights nicely, but Python doesn't try to execute the
   # statements.
   node, _ = parser.parse_entity(test_fn, future_features=())
   orig_source = parser.unparse(node, indentation='  ')
   orig_str = textwrap.dedent(orig_source).strip()
   config = [(anf.ANY, anf.LEAVE)]  # Configuration to transform nothing
   node = anf.transform(node, self._simple_context(), config=config)
   new_source = parser.unparse(node, indentation='  ')
   new_str = textwrap.dedent(new_source).strip()
   self.assertEqual(orig_str, new_str)
Beispiel #4
0
    def test_parse_lambda_resolution_by_signature(self):

        l = lambda x: lambda x, y: x + y

        node, source = parser.parse_entity(l, future_features=())
        self.assertEqual(parser.unparse(node, include_encoding_marker=False),
                         '(lambda x: (lambda x, y: (x + y)))')
        self.assertEqual(source, 'lambda x: lambda x, y: x + y')

        node, source = parser.parse_entity(l(0), future_features=())
        self.assertEqual(parser.unparse(node, include_encoding_marker=False),
                         '(lambda x, y: (x + y))')
        self.assertEqual(source, 'lambda x, y: x + y')
Beispiel #5
0
    def get_definition_directive(self, node, directive, arg, default):
        """Returns the unique directive argument for a symbol.

    See lang/directives.py for details on directives.

    Example:
       # Given a directive in the code:
       ag.foo_directive(bar, baz=1)

       # One can write for an AST node Name(id='bar'):
       get_definition_directive(node, ag.foo_directive, 'baz')

    Args:
      node: ast.AST, the node representing the symbol for which the directive
        argument is needed.
      directive: Callable[..., Any], the directive to search.
      arg: str, the directive argument to return.
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
        defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
        if not defs:
            return default

        arg_values_found = []
        for def_ in defs:
            if (directive in def_.directives
                    and arg in def_.directives[directive]):
                arg_values_found.append(def_.directives[directive][arg])

        if not arg_values_found:
            return default

        if len(arg_values_found) == 1:
            return arg_values_found[0]

        # If multiple annotations reach the symbol, they must all match. If they do,
        # return any of them.
        first_value = arg_values_found[0]
        for other_value in arg_values_found[1:]:
            if not ast_util.matches(first_value, other_value):
                qn = anno.getanno(node, anno.Basic.QN)
                raise ValueError(
                    '%s has ambiguous annotations for %s(%s): %s, %s' %
                    (qn, directive.__name__, arg,
                     parser.unparse(other_value).strip(),
                     parser.unparse(first_value).strip()))
        return first_value
Beispiel #6
0
    def _transformed_factory(self, fn, cache_subkey, user_context,
                             extra_locals):
        """Returns the transformed function factory for a given input."""
        if self._cache.has(fn, cache_subkey):
            return self._cached_factory(fn, cache_subkey)

        with self._cache_lock:
            # Check again under lock.
            if self._cache.has(fn, cache_subkey):
                return self._cached_factory(fn, cache_subkey)

            logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
            nodes, ctx = self._transform_function(fn, user_context)

            if logging.has_verbosity(2):
                logging.log(2, 'Transformed %s:\n\n%s\n', fn,
                            parser.unparse(nodes))

            factory = _TransformedFnFactory(ctx.info.name,
                                            fn.__code__.co_freevars,
                                            extra_locals)
            factory.create(nodes,
                           ctx.namer,
                           future_features=ctx.info.future_features)
            self._cache[fn][cache_subkey] = factory
        return factory
Beispiel #7
0
        def do_parse_and_test(lam, **unused_kwargs):
            node, source = parser.parse_entity(lam, future_features=())

            self.assertEqual(
                parser.unparse(node, include_encoding_marker=False),
                '(lambda x: x)')
            self.assertEqual(source, 'lambda x: x, named_arg=1)')
Beispiel #8
0
    def test_parse_lambda_complex_body(self):

        l = lambda x: (  # pylint:disable=g-long-lambda
            x.y(
                [],
                x.z,
                (),
                x[0:2],
            ),
            x.u,
            'abc',
            1,
        )

        node, source = parser.parse_entity(l, future_features=())
        self.assertEqual(
            parser.unparse(node, include_encoding_marker=False),
            "(lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1))")
        self.assertEqual(source,
                         ('lambda x: (  # pylint:disable=g-long-lambda\n'
                          '        x.y(\n'
                          '            [],\n'
                          '            x.z,\n'
                          '            (),\n'
                          '            x[0:2],\n'
                          '        ),\n'
                          '        x.u,\n'
                          '        \'abc\',\n'
                          '        1,'))
Beispiel #9
0
  def test_ext_slice_roundtrip(self):
    def ext_slice(n):
      return n[:, :], n[0, :], n[:, 0]

    node, _ = parser.parse_entity(ext_slice, future_features=())
    source = parser.unparse(node)
    self.assertAstMatches(node, source, expr=False)
Beispiel #10
0
def convert_entity_to_ast(o, program_ctx):
    """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    NotImplementedError: if entity is of a type that is not yet supported.
  """
    logging.log(1, 'Converting %s', o)

    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    parser.unparse(nodes))
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
Beispiel #11
0
 def assertLambdaNodes(self, matching_nodes, expected_bodies):
   self.assertEqual(len(matching_nodes), len(expected_bodies))
   for node in matching_nodes:
     self.assertIsInstance(node, gast.Lambda)
     self.assertIn(
         parser.unparse(node.body, include_encoding_marker=False).strip(),
         expected_bodies)
Beispiel #12
0
    def test_unparse(self):
        node = gast.If(test=gast.Constant(1, kind=None),
                       body=[
                           gast.Assign(targets=[
                               gast.Name('a',
                                         ctx=gast.Store(),
                                         annotation=None,
                                         type_comment=None)
                           ],
                                       value=gast.Name('b',
                                                       ctx=gast.Load(),
                                                       annotation=None,
                                                       type_comment=None))
                       ],
                       orelse=[
                           gast.Assign(targets=[
                               gast.Name('a',
                                         ctx=gast.Store(),
                                         annotation=None,
                                         type_comment=None)
                           ],
                                       value=gast.Constant('c', kind=None))
                       ])

        source = parser.unparse(node, indentation='  ')
        self.assertEqual(
            textwrap.dedent("""
            # coding=utf-8
            if 1:
                a = b
            else:
                a = 'c'
        """).strip(), source.strip())
Beispiel #13
0
    def test_parse_lambda_complex_body(self):

        l = lambda x: (  # pylint:disable=g-long-lambda
            x.y(
                [],
                x.z,
                (),
                x[0:2],
            ),
            x.u,
            'abc',
            1,
        )

        node, source = parser.parse_entity(l, future_features=())
        self.assertEqual(
            parser.unparse(node, include_encoding_marker=False),
            "(lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1))")
        base_source = ('lambda x: (  # pylint:disable=g-long-lambda\n'
                       '        x.y(\n'
                       '            [],\n'
                       '            x.z,\n'
                       '            (),\n'
                       '            x[0:2],\n'
                       '        ),\n'
                       '        x.u,\n'
                       '        \'abc\',\n'
                       '        1,')
        # The complete source includes the trailing parenthesis. But that is only
        # detected in runtimes which correctly track end_lineno for ASTs.
        self.assertIn(source, (base_source, base_source + '\n    )'))
Beispiel #14
0
    def test_rename_symbols_function(self):
        node = parser.parse('def f():\n  pass')
        node = ast_util.rename_symbols(
            node, {qual_names.QN('f'): qual_names.QN('f1')})

        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'def f1():\n    pass')
Beispiel #15
0
    def test_parse_lambda_prefix_cleanup(self):

        lambda_lam = lambda x: x + 1

        node, source = parser.parse_entity(lambda_lam, future_features=())
        self.assertEqual(parser.unparse(node, include_encoding_marker=False),
                         '(lambda x: (x + 1))')
        self.assertEqual(source, 'lambda x: x + 1')
Beispiel #16
0
    def test_rename_symbols_attributes(self):
        node = parser.parse('b.c = b.c.d')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
Beispiel #17
0
    def test_rename_symbols_global(self):
        node = parser.parse('global a, b, c')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b'): qual_names.QN('renamed_b')})

        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'global a, renamed_b, c')
Beispiel #18
0
    def transform_function(self, fn, user_context):
        """Transforms a function. See GenericTranspiler.trasnform_function.

    This overload wraps the parent's `transform_function`, adding caching and
    facilities to instantiate the output as a Python object. It also
    adds facilities to make new symbols available to the generated Python code,
    visible as local variables - see `get_extra_locals`.

    Args:
      fn: A function or lambda.
      user_context: An opaque object (may be None) that is forwarded to
        transform_ast, through the ctx.user_context argument.
    Returns:
      A tuple:
        * A function or lambda with the same signature and closure as `fn`
        * The temporary module into which the transformed function was loaded
        * The source map as a
            Dict[origin_info.LineLocation, origin_info.OriginInfo]
    """
        cache_subkey = self.get_caching_key(user_context)

        if self._cache.has(fn, cache_subkey):
            # Fast path: use a lock-free check.
            factory = self._cached_factory(fn, cache_subkey)

        else:
            with self._cache_lock:
                # Check again under lock.
                if self._cache.has(fn, cache_subkey):
                    factory = self._cached_factory(fn, cache_subkey)

                else:
                    logging.log(1, '%s is not cached for subkey %s', fn,
                                cache_subkey)
                    # TODO(mdan): Confusing overloading pattern. Fix.
                    nodes, ctx = super(PyToPy, self).transform_function(
                        fn, user_context)

                    if logging.has_verbosity(2):
                        logging.log(2, 'Transformed %s:\n\n%s\n', fn,
                                    parser.unparse(nodes))

                    factory = _PythonFnFactory(ctx.info.name,
                                               fn.__code__.co_freevars,
                                               self.get_extra_locals())
                    factory.create(nodes,
                                   ctx.namer,
                                   future_features=ctx.info.future_features)
                    self._cache[fn][cache_subkey] = factory

        transformed_fn = factory.instantiate(globals_=fn.__globals__,
                                             closure=fn.__closure__ or (),
                                             defaults=fn.__defaults__,
                                             kwdefaults=getattr(
                                                 fn, '__kwdefaults__', None))
        return transformed_fn, factory.module, factory.source_map
Beispiel #19
0
    def test_parse_lambda_resolution_by_location(self):

        _ = lambda x: x + 1
        l = lambda x: x + 1
        _ = lambda x: x + 1

        node, source = parser.parse_entity(l, future_features=())
        self.assertEqual(parser.unparse(node, include_encoding_marker=False),
                         '(lambda x: (x + 1))')
        self.assertEqual(source, 'lambda x: x + 1')
    def test_rename_symbols_basic(self):
        node = parser.parse('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

        self.assertIsInstance(node.value.left.id, str)
        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'renamed_a + b')
Beispiel #21
0
 def visit_If(self, node):
   self.emit('if ')
   # This is just for simplifity. A real generator will walk the tree and
   # emit proper code.
   self.emit(parser.unparse(node.test, include_encoding_marker=False))
   self.emit(' {\n')
   self.visit_block(node.body)
   self.emit('} else {\n')
   self.visit_block(node.orelse)
   self.emit('}\n')
Beispiel #22
0
    def test_parse_lambda_multiline(self):

        l = (
            lambda x: lambda y: x + y  # pylint:disable=g-long-lambda
            - 1)

        node, source = parser.parse_entity(l, future_features=())
        self.assertEqual(parser.unparse(node, include_encoding_marker=False),
                         '(lambda x: (lambda y: ((x + y) - 1)))')
        self.assertMatchesWithPotentialGarbage(
            source,
            ('lambda x: lambda y: x + y  # pylint:disable=g-long-lambda\n'
             '        - 1'), ')')

        node, source = parser.parse_entity(l(0), future_features=())
        self.assertEqual(parser.unparse(node, include_encoding_marker=False),
                         '(lambda y: ((x + y) - 1))')
        self.assertMatchesWithPotentialGarbage(
            source, ('lambda y: x + y  # pylint:disable=g-long-lambda\n'
                     '        - 1'), ')')
Beispiel #23
0
    def test_rename_symbols_basic(self):
        node = parser.parse('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
        source = parser.unparse(node, include_encoding_marker=False)
        expected_node_src = 'renamed_a + b'

        self.assertIsInstance(node.value.left.id, str)
        self.assertAstMatches(node, source)
        self.assertAstMatches(node, expected_node_src)
Beispiel #24
0
    def test_convert_entity_to_ast_function_with_defaults(self):
        b = 2
        c = 1

        def f(a, d=c + 1):
            return a + b + d

        program_ctx = self._simple_program_ctx()
        nodes, name, _ = conversion.convert_entity_to_ast(f, program_ctx)
        fn_node, = nodes
        self.assertIsInstance(fn_node, gast.FunctionDef)
        self.assertEqual('tf__f', name)
        self.assertEqual(
            parser.unparse(fn_node.args.defaults[0],
                           include_encoding_marker=False).strip(), 'None')
 def visit_IfExp(self, node):
     template = '''
     ag__.if_exp(
         test,
         lambda: true_expr,
         lambda: false_expr,
         expr_repr)
 '''
     expr_repr = parser.unparse(node.test,
                                include_encoding_marker=False).strip()
     return templates.replace_as_expression(template,
                                            test=node.test,
                                            true_expr=node.body,
                                            false_expr=node.orelse,
                                            expr_repr=gast.Constant(
                                                expr_repr, kind=None))
Beispiel #26
0
def convert_entity_to_ast(o, program_ctx):
    """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    NotImplementedError: if entity is of a type that is not yet supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
    elif tf_inspect.isfunction(o):
        nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
    elif tf_inspect.ismethod(o):
        nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
    elif hasattr(o, '__class__'):
        # Note: this should only be raised when attempting to convert the object
        # directly. converted_call should still support it.
        raise NotImplementedError(
            'cannot convert entity "{}": object conversion is not yet'
            ' supported.'.format(o))
    else:
        raise NotImplementedError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    parser.unparse(nodes))
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
Beispiel #27
0
  def test_unparse(self):
    node = gast.If(
        test=gast.Num(1),
        body=[
            gast.Assign(
                targets=[gast.Name('a', gast.Store(), None)],
                value=gast.Name('b', gast.Load(), None))
        ],
        orelse=[
            gast.Assign(
                targets=[gast.Name('a', gast.Store(), None)],
                value=gast.Str('c'))
        ])

    source = parser.unparse(node, indentation='  ')
    self.assertEqual(
        textwrap.dedent("""
            # coding=utf-8
            if 1:
              a = b
            else:
              a = 'c'
        """).strip(), source.strip())
Beispiel #28
0
def load_ast(nodes,
             indentation='  ',
             include_source_map=False,
             delete_on_exit=True):
    """Loads the given AST as a Python module.

  Compiling the AST code this way ensures that the source code is readable by
  e.g. `pdb` or `inspect`.

  Args:
    nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST
      object.
    indentation: Text, the string to use for indentation.
    include_source_map: bool, whether return a source map.
    delete_on_exit: bool, whether to delete the temporary file used for
      compilation on exit.

  Returns:
    Tuple[module, Text, Dict[LineLocation, OriginInfo]], containing:
    the module containing the unparsed nodes, the source code corresponding to
    nodes, and the source map. Is include_source_map is False, the source map
    will be None.
  """
    if not isinstance(nodes, (list, tuple)):
        nodes = (nodes, )

    source = parser.unparse(nodes, indentation=indentation)
    module, _ = load_source(source, delete_on_exit)

    if include_source_map:
        source_map = origin_info.create_source_map(nodes, source,
                                                   module.__file__)
    else:
        source_map = None

    # TODO(mdan): Return a structured object.
    return module, source, source_map
Beispiel #29
0
 def visit_Return(self, node):
   self.emit(parser.unparse(node, include_encoding_marker=False))
   self.emit('\n')
Beispiel #30
0
 def debug_print_src(self, node):
     """Helper method useful for debugging. Prints the AST as code."""
     if __debug__:
         print(parser.unparse(node))
     return node