Example #1
0
    def test_replace_code_block(self):
        template = """
      def test_fn(a):
        block
        return a
    """

        class ShouldBeReplaced(object):
            pass

        node = templates.replace(
            template,
            block=[
                gast.Assign(
                    [
                        gast.Name('a',
                                  ctx=ShouldBeReplaced,
                                  annotation=None,
                                  type_comment=None)
                    ],
                    gast.BinOp(
                        gast.Name('a',
                                  ctx=ShouldBeReplaced,
                                  annotation=None,
                                  type_comment=None), gast.Add(),
                        gast.Constant(1, kind=None)),
                ),
            ] * 2)[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(3, result.test_fn(1))
Example #2
0
    def test_load_ast(self):
        node = gast.FunctionDef(
            name='f',
            args=gast.arguments(args=[gast.Name('a', gast.Param(), None)],
                                vararg=None,
                                kwonlyargs=[],
                                kwarg=None,
                                defaults=[],
                                kw_defaults=[]),
            body=[
                gast.Return(
                    gast.BinOp(op=gast.Add(),
                               left=gast.Name('a', gast.Load(), None),
                               right=gast.Num(1)))
            ],
            decorator_list=[],
            returns=None)

        module, source, _ = loader.load_ast(node)

        expected_source = """
      # coding=utf-8
      def f(a):
        return a + 1
    """
        self.assertEqual(
            textwrap.dedent(expected_source).strip(), source.strip())
        self.assertEqual(2, module.f(1))
        with open(module.__file__, 'r') as temp_output:
            self.assertEqual(
                textwrap.dedent(expected_source).strip(),
                temp_output.read().strip())
Example #3
0
    def module(self):
        """Constructs an `instructions.Module` for this `Context`.

    Returns:
      module: An `instructions.Module` representing the batched computation
        defined by all the functions decorated with `batch` in this `Context` so
        far.
    """
        if self._module is not None:
            return self._module
        ab = dsl.ProgramBuilder()
        function_objects = []
        for function, type_inference in self._tagged_functions:
            declared = ab.declare_function(function.__name__, type_inference)
            function_objects.append(declared)
        for function, _ in self._tagged_functions:
            name = function.__name__
            node, ctx = _parse_and_analyze(function, self.function_names())
            node = _AutoBatchingTransformer(self.function_names(), [
                scoped_name
                for scoped_name, _ in _environment(function, [name])
            ], ctx).visit(node)
            builder_module, _, _ = loader.load_ast(node)
            for scoped_name, val in _environment(function, [name]):
                builder_module.__dict__[scoped_name] = val
            builder = getattr(builder_module, name)
            builder(ab, function_objects)
        self._module = ab.module()
        return self._module
Example #4
0
  def test_basic(self):
    def test_function():
      a = 0
      return a

    node, _ = parser.parse_entity(test_function, future_features=())
    node = anf.transform(node, self._simple_context())
    result, _, _ = loader.load_ast(node)
    self.assertEqual(test_function(), result.test_function())
Example #5
0
 def test_keywords_to_dict(self):
     keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
     d = ast_util.keywords_to_dict(keywords)
     # Make sure we generate a usable dict node by attaching it to a variable and
     # compiling everything.
     node = parser.parse('def f(b): pass')
     node.body.append(ast.Return(d))
     result, _, _ = loader.load_ast(node)
     self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
Example #6
0
    def test_replace_name_with_dict(self):
        template = """
      def test_fn():
        return foo['bar']
    """

        source = parser.parse_expression('{\'bar\': 3}')
        node = templates.replace(template, foo=source)[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(3, result.test_fn())
Example #7
0
 def _get_source(self, node):
     try:
         source, _ = loader.load_ast(node)
         return source
     # pylint: disable=broad-except
     # This function is used for error reporting.  If an exception occurs here,
     # it should be suppressed, in favor of emitting as informative a message
     # about the original error as possible.
     except Exception:
         return '<could not convert AST to source>'
Example #8
0
    def test_replace_tuple(self):
        template = """
      def test_fn(a, c):
        return b,
    """

        node = templates.replace(template, b=('a', 'c'))[0]
        result, _, _ = loader.load_ast(node)

        self.assertEqual((2, 3), result.test_fn(2, 3))
Example #9
0
    def test_replace_variable(self):
        template = """
      def test_fn(a):
        a += 1
        a = 2 * a + 1
        return b
    """

        node = templates.replace(template, a='b')[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(7, result.test_fn(2))
Example #10
0
    def test_replace_function_name(self):
        template = """
      def fname(a):
        a += 1
        a = 2 * a + 1
        return a
    """

        node = templates.replace(template, fname='test_fn')[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(7, result.test_fn(2))
Example #11
0
    def test_parse_load_identity(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = x + 1
            return b

        node, _ = parser.parse_entity(test_fn, future_features=())
        module, _, _ = loader.load_ast(node)

        self.assertEqual(textwrap.dedent(tf_inspect.getsource(test_fn)),
                         tf_inspect.getsource(module.test_fn))
Example #12
0
    def test_replace_attribute(self):
        template = """
      def test_fn(a):
        return a.foo
    """

        node = templates.replace(template, foo='b')[0]
        result, _, _ = loader.load_ast(node)
        mod = imp.new_module('test')
        mod.b = 3
        self.assertEqual(3, result.test_fn(mod))

        with self.assertRaises(ValueError):
            templates.replace(template, foo=1)
Example #13
0
    def compiled(self, node, namespace, symbols=()):
        source = None

        self.dynamic_calls = []

        # See api.converted_call
        def converted_call(f,
                           args,
                           kwargs,
                           unused_opts=None,
                           unused_function_ctx=None):
            """Mock version of api.converted_call."""
            self.dynamic_calls.append((args, kwargs))
            if kwargs is None:
                kwargs = {}
            return f(*args, **kwargs)

        def fake_autograph_artifact(f):
            setattr(f, 'fake_autograph_artifact', True)
            return f

        try:
            result, source, source_map = loader.load_ast(
                node, include_source_map=True)
            # TODO(mdan): Move the unparsing from converter into pyct and reuse here.

            # TODO(mdan): Move this into self.prepare()
            result.tf = self.make_fake_mod('fake_tf', *symbols)
            fake_ag = self.make_fake_mod('fake_ag', converted_call,
                                         converter.ConversionOptions)
            fake_ag.__dict__.update(operators.__dict__)
            fake_ag.__dict__.update(special_functions.__dict__)
            fake_ag.ConversionOptions = converter.ConversionOptions
            fake_ag.Feature = converter.Feature
            fake_ag.utils = utils
            fake_ag.FunctionScope = function_wrappers.FunctionScope
            fake_ag.autograph_artifact = fake_autograph_artifact
            result.ag__ = fake_ag
            result.ag_source_map__ = source_map
            for k, v in namespace.items():
                result.__dict__[k] = v
            yield result
        except Exception:  # pylint:disable=broad-except
            if source is None:
                print('Offending AST:\n%s' %
                      pretty_printer.fmt(node, color=False))
            else:
                print('Offending source code:\n%s' % source)
            raise
Example #14
0
    def test_replace_name_with_call(self):
        template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

        source = parser.parse_expression('f()(b)')
        node = templates.replace(template, foo=source)[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(15, result.test_fn())
Example #15
0
    def test_parse_load_identity(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = (x + 1)
            return b

        node, _ = parser.parse_entity(test_fn, future_features=())
        module, _, _ = loader.load_ast(node)
        source = tf_inspect.getsource(module.test_fn)
        expected_node_src = textwrap.dedent(tf_inspect.getsource(test_fn))

        self.assertAstMatches(node, source)
        self.assertAstMatches(node, expected_node_src)
Example #16
0
    def test_replace_call_keyword(self):
        template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

        source = parser.parse_expression('f(d=3, f=5)')
        node = templates.replace(template, kws=source.keywords)[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(9, result.test_fn())

        with self.assertRaises(ValueError):
            templates.replace(template, kws=[])
            templates.replace(template, kws=1)
  def test_parse_load_identity(self):

    def test_fn(x):
      a = True
      b = ''
      if a:
        b = (x + 1)
      return b

    node, _ = parser.parse_entity(test_fn, future_features=())
    module, _, _ = loader.load_ast(node)

    # astunparse uses fixed 4-space indenting.
    self.assertEqual(
        textwrap.dedent(tf_inspect.getsource(test_fn)),
        tf_inspect.getsource(module.test_fn).replace('    ', '  '))
  def test_load_ast(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[
                gast.Name(
                    'a', ctx=gast.Param(), annotation=None, type_comment=None)
            ],
            posonlyargs=[],
            vararg=None,
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name(
                        'a',
                        ctx=gast.Load(),
                        annotation=None,
                        type_comment=None),
                    right=gast.Constant(1, kind=None)))
        ],
        decorator_list=[],
        returns=None,
        type_comment=None)

    module, source, _ = loader.load_ast(node)

    expected_node_src = """
      # coding=utf-8
      def f(a):
          return (a + 1)
    """
    expected_node_src = textwrap.dedent(expected_node_src)

    self.assertAstMatches(node, source)
    self.assertAstMatches(node, expected_node_src)

    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertAstMatches(node, temp_output.read())
Example #19
0
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
    """Returns a (possibly cached) factory for the converted result of entity."""
    # The cache subkey encompasses any conversion options on which the generated
    # code may depend.
    # The cached factory includes the necessary definitions to distinguish
    # between the global and non-global free variables. For this reason, the
    # cache subkey includes the names of the free non-globals.
    subkey = (program_ctx.options, frozenset(free_nonglobal_var_names))

    with _CACHE_LOCK:
        # The cache values are _ConvertedEntityFactoryInfo objects.
        if _CACHE.has(entity, subkey):
            # TODO(mdan): Check whether the module is still loaded.
            converted_entity_info = _CACHE[entity][subkey]
            logging.log(3, 'Cache hit for entity %s subkey %s: %s', entity,
                        subkey, converted_entity_info)
            return converted_entity_info

        logging.log(1, 'Entity %s is not cached for subkey %s', entity, subkey)

        nodes, converted_name, entity_info = convert_entity_to_ast(
            entity, program_ctx)

        namer = naming.Namer(entity_info.namespace)
        factory_factory_name = namer.new_symbol(
            'create_converted_entity_factory', ())
        factory_name = namer.new_symbol('create_converted_entity', ())
        nodes = _wrap_into_dynamic_factory(nodes, converted_name,
                                           factory_factory_name, factory_name,
                                           free_nonglobal_var_names,
                                           entity_info.future_features)

        module, _, source_map = loader.load_ast(nodes, include_source_map=True)
        module_name = module.__name__

        converted_entity_info = _ConvertedEntityFactoryInfo(
            module_name=module_name,
            converted_name=converted_name,
            factory_factory_name=factory_factory_name,
            source_map=source_map)
        _CACHE[entity][subkey] = converted_entity_info
        return converted_entity_info
Example #20
0
  def create(self,
             nodes,
             namer,
             inner_factory_name='inner_factory',
             outer_factory_name='outer_factory',
             future_features=()):
    """Initializes a function."""
    if self._unbound_factory is not None:
      raise ValueError('double initialization; create a new object instead')

    inner_factory_name = namer.new_symbol(inner_factory_name, ())
    outer_factory_name = namer.new_symbol(outer_factory_name, ())
    nodes = _wrap_into_factory(nodes, self._name, inner_factory_name,
                               outer_factory_name, self._freevars,
                               self._extra_locals.keys(), future_features)

    module, _, source_map = loader.load_ast(
        nodes, include_source_map=True)
    outer_factory = getattr(module, outer_factory_name)
    self._unbound_factory = outer_factory()
    self.module = module
    self.source_map = source_map
Example #21
0
  def test_to_ast(self):
    opts = converter.ConversionOptions()
    opts_ast = opts.to_ast()

    template = '''
    def test_fn():
      return opts_ast
    '''
    opts_packed = templates.replace(template, opts_ast=opts_ast)

    reparsed, _, _ = loader.load_ast(opts_packed)
    reparsed.__dict__['ag__'] = self.make_fake_mod(
        'fake_ag', converter.ConversionOptions, converter.Feature)

    reparsed_opts = reparsed.test_fn()

    self.assertEqual(opts.recursive, reparsed_opts.recursive)
    self.assertEqual(opts.user_requested, False)
    self.assertEqual(
        opts.internal_convert_user_code,
        reparsed_opts.internal_convert_user_code)
    self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
Example #22
0
    def test_to_ast(self):
        opts = converter.ConversionOptions()
        opts_ast = opts.to_ast()

        template = '''
    def f():
      return opts_ast
    '''
        opts_packed = templates.replace(template, opts_ast=opts_ast)

        reparsed, _, _ = loader.load_ast(opts_packed)
        fake_ag = imp.new_module('fake_ag')
        fake_ag.ConversionOptions = converter.ConversionOptions
        fake_ag.Feature = converter.Feature
        reparsed.ag__ = fake_ag

        reparsed_opts = reparsed.f()

        self.assertEqual(opts.recursive, reparsed_opts.recursive)
        self.assertEqual(opts.user_requested, False)
        self.assertEqual(opts.internal_convert_user_code,
                         reparsed_opts.internal_convert_user_code)
        self.assertEqual(opts.optional_features,
                         reparsed_opts.optional_features)
Example #23
0
 def debug_print_src(self, node):
     """Helper method useful for debugging. Prints the AST as code."""
     if __debug__:
         print(loader.load_ast(node))
     return node