示例#1
0
 def test_new_symbol_avoids_conflicts(self):
     namer = naming.Namer({'temp': 1}, True, None, ())
     # temp is reserved in the global namespace
     self.assertEqual('temp_1', namer.new_symbol('temp', set()))
     # temp_2 is reserved in the local namespace
     self.assertEqual('temp_3', namer.new_symbol('temp', set(('temp_2', ))))
     self.assertItemsEqual(('temp_1', 'temp_3'), namer.generated_names)
示例#2
0
    def test_compiled_function_name_avoids_global_conflicts(self):
        def foo():
            pass

        namer = naming.Namer({'tf__foo': 1}, True, None, ())
        self.assertEqual(('tf__foo_1', True),
                         namer.compiled_function_name('foo', foo))
示例#3
0
def convert_func_to_ast(f, program_ctx, do_rename=True):
    """Specialization of `convert_entity_to_ast` for callable functions."""

    future_features = inspect_utils.getfutureimports(f)
    node, source = parser.parse_entity(f, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    # Parsed AST should contain future imports and one function def node.

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve_entity(node, source, f)

    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
    elif do_rename:
        new_name = namer.function_name(f.__name__)
    else:
        new_name = f.__name__

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         future_features=future_features,
                                         namespace=namespace)
    context = converter.EntityContext(namer, entity_info, program_ctx,
                                      new_name)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        node = gast.Assign(targets=[
            gast.Name(new_name,
                      ctx=gast.Store(),
                      annotation=None,
                      type_comment=None)
        ],
                           value=node)
    elif do_rename:
        node.name = new_name
    else:
        assert node.name == new_name

    return (node, ), new_name, entity_info
示例#4
0
    def test_compiled_function_name_consistent(self):
        def foo():
            pass

        namer = naming.Namer({}, True, None, ())
        self.assertEqual(('tf__foo', True),
                         namer.compiled_function_name('foo', foo))
        self.assertEqual(('tf__foo', True),
                         namer.compiled_function_name('foo', foo))
示例#5
0
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    node = node.body[0]

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    try:
        node = node_to_graph(node, context)
    except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
        logging.error(1, 'Error converting %s', f, exc_info=True)
        raise errors.InternalError('conversion', e)
        # TODO(mdan): Catch and rethrow syntax errors.

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    elif do_rename:
        # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
        new_name = namer.function_name(f.__name__)
        node.name = new_name
    else:
        new_name = f.__name__
        assert node.name == new_name

    return [node], new_name, namespace
示例#6
0
    def test_compiled_function_name_tracks_names(self):
        def bar():
            pass

        namer = naming.Namer({}, True, None, ())
        self.assertEqual(('tf__foo', True),
                         namer.compiled_function_name('foo'))
        self.assertEqual(('tf__bar', True),
                         namer.compiled_function_name('bar', bar))
        self.assertEqual({bar: 'tf__bar'}, namer.renamed_calls)
        self.assertItemsEqual(('tf__bar', 'tf__foo'), namer.generated_names)
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
  """Returns a (possibly cached) factory for the converted result of entity."""
  # The cache key is the entity's code object if it defined one, otherwise it's
  # the entity itself. Keying by the code object allows caching of functions
  # that are dynamically created e.g. in a loop.
  if hasattr(entity, '__code__'):
    key = entity.__code__
  else:
    key = entity

  # The cache subkey encompases any conversion options on which the generated
  # code may depend.
  # The cached factory includes the necessary definitions to distinguish
  # between the global and non-global free variables. For this reason, the
  # cache subkey includes the names of the free non-globals.
  subkey = (program_ctx.options, frozenset(free_nonglobal_var_names))

  with _CACHE_LOCK:
    # The cache values are _ConvertedEntityFactoryInfo objects.
    if _CACHE.has(key, subkey):
      # TODO(mdan): Check whether the module is still loaded.
      converted_entity_info = _CACHE[key][subkey]
      logging.log(3, 'Cache hit for entity %s key %s subkey %s: %s', entity,
                  key, subkey, converted_entity_info)
      return converted_entity_info

    logging.log(1, 'Entity %s is not cached for key %s subkey %s', entity, key,
                subkey)

    nodes, converted_name, entity_info = convert_entity_to_ast(
        entity, program_ctx)

    namer = naming.Namer(entity_info.namespace)
    factory_factory_name = namer.new_symbol('create_converted_entity_factory',
                                            ())
    factory_name = namer.new_symbol('create_converted_entity', ())
    nodes = _wrap_into_dynamic_factory(nodes, converted_name,
                                       factory_factory_name, factory_name,
                                       free_nonglobal_var_names,
                                       entity_info.future_features)

    module, _, source_map = compiler.ast_to_object(
        nodes, include_source_map=True)
    module_name = module.__name__

    converted_entity_info = _ConvertedEntityFactoryInfo(
        module_name=module_name,
        converted_name=converted_name,
        factory_factory_name=factory_factory_name,
        source_map=source_map)
    _CACHE[key][subkey] = converted_entity_info
    return converted_entity_info
示例#8
0
def _parse_and_analyze(f):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    entity_info = transformer.EntityInfo(source_code='',
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=True),
        autograph_module=None)
    ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                  entity_info=entity_info,
                                  program_ctx=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx, is_initial=True)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    node = anf.transform(node, ctx)
    node = converter.standard_analysis(node, ctx, is_initial=False)

    return node, ctx
示例#9
0
    def prepare(self, test_fn, namespace, recursive=True):
        namespace['ConversionOptions'] = converter.ConversionOptions

        future_features = ('print_function', 'division')
        node, source = parser.parse_entity(test_fn,
                                           future_features=future_features)
        namer = naming.Namer(namespace)
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(recursive=recursive),
            autograph_module=None)
        entity_info = transformer.EntityInfo(source_code=source,
                                             source_file='<fragment>',
                                             future_features=future_features,
                                             namespace=namespace)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        node = converter.standard_analysis(node, ctx, is_initial=True)
        return node, ctx
    def prepare(self, test_fn, namespace, arg_types=None, recursive=True):
        namespace['ConversionOptions'] = converter.ConversionOptions

        node, source, _ = parser.parse_entity(test_fn)
        namer = naming.Namer(namespace)
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(recursive=recursive),
            autograph_module=None)
        entity_info = transformer.EntityInfo(source_code=source,
                                             source_file='<fragment>',
                                             namespace=namespace,
                                             arg_values=None,
                                             arg_types=arg_types)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        origin_info.resolve(node, source, test_fn)
        node = converter.standard_analysis(node, ctx, is_initial=True)
        return node, ctx
示例#11
0
 def test_function_name_unsanitized_fqn(self):
     namer = naming.Namer({})
     self.assertEqual('tf__foo_bar', namer.function_name('foo.bar'))
     self.assertEqual('tf__foo_bar_baz',
                      namer.function_name(('foo.bar', 'baz')))
示例#12
0
def _parse_and_analyze(f, autobatch_functions):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze
    autobatch_functions: List of Python `str` names of autobatched functions.
      Arguments to these functions will be canonicalized to variable references,
      but others will not.

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    entity_info = transformer.EntityInfo(source_code='',
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=True),
        autograph_module=None)
    ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                  entity_info=entity_info,
                                  program_ctx=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx, is_initial=True)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    # Replacing if tests and autobatched function call arguments because
    # that's where divergence can happen.
    # Replacing all function calls because the downstream transformation
    # expects calls to lead directly to assignments.
    def maybe_replace_function_argument(parent, field_name, child):
        del field_name, child
        if not anno.hasanno(parent.func, anno.Basic.QN):
            return False
        func_name = anno.getanno(parent.func, anno.Basic.QN)
        if str(func_name) in autobatch_functions:
            return True
        return False

    anf_config = [
        (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE),
        (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE),
        (anf.ASTEdgePattern(gast.Call, 'args',
                            anf.ANY), maybe_replace_function_argument),
    ]
    node = anf.transform(node, ctx, config=anf_config)
    node = converter.standard_analysis(node, ctx, is_initial=False)

    return node, ctx
示例#13
0
 def test_function_name_consistent(self):
     namer = naming.Namer({})
     self.assertEqual('tf__foo', namer.function_name('foo'))
     self.assertEqual('tf__foo', namer.function_name('foo'))
示例#14
0
 def test_new_symbol_avoids_duplicates(self):
     namer = naming.Namer({}, True, None, ())
     self.assertEqual('temp', namer.new_symbol('temp', set()))
     self.assertEqual('temp_1', namer.new_symbol('temp', set()))
     self.assertItemsEqual(('temp', 'temp_1'), namer.generated_names)
示例#15
0
 def test_compiled_class_name_unsanitized_fqn(self):
     namer = naming.Namer({}, True, None, ())
     self.assertEqual('TfFooBarBaz',
                      namer.compiled_class_name(('foo.bar', 'Baz')))
示例#16
0
 def new_namer(self, namespace):
     return naming.Namer(namespace, self.options.recursive, self.name_map,
                         self.partial_types)
示例#17
0
 def test_compiled_class_name_basic(self):
     namer = naming.Namer({}, True, None, ())
     self.assertEqual('TfFooBar', namer.compiled_class_name(('foo', 'Bar')))
示例#18
0
 def test_compiled_function_name_unsanitized_fqn(self):
     namer = naming.Namer({}, True, None, ())
     self.assertEqual(('tf__foo_bar', True),
                      namer.compiled_function_name('foo.bar'))
     self.assertEqual(('tf__foo_bar_baz', True),
                      namer.compiled_function_name(('foo.bar', 'baz')))
示例#19
0
 def test_class_name_unsanitized_fqn(self):
     namer = naming.Namer({})
     self.assertEqual('TfFooBarBaz', namer.class_name(('foo.bar', 'Baz')))
示例#20
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        nodes, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            do_rename=False)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = nodes[0]
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace
示例#21
0
 def test_class_name_basic(self):
     namer = naming.Namer({})
     self.assertEqual('TfFooBar', namer.class_name(('foo', 'Bar')))
示例#22
0
 def test_new_symbol_tracks_names(self):
     namer = naming.Namer({}, True, None, ())
     self.assertEqual('temp', namer.new_symbol('temp', set()))
     self.assertItemsEqual(('temp', ), namer.generated_names)
示例#23
0
 def test_function_name_avoids_global_conflicts(self):
     namer = naming.Namer({'tf__foo': 1})
     self.assertEqual('tf__foo_1', namer.function_name('foo'))
示例#24
0
def convert_class_to_ast(c, program_ctx):
    """Specialization of `convert_entity_to_ast` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('cannot convert %s: no member methods' % c)

    # TODO(mdan): Don't clobber namespaces for each method in one class namespace.
    # The assumption that one namespace suffices for all methods only holds if
    # all methods were defined in the same module.
    # If, instead, functions are imported from multiple modules and then spliced
    # into the class, then each function has its own globals and __future__
    # imports that need to stay separate.

    # For example, C's methods could both have `global x` statements referring to
    # mod1.x and mod2.x, but using one namespace for C would cause a conflict.
    # from mod1 import f1
    # from mod2 import f2
    # class C(object):
    #   method1 = f1
    #   method2 = f2

    class_namespace = {}
    future_features = None
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        (node, ), _, entity_info = convert_func_to_ast(m,
                                                       program_ctx=program_ctx,
                                                       do_rename=False)
        class_namespace.update(entity_info.namespace)
        converted_members[m] = node

        # TODO(mdan): Similarly check the globals.
        if future_features is None:
            future_features = entity_info.future_features
        elif frozenset(future_features) ^ frozenset(
                entity_info.future_features):
            # Note: we can support this case if ever needed.
            raise ValueError(
                'cannot convert {}: if has methods built with mismatched future'
                ' features: {} and {}'.format(c, future_features,
                                              entity_info.future_features))
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    # TODO(mdan): Find a way better than forging this object.
    entity_info = transformer.EntityInfo(source_code=None,
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=class_namespace)

    return output_nodes, class_name, entity_info
示例#25
0
 def test_function_name_tracks_names(self):
     namer = naming.Namer({})
     self.assertEqual('tf__foo', namer.function_name('foo'))
     self.assertEqual('tf__bar', namer.function_name('bar'))
     self.assertItemsEqual(('tf__bar', 'tf__foo'), namer.generated_names)