예제 #1
0
  def test_find_matching_definitions_lambda_uses_arg_names(self):
    node = parser.parse_str(
        textwrap.dedent("""
      f = lambda x: 1, lambda y: 2
    """))
    f = lambda x: x
    nodes = ast_util.find_matching_definitions(node, f)
    self.assertLambdaNodes(nodes, ('(1)',))

    f = lambda y: y
    nodes = ast_util.find_matching_definitions(node, f)
    self.assertLambdaNodes(nodes, ('(2)',))
예제 #2
0
  def test_find_matching_definitions_lambda_uses_arg_names(self):
    node = parser.parse_str(
        textwrap.dedent("""
      f = lambda x: 1, lambda y: 2
    """))
    f = lambda x: x
    nodes = ast_util.find_matching_definitions(node, f)
    self.assertLambdaNodes(nodes, ('(1)',))

    f = lambda y: y
    nodes = ast_util.find_matching_definitions(node, f)
    self.assertLambdaNodes(nodes, ('(2)',))
예제 #3
0
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    node = node.body[0]

    # In general, the output of inspect.getsource is inexact because it uses
    # regex matching to adjust the exact location around the line number that
    # CPython records. This is particularly problematic for lambda functions,
    # where the entire containing lines are returned.
    nodes = ast_util.find_matching_definitions(node, f)
    if len(nodes) != 1:
        if f.__name__ == '<lambda>':
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        else:
            raise ValueError(
                'Unable to identify source code of function {}. The source code'
                ' reported by Python did not include exactly one matching signature:'
                '\n{}\n. This is an extremely rare occurrence. Please report it to'
                ' the TensorFlow team.'.format(f, source))
    node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = program_ctx.new_namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types,
                                         owner_type=owner_type)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    else:
        # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
        new_name, did_rename = namer.compiled_function_name(
            f.__name__, f, owner_type)
        if did_rename:
            node.name = new_name
        else:
            new_name = f.__name__
            assert node.name == new_name

    program_ctx.update_name_map(namer)
    # TODO(mdan): Use this at compilation.

    return [node], new_name, namespace
예제 #4
0
 def test_find_matching_definitions_lambda(self):
     node = parser.parse(textwrap.dedent("""
   f = lambda x: 1
 """))
     f = lambda x: x
     nodes = ast_util.find_matching_definitions(node, f)
     self.assertLambdaNodes(nodes, ('(1)', ))
예제 #5
0
def convert_func_to_ast(f, program_ctx, do_rename=True):
    """Specialization of `convert_entity_to_ast` for callable functions."""

    future_features = inspect_utils.getfutureimports(f)
    node, source = parser.parse_entity(f, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    # Parsed AST should contain future imports and one function def node.

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve_entity(node, source, f)

    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
    elif do_rename:
        new_name = namer.function_name(f.__name__)
    else:
        new_name = f.__name__

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         future_features=future_features,
                                         namespace=namespace)
    context = converter.EntityContext(namer, entity_info, program_ctx,
                                      new_name)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        node = gast.Assign(targets=[
            gast.Name(new_name,
                      ctx=gast.Store(),
                      annotation=None,
                      type_comment=None)
        ],
                           value=node)
    elif do_rename:
        node.name = new_name
    else:
        assert node.name == new_name

    return (node, ), new_name, entity_info
예제 #6
0
 def test_find_matching_definitions_lambda_multiple_matches(self):
   node = parser.parse_str(
       textwrap.dedent("""
     f = lambda x: 1, lambda x: 2
   """))
   f = lambda x: x
   nodes = ast_util.find_matching_definitions(node, f)
   self.assertLambdaNodes(nodes, ('(1)', '(2)'))
예제 #7
0
 def test_find_matching_definitions_lambda_multiple_matches(self):
   node = parser.parse_str(
       textwrap.dedent("""
     f = lambda x: 1, lambda x: 2
   """))
   f = lambda x: x
   nodes = ast_util.find_matching_definitions(node, f)
   self.assertLambdaNodes(nodes, ('(1)', '(2)'))
예제 #8
0
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    node = node.body[0]

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    try:
        node = node_to_graph(node, context)
    except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
        logging.error(1, 'Error converting %s', f, exc_info=True)
        raise errors.InternalError('conversion', e)
        # TODO(mdan): Catch and rethrow syntax errors.

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    elif do_rename:
        # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
        new_name = namer.function_name(f.__name__)
        node.name = new_name
    else:
        new_name = f.__name__
        assert node.name == new_name

    return [node], new_name, namespace
예제 #9
0
    def test_find_matching_definitions_function(self):
        node = parser.parse_str(
            textwrap.dedent("""
      def f(x):
        return 1
    """))

        def f(x):
            return x

        nodes = ast_util.find_matching_definitions(node, f)
        self.assertFunctionDefNodes(nodes, ('return 1', ))
예제 #10
0
    def _transform_function(self, fn, user_context):
        """Performs source code transformation on a function."""
        future_features = inspect_utils.getfutureimports(fn)
        node, source = parser.parse_entity(fn, future_features=future_features)
        logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

        # In general, the output of inspect.getsource is inexact for lambdas
        # because it uses regex matching to adjust the exact location around
        # the line number that CPython records. Then, the entire containing line
        # is returned, which we may have trouble disambiguating.
        # For example:
        #   x, y = lambda: 1, lambda: 2
        is_lambda = fn.__name__ == '<lambda>'
        if is_lambda:
            nodes = ast_util.find_matching_definitions(node, fn)
            if len(nodes) != 1:
                raise ValueError(
                    'Unable to identify source code of lambda function {}.'
                    ' It was defined in this code:\n'
                    '{}\n'
                    'This code must contain a single distinguishable lambda.'
                    ' To avoid this problem, define each lambda in a separate'
                    ' expression.'.format(fn, source))
            node, = nodes

        origin_info.resolve_entity(node, source, fn)

        namespace = inspect_utils.getnamespace(fn)
        namer = naming.Namer(namespace)
        new_name = namer.new_symbol(self.get_transformed_name(node), ())
        entity_info = transformer.EntityInfo(name=new_name,
                                             source_code=source,
                                             source_file='<fragment>',
                                             future_features=future_features,
                                             namespace=namespace)
        context = transformer.Context(entity_info, namer, user_context)

        node = self._erase_arg_defaults(node)
        node = self.transform_ast(node, context)

        if is_lambda:
            node = gast.Assign(targets=[
                gast.Name(new_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                               value=node)
        else:
            node.name = new_name

        return node, context
예제 #11
0
    def test_find_matching_definitions_decorated_compatible(self):
        node = parser.parse_str(
            textwrap.dedent("""
      @sneaky_decorator
      def f(x, *args, **kwargs):
        return 1
    """))

        def f(a, b, c, d=1):
            return a + b + c + d

        nodes = ast_util.find_matching_definitions(node, f)
        self.assertFunctionDefNodes(nodes, ('return 1', ))
예제 #12
0
    def test_find_matching_definitions_nested_functions_same_name(self):
        node = parser.parse_str(
            textwrap.dedent("""
      def f(x, *args, **kwargs):
        def f(x, y):
          return 1
        return 2
    """))

        def f(x, y):
            return x + y

        nodes = ast_util.find_matching_definitions(node, f)
        self.assertFunctionDefNodes(nodes, ('return 1', ))
예제 #13
0
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    node = node.body[0]

    # In general, the output of inspect.getsource is inexact because it uses crude
    # regex matching methods to search the source file. This is particularly
    # problematic for lambda functions, where the entire containing lines are
    # returned. Certain distributions of CPython may also return the enclosing
    # function for local functions.
    nodes = ast_util.find_matching_definitions(node, f)
    if len(nodes) != 1:
        if f.__name__ == '<lambda>':
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        else:
            # The inspect.getsource bug is currently known to occur in the Windows
            # integration tests which run Python 3.6.
            # TODO(mdan): Find out eaxctly which distribution of Python is that.
            raise ValueError(
                'Unable to identify source code of function {}. The source code'
                ' reported by Python did not include exactly one matching signature:'
                '\n{}\nTo avoid ambiguity, use a unique name for each'
                ' function.\nNote that some distributions of Python may report source'
                ' code incorrectly. It may be possible to avoid that bug by'
                ' organizing the code into smaller units (smaller files, functions or'
                ' classes), or by turning AutoGraph off.'.format(f, source))
    node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = program_ctx.new_namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types,
                                         owner_type=owner_type)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    else:
        # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
        new_name, did_rename = namer.compiled_function_name(
            f.__name__, f, owner_type)
        if did_rename:
            node.name = new_name
        else:
            new_name = f.__name__
            assert node.name == new_name

    program_ctx.update_name_map(namer)
    # TODO(mdan): Use this at compilation.

    return [node], new_name, namespace