Beispiel #1
0
    def test_simple_decorator(self):
        def simple_decorator(f):
            return lambda a: f(a) + 1

        # The Python parser does capture decorators into the AST.
        # However, the interpreter desugars them upon load, and refering to the
        # decorated function at runtime usually loses any trace of the decorator.
        # Below is an example when that doesn't happen.
        def static_wrapper():
            @simple_decorator
            def test_fn(a):  # pylint:disable=unused-variable
                return a

        node = self.parse_and_analyze(static_wrapper,
                                      {'simple_decorator': simple_decorator})
        node = node.body[0].body[0]

        node = decorators.transform(node, remove_decorators=())
        result = compiler.ast_to_object(
            node,
            source_prefix=textwrap.dedent(
                tf_inspect.getsource(simple_decorator)))
        self.assertEqual(2, result.test_fn(1))

        node = decorators.transform(node,
                                    remove_decorators=(simple_decorator, ))
        result = compiler.ast_to_object(node)
        self.assertEqual(1, result.test_fn(1))
  def test_simple_decorator(self):

    def simple_decorator(f):
      return lambda a: f(a) + 1

    # The Python parser does capture decorators into the AST.
    # However, the interpreter desugars them upon load, and refering to the
    # decorated function at runtime usually loses any trace of the decorator.
    # Below is an example when that doesn't happen.
    def static_wrapper():

      @simple_decorator
      def test_fn(a):  # pylint:disable=unused-variable
        return a

    node = self.parse_and_analyze(static_wrapper,
                                  {'simple_decorator': simple_decorator})
    node = node.body[0].body[0]

    node = decorators.transform(node, remove_decorators=())
    result = compiler.ast_to_object(
        node,
        source_prefix=textwrap.dedent(tf_inspect.getsource(simple_decorator)))
    self.assertEqual(2, result.test_fn(1))

    node = decorators.transform(node, remove_decorators=(simple_decorator,))
    result = compiler.ast_to_object(node)
    self.assertEqual(1, result.test_fn(1))
Beispiel #3
0
def to_graph(f, arg_value_hints=None):
    """Compile a Python function into equivalent TensorFlow code.

  Args:
    f: A Python function with arbitrary arguments and return values.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.

  Returns:
    A function with a signature identical to `f`, but which when executed it
  creates TF a graph that has the same functionality as the original function.
  """
    conversion_map = conversion.ConversionMap()
    _, name = conversion.object_to_graph(f, conversion_map, arg_value_hints)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.append(parser.parse_str(import_line))
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    compiled_node.__dict__.update(six.get_function_globals(f))

    compiled_fn = getattr(compiled_node, name)
    return compiled_fn
Beispiel #4
0
def to_graph(o, arg_value_hints=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    o: A Python function or class.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap()
  _, name = conversion.object_to_graph(o, conversion_map, arg_value_hints)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.append(parser.parse_str(import_line))
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(o):
    compiled_node.__dict__.update(six.get_function_globals(o))

  compiled_fn = getattr(compiled_node, name)
  return compiled_fn
    def test_continue_deeply_nested(self):
        def test_fn(x):
            v = []
            u = []
            w = []
            while x > 0:
                x -= 1
                if x % 2 == 0:
                    if x % 3 != 0:
                        u.append(x)
                    else:
                        w.append(x)
                        continue
                v.append(x)
            return v, u, w

        node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
        node = continue_canonicalization.transform(node, self.ctx)
        result = compiler.ast_to_object(node)

        self.assertEqual(test_fn(0), result.test_fn(0))
        self.assertEqual(test_fn(1), result.test_fn(1))
        self.assertEqual(test_fn(2), result.test_fn(2))
        self.assertEqual(test_fn(3), result.test_fn(3))
        self.assertEqual(test_fn(4), result.test_fn(4))
Beispiel #6
0
    def test_function_decorator(self):
        def function_decorator():
            def decorator(f):
                return lambda a: f(a) + 1

            return decorator

        # The Python parser does capture decorators into the AST.
        # However, the interpreter desugars them on load, and refering to the
        # decorated function at runtime usually loses any trace of the decorator.
        # Below is an example when that doesn't happen.
        def static_wrapper():
            @function_decorator()
            def test_fn(a):  # pylint:disable=unused-variable
                return a

        node = self.parse_and_analyze(
            static_wrapper, {'function_decorator': function_decorator})
        node = node.body[0].body[0]

        node = decorators.transform(node, remove_decorators=())
        # Since the decorator is not removed, we need to include its source
        # code. We cannot do it after the fact because decorators are executed
        # on load.
        result, _ = compiler.ast_to_object(
            node,
            source_prefix=textwrap.dedent(
                tf_inspect.getsource(function_decorator)))
        self.assertEqual(2, result.test_fn(1))

        node = decorators.transform(node,
                                    remove_decorators=(function_decorator, ))
        with self.compiled(node) as result:
            self.assertEqual(1, result.test_fn(1))
  def test_basic_break_for_loop(self):

    def test_fn(a):
      v = []
      for x in a:
        x -= 1
        if x % 2 == 0:
          break
        v.append(x)
      return v

    # The break is incompletely canonicalized for for loops. Everything is
    # in place except for the condition verification.
    def test_equiv_fn(a):
      v = []
      for x in a:
        x -= 1
        if x % 2 == 0:
          continue
        v.append(x)
      return v

    node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
    node = break_canonicalization.transform(node, TestNamer())
    result = compiler.ast_to_object(node)

    # The break is incompletely canonicalized. Everything is in place, but
    # the loop does not break.
    self.assertEqual(test_equiv_fn([]), result.test_fn([]))
    self.assertEqual(test_equiv_fn([1]), result.test_fn([1]))
    self.assertEqual(test_equiv_fn([2]), result.test_fn([2]))
    self.assertEqual(test_equiv_fn([1, 2, 3, 4]), result.test_fn([1, 2, 3, 4]))
    def test_basic_break_for_loop(self):
        def test_fn(a):
            v = []
            for x in a:
                x -= 1
                if x % 2 == 0:
                    break
                v.append(x)
            return v

        # The break is incompletely canonicalized for for loops. Everything is
        # in place except for the condition verification.
        def test_equiv_fn(a):
            v = []
            for x in a:
                x -= 1
                if x % 2 == 0:
                    continue
                v.append(x)
            return v

        node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
        node = break_canonicalization.transform(node, TestNamer())
        result = compiler.ast_to_object(node)

        # The break is incompletely canonicalized. Everything is in place, but
        # the loop does not break.
        self.assertEqual(test_equiv_fn([]), result.test_fn([]))
        self.assertEqual(test_equiv_fn([1]), result.test_fn([1]))
        self.assertEqual(test_equiv_fn([2]), result.test_fn([2]))
        self.assertEqual(test_equiv_fn([1, 2, 3, 4]),
                         result.test_fn([1, 2, 3, 4]))
  def test_continue_deeply_nested(self):

    def test_fn(x):
      v = []
      u = []
      w = []
      while x > 0:
        x -= 1
        if x % 2 == 0:
          if x % 3 != 0:
            u.append(x)
          else:
            w.append(x)
            continue
        v.append(x)
      return v, u, w

    node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
    node = break_canonicalization.transform(node, TestNamer())
    result = compiler.ast_to_object(node)

    self.assertEqual(test_fn(0), result.test_fn(0))
    self.assertEqual(test_fn(1), result.test_fn(1))
    self.assertEqual(test_fn(2), result.test_fn(2))
    self.assertEqual(test_fn(3), result.test_fn(3))
    self.assertEqual(test_fn(4), result.test_fn(4))
  def test_uncompiled_modules(self):

    def test_fn(a):
      a = math_ops.multiply(a, constant_op.constant(2))
      a = math_ops.add(a, constant_op.constant(1))
      return a

    node = self.parse_and_analyze(
        test_fn, {
            'math_ops': math_ops,
            'constant_op': constant_op
        },
        namer=TestNamer())
    node = call_trees.transform(node, self.ctx,
                                set(((math_ops.__name__,),
                                     (constant_op.__name__,))), ())
    result = compiler.ast_to_object(node)
    setattr(result, 'math_ops', math_ops)
    setattr(result, 'constant_op', constant_op)

    with self.test_session() as sess:
      # Not renamed, because the converter doesn't rename the definition itself.
      # (the caller is responsible for that).
      result_tensor = result.test_fn(constant_op.constant(1))
      result_val = sess.run(result_tensor)

    self.assertEquals(3, result_val)
  def test_function_decorator(self):

    def function_decorator():

      def decorator(f):
        return lambda a: f(a) + 1

      return decorator

    # The Python parser does capture decorators into the AST.
    # However, the interpreter desugars them on load, and refering to the
    # decorated function at runtime usually loses any trace of the decorator.
    # Below is an example when that doesn't happen.
    def static_wrapper():

      @function_decorator()
      def test_fn(a):  # pylint:disable=unused-variable
        return a

    node = self.parse_and_analyze(static_wrapper,
                                  {'function_decorator': function_decorator})
    node = node.body[0].body[0]

    node = decorators.transform(node, remove_decorators=())
    # Since the decorator is not removed, we need to include its source
    # code. We cannot do it after the fact because decorators are executed
    # on load.
    result, _ = compiler.ast_to_object(
        node,
        source_prefix=textwrap.dedent(tf_inspect.getsource(function_decorator)))
    self.assertEqual(2, result.test_fn(1))

    node = decorators.transform(node, remove_decorators=(function_decorator,))
    with self.compiled(node) as result:
      self.assertEqual(1, result.test_fn(1))
Beispiel #12
0
  def test_ast_to_object(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[gast.Name('a', gast.Param(), None)],
            vararg=None,
            kwonlyargs=[],
            kwarg=None,
            defaults=[],
            kw_defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name('a', gast.Load(), None),
                    right=gast.Num(1)))
        ],
        decorator_list=[],
        returns=None)

    module, source = compiler.ast_to_object(node)

    expected_source = """
      def f(a):
        return a + 1
    """
    self.assertEqual(
        textwrap.dedent(expected_source).strip(),
        source.strip())
    self.assertEqual(2, module.f(1))
    with open(module.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent(expected_source).strip(),
          temp_output.read().strip())
Beispiel #13
0
  def test_uncompiled_modules(self):

    def test_fn(a):
      a = math_ops.multiply(a, constant_op.constant(2))
      a = math_ops.add(a, constant_op.constant(1))
      return a

    node = self._parse_and_analyze(test_fn, {
        'math_ops': math_ops,
        'constant_op': constant_op
    })
    node = call_trees.transform(node, TestNamer(), {},
                                set(((math_ops.__name__,),
                                     (constant_op.__name__,))), ())
    result = compiler.ast_to_object(node)
    setattr(result, 'math_ops', math_ops)
    setattr(result, 'constant_op', constant_op)

    with self.test_session() as sess:
      # Not renamed, because the converter doesn't rename the definition itself.
      # (the caller is responsible for that).
      result_tensor = result.test_fn(constant_op.constant(1))
      result_val = sess.run(result_tensor)

    self.assertEquals(3, result_val)
  def test_ast_to_object(self):
    node = gast.FunctionDef(
        name='f',
        args=gast.arguments(
            args=[gast.Name('a', gast.Param(), None)],
            vararg=None,
            kwonlyargs=[],
            kwarg=None,
            defaults=[],
            kw_defaults=[]),
        body=[
            gast.Return(
                gast.BinOp(
                    op=gast.Add(),
                    left=gast.Name('a', gast.Load(), None),
                    right=gast.Num(1)))
        ],
        decorator_list=[],
        returns=None)

    mod = compiler.ast_to_object(node)

    self.assertEqual(2, mod.f(1))
    with open(mod.__file__, 'r') as temp_output:
      self.assertEqual(
          textwrap.dedent("""
              def f(a):
                return a + 1
          """).strip(),
          temp_output.read().strip())
Beispiel #15
0
def to_graph(f, arg_value_hints=None):
  """Compile a Python function into equivalent TensorFlow code.

  Args:
    f: A Python function with arbitrary arguments and return values.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.

  Returns:
    A function with a signature identical to `f`, but which when executed it
  creates TF a graph that has the same functionality as the original function.
  """
  conversion_map = conversion.ConversionMap()
  _, name = conversion.object_to_graph(f, conversion_map, arg_value_hints)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.append(parser.parse_str(import_line))
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  compiled_node.__dict__.update(six.get_function_globals(f))

  compiled_fn = getattr(compiled_node, name)
  return compiled_fn
 def _remover_wrapper(self, f, remove_decorators):
   namespace = {
       'self_removing_decorator': self_removing_decorator,
       'simple_decorator': simple_decorator
   }
   node = self.parse_and_analyze(f, namespace)
   node, _ = decorators.transform(node, remove_decorators=remove_decorators)
   result, _ = compiler.ast_to_object(node)
   return getattr(result, f.__name__)
    def test_replace_tuple(self):
        template = """
      def test_fn(a, c):
        return b,
    """

        node = templates.replace(template, b=('a', 'c'))[0]
        result = compiler.ast_to_object(node)
        self.assertEquals((2, 3), result.test_fn(2, 3))
  def test_replace_tuple(self):
    template = """
      def test_fn(a, c):
        return b,
    """

    node = templates.replace(template, b=('a', 'c'))[0]
    result = compiler.ast_to_object(node)
    self.assertEquals((2, 3), result.test_fn(2, 3))
Beispiel #19
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, graph_ready, convert_inline),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    _, name = conversion.entity_to_graph(e, conversion_map, arg_values,
                                         arg_types)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.append(parser.parse_str(import_line))
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node, compiled_src = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    if tf_inspect.isfunction(e):
        compiled_node.__dict__.update(six.get_function_globals(e))
    compiled_fn = getattr(compiled_node, name)

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled_fn
Beispiel #20
0
    def test_replace_name_with_dict(self):
        template = """
      def test_fn():
        return foo['bar']
    """

        source = parser.parse_expression('{\'bar\': 3}')
        node = templates.replace(template, foo=source)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(3, result.test_fn())
Beispiel #21
0
 def test_keywords_to_dict(self):
     keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
     d = ast_util.keywords_to_dict(keywords)
     # Make sure we generate a usable dict node by attaching it to a variable and
     # compiling everything.
     output = parser.parse_str('b = 3')
     output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d), )
     result, _ = compiler.ast_to_object(output)
     self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
     print(d)
 def compiled(self, node, *symbols):
   source = '<compile failed>'
   try:
     result, source = compiler.ast_to_object(node)
     result.tf = self.make_fake_tf(*symbols)
     result.py2tf_utils = utils
     yield result
   except Exception:  # pylint:disable=broad-except
     print('Offending compiled code:\n%s' % source)
     raise
Beispiel #23
0
 def _remover_wrapper(self, f, remove_decorators):
     namespace = {
         'self_removing_decorator': self_removing_decorator,
         'simple_decorator': simple_decorator
     }
     node = self.parse_and_analyze(f, namespace)
     node, _ = decorators.transform(node,
                                    remove_decorators=remove_decorators)
     result, _ = compiler.ast_to_object(node)
     return getattr(result, f.__name__)
Beispiel #24
0
    def test_noop(self):
        def test_fn(a):
            return a

        node = self.parse_and_analyze(test_fn, {})
        node, deps = decorators.transform(node, remove_decorators=())
        result, _ = compiler.ast_to_object(node)

        self.assertFalse(deps)
        self.assertEqual(1, result.test_fn(1))
  def test_replace_name_with_dict(self):
    template = """
      def test_fn():
        return foo['bar']
    """

    source = parser.parse_expression('{\'bar\': 3}')
    node = templates.replace(template, foo=source)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(3, result.test_fn())
Beispiel #26
0
    def test_transform(self):
        def test_fn(a):
            print(a)

        node = self._parse_and_analyze(test_fn, {'print': print})
        node = print_functions.transform(node)
        result = compiler.ast_to_object(node)

        result.test_fn('a')
        self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
Beispiel #27
0
 def test_keywords_to_dict(self):
   keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
   d = ast_util.keywords_to_dict(keywords)
   # Make sure we generate a usable dict node by attaching it to a variable and
   # compiling everything.
   output = parser.parse_str('b = 3')
   output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),)
   result, _ = compiler.ast_to_object(output)
   self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
   print(d)
Beispiel #28
0
 def compiled(self, node, *symbols):
     source = '<compile failed>'
     try:
         result, source = compiler.ast_to_object(node)
         result.tf = self.make_fake_tf(*symbols)
         result.py2tf_utils = utils
         yield result
     except Exception:  # pylint:disable=broad-except
         print('Offending compiled code:\n%s' % source)
         raise
Beispiel #29
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, graph_ready, convert_inline),
      partial_types=partial_types,
      api_module=tf_inspect.getmodule(to_graph))
  _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.extend(parser.parse_str(import_line).body)
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node, compiled_src = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(e):
    compiled_node.__dict__.update(inspect_utils.getnamespace(e))
  compiled_fn = getattr(compiled_node, name)

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
    def test_bool_ops(self):
        def test_fn(a, b, c):
            return (a or b) and (a or b or c)

        node = self.parse_and_analyze(test_fn, {})
        node = logical_expressions.transform(node)
        result = compiler.ast_to_object(node)
        setattr(result, 'tf', math_ops)

        with self.test_session() as sess:
            self.assertTrue(sess.run(result.test_fn(True, False, True)))
Beispiel #31
0
  def test_replace_variable(self):
    def template(a):  # pylint:disable=unused-argument
      def test_fn(a):  # pylint:disable=unused-variable
        a += 1
        a = 2 * a + 1
        return b  # pylint:disable=undefined-variable

    node = templates.replace(
        template, a=gast.Name('b', gast.Load(), None))[0]
    result = compiler.ast_to_object(node)
    self.assertEquals(7, result.test_fn(2))
  def test_noop(self):

    def test_fn(a):
      return a

    node = self.parse_and_analyze(test_fn, {})
    node, deps = decorators.transform(node, remove_decorators=())
    result, _ = compiler.ast_to_object(node)

    self.assertFalse(deps)
    self.assertEqual(1, result.test_fn(1))
Beispiel #33
0
    def test_replace_variable(self):
        def template(a):  # pylint:disable=unused-argument
            def test_fn(a):  # pylint:disable=unused-variable
                a += 1
                a = 2 * a + 1
                return b  # pylint:disable=undefined-variable

        node = templates.replace(template, a=gast.Name('b', gast.Load(),
                                                       None))[0]
        result = compiler.ast_to_object(node)
        self.assertEquals(7, result.test_fn(2))
Beispiel #34
0
  def test_replace_function_name(self):
    def template(fname):  # pylint:disable=unused-argument
      def fname(a):  # pylint:disable=function-redefined
        a += 1
        a = 2 * a + 1
        return a

    node = templates.replace(
        template, fname=gast.Name('test_fn', gast.Load(), None))[0]
    result = compiler.ast_to_object(node)
    self.assertEquals(7, result.test_fn(2))
  def test_transform(self):

    def test_fn(a):
      print(a)

    node = self.parse_and_analyze(test_fn, {'print': print})
    node = print_functions.transform(node)
    result = compiler.ast_to_object(node)

    result.test_fn('a')
    self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
Beispiel #36
0
    def test_replace_variable(self):
        template = """
      def test_fn(a):
        a += 1
        a = 2 * a + 1
        return b
    """

        node = templates.replace(template, a='b')[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(7, result.test_fn(2))
  def test_replace_variable(self):
    template = """
      def test_fn(a):
        a += 1
        a = 2 * a + 1
        return b
    """

    node = templates.replace(template, a='b')[0]
    result = compiler.ast_to_object(node)
    self.assertEquals(7, result.test_fn(2))
Beispiel #38
0
    def test_replace_function_name(self):
        template = """
      def fname(a):
        a += 1
        a = 2 * a + 1
        return a
    """

        node = templates.replace(template, fname='test_fn')[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(7, result.test_fn(2))
  def test_replace_function_name(self):
    template = """
      def fname(a):
        a += 1
        a = 2 * a + 1
        return a
    """

    node = templates.replace(template, fname='test_fn')[0]
    result = compiler.ast_to_object(node)
    self.assertEquals(7, result.test_fn(2))
  def test_bool_ops(self):

    def test_fn(a, b, c):
      return (a or b) and (a or b or c)

    node = self.parse_and_analyze(test_fn, {})
    node = logical_expressions.transform(node)
    result = compiler.ast_to_object(node)
    setattr(result, 'tf', math_ops)

    with self.test_session() as sess:
      self.assertTrue(sess.run(result.test_fn(True, False, True)))
    def test_equals(self):
        def test_fn(a, b):
            return a == b

        node = self.parse_and_analyze(test_fn, {})
        node = logical_expressions.transform(node)
        result = compiler.ast_to_object(node)
        setattr(result, 'tf', math_ops)

        with self.test_session() as sess:
            self.assertTrue(sess.run(result.test_fn(1, 1)))
            self.assertFalse(sess.run(result.test_fn(1, 2)))
Beispiel #42
0
    def test_len(self):
        def test_fn(a):
            return len(a)

        node = self.parse_and_analyze(test_fn, {'len': len})
        node = builtin_functions.transform(node)
        result = compiler.ast_to_object(node)
        setattr(result, 'tf', array_ops)

        with self.test_session() as sess:
            self.assertEqual(
                3, sess.run(result.test_fn(constant_op.constant([0, 0, 0]))))
Beispiel #43
0
    def test_replace_function_name(self):
        def template(fname):  # pylint:disable=unused-argument
            def fname(a):  # pylint:disable=function-redefined
                a += 1
                a = 2 * a + 1
                return a

        node = templates.replace(template,
                                 fname=gast.Name('test_fn', gast.Load(),
                                                 None))[0]
        result = compiler.ast_to_object(node)
        self.assertEquals(7, result.test_fn(2))
  def test_equals(self):

    def test_fn(a, b):
      return a == b

    node = self.parse_and_analyze(test_fn, {})
    node = logical_expressions.transform(node)
    result = compiler.ast_to_object(node)
    setattr(result, 'tf', math_ops)

    with self.test_session() as sess:
      self.assertTrue(sess.run(result.test_fn(1, 1)))
      self.assertFalse(sess.run(result.test_fn(1, 2)))
Beispiel #45
0
    def test_replace_attribute(self):
        template = """
      def test_fn(a):
        return a.foo
    """

        node = templates.replace(template, foo='b')[0]
        result, _ = compiler.ast_to_object(node)
        mod = imp.new_module('test')
        mod.b = 3
        self.assertEquals(3, result.test_fn(mod))

        with self.assertRaises(ValueError):
            templates.replace(template, foo=1)
  def test_while_single_var(self):

    def test_fn(n):
      while n > 0:
        n -= 1
      return n

    node = self._parse_and_analyze(test_fn, {})
    node = control_flow.transform(node, TestNamer())
    result = compiler.ast_to_object(node)
    setattr(result, 'tf', control_flow_ops)

    with self.test_session() as sess:
      self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
Beispiel #47
0
    def test_while_single_var(self):
        def test_fn(n):
            while n > 0:
                n -= 1
            return n

        node = self._parse_and_analyze(test_fn, {})
        node = control_flow.transform(node, TestNamer())
        result = compiler.ast_to_object(node)
        setattr(result, 'tf', control_flow_ops)

        with self.test_session() as sess:
            self.assertEqual(0,
                             sess.run(result.test_fn(constant_op.constant(5))))
  def test_len(self):

    def test_fn(a):
      return len(a)

    node = self.parse_and_analyze(test_fn, {'len': len})
    node = builtin_functions.transform(node, self.ctx)
    result = compiler.ast_to_object(node)
    setattr(result, 'tf', array_ops)

    with self.test_session() as sess:
      self.assertEqual(3,
                       sess.run(
                           result.test_fn(constant_op.constant([0, 0, 0]))))
  def test_replace_attribute(self):
    template = """
      def test_fn(a):
        return a.foo
    """

    node = templates.replace(template, foo='b')[0]
    result, _ = compiler.ast_to_object(node)
    mod = imp.new_module('test')
    mod.b = 3
    self.assertEquals(3, result.test_fn(mod))

    with self.assertRaises(ValueError):
      templates.replace(template, foo=1)
Beispiel #50
0
  def test_parser_compile_idempotent(self):

    def test_fn(x):
      a = True
      b = ''
      if a:
        b = x + 1
      return b

    self.assertEqual(
        textwrap.dedent(tf_inspect.getsource(test_fn)),
        tf_inspect.getsource(
            compiler.ast_to_object(
                parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
  def test_if_single_var(self):

    def test_fn(n):
      if n > 0:
        n = -n
      return n

    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
    node = control_flow.transform(node, self.ctx)
    result = compiler.ast_to_object(node)
    setattr(result, 'tf', control_flow_ops)

    with self.test_session() as sess:
      self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
  def test_replace_name_with_call(self):
    template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

    source = parser.parse_expression('f()(b)')
    node = templates.replace(template, foo=source)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(15, result.test_fn())
    def test_basic_for(self):
        def test_fn(l):
            s = 0
            for e in l:
                s += e
            return s

        node = self._parse_and_analyze(test_fn, {})
        node = for_canonicalization.transform(node, TestNamer())
        result = compiler.ast_to_object(node)

        l = [1, 2, 3]
        self.assertEqual(test_fn(l), result.test_fn(l))
        l = []
        self.assertEqual(test_fn(l), result.test_fn(l))
Beispiel #54
0
    def test_code_block(self):
        def template(block):  # pylint:disable=unused-argument
            def test_fn(a):  # pylint:disable=unused-variable
                block  # pylint:disable=pointless-statement
                return a

        node = templates.replace(
            template,
            block=[
                gast.Assign([gast.Name('a', gast.Store(), None)],
                            gast.BinOp(gast.Name('a', gast.Load(), None),
                                       gast.Add(), gast.Num(1))),
            ] * 2)[0]
        result = compiler.ast_to_object(node)
        self.assertEquals(3, result.test_fn(1))
Beispiel #55
0
    def test_replace_name_with_call(self):
        template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

        source = parser.parse_expression('f()(b)')
        node = templates.replace(template, foo=source)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(15, result.test_fn())
  def test_basic_for(self):

    def test_fn(l):
      s = 0
      for e in l:
        s += e
      return s

    node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
    node = for_canonicalization.transform(node, self.ctx)
    result = compiler.ast_to_object(node)

    l = [1, 2, 3]
    self.assertEqual(test_fn(l), result.test_fn(l))
    l = []
    self.assertEqual(test_fn(l), result.test_fn(l))
  def test_code_block(self):
    template = """
      def test_fn(a):
        block
        return a
    """

    node = templates.replace(
        template,
        block=[
            gast.Assign([
                gast.Name('a', None, None)
            ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
        ] * 2)[0]
    result = compiler.ast_to_object(node)
    self.assertEquals(3, result.test_fn(1))
  def test_replace_call_keyword(self):
    template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

    source = parser.parse_expression('f(d=3, f=5)')
    node = templates.replace(template, kws=source.keywords)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(9, result.test_fn())

    with self.assertRaises(ValueError):
      templates.replace(template, kws=[])
      templates.replace(template, kws=1)