Beispiel #1
0
def replace(template, **replacements):
  """Replace placeholders in a Python template.

  Args:
    template: A function to be used as a template. Any placeholder is expected
        to also be a function argument.
    **replacements: A mapping from placeholder names to (lists of) AST nodes
        that these placeholders will be replaced by.

  Returns:
    body: An AST node or list of AST nodes with the replacements made. If the
        template was a function, a list will be returned. If the template was a
        node, the same node will be returned. If the template was a string, an
        AST node will be returned (a `Module` node in the case of a multi-line
        string, an `Expr` node otherwise).

  Raises:
    ValueError: If a function is used as a template and an incorrect set of
        replacements was passed.
  """
  tree = parser.parse_object(template).body[0]
  placeholders = set(arg.id for arg in tree.args.args)
  tree.args.args = []
  if tree.args.vararg:
    placeholders.add(tree.args.vararg)
    tree.args.vararg = None
  if set(replacements.keys()) != placeholders:
    raise ValueError(
        'too many or few replacements. replacements: %s; placeholders: %s' %
        (replacements.keys(), placeholders))

  # Perform the replacement, stripping the function into which the template was
  # wrapped.
  return ReplaceTransformer(replacements).visit(tree).body
def function_to_graph(f, conversion_map, arg_values, arg_types,
                      owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""
  node = parser.parse_object(f).body[0]
  namespace = six.get_function_globals(f)

  # This is needed for non-global functions.
  closure = six.get_function_closure(f)
  if closure:
    for e in closure:
      if callable(e.cell_contents):
        fn = e.cell_contents
        namespace[fn.__name__] = fn

  namer = conversion_map.new_namer(namespace)
  ctx = context.EntityContext(
      namer=namer,
      source_code=tf_inspect.getsource(f),
      source_file=tf_inspect.getfile(f),
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types)
  node = node_to_graph(node, ctx, conversion_map.nocompile_decorators)

  # Simulate a rename to ensure the top level is in the name map. This is needed
  # for top level functions, and it also helps the consistency verification made
  # by update_name_map.
  if owner_type is not None:
    new_name = namer.compiled_function_name(f.__name__, f, owner_type)
  else:
    new_name = namer.compiled_function_name(f.__name__, f)
  node.name = new_name
  conversion_map.update_name_map(namer)
  return node, conversion_map.name_map[f]
Beispiel #3
0
  def test_if(self):

    def test_fn(x):
      if x > 0:
        x = -x
        y = 2 * x
        z = -y
      else:
        x = 2 * x
        y = -x
        u = -y
      return z, u

    node = parser.parse_object(test_fn)
    node = access.resolve(node)

    if_node = node.body[0].body[0]
    self.assertScopeIs(
        anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'),
        ('y', 'z'))
    # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
    self.assertScopeIs(
        anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'),
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
    self.assertScopeIs(
        anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'),
        ('y', 'u'))
    self.assertScopeIs(
        anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'),
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
Beispiel #4
0
def function_to_graph(f, conversion_map, param_value_hints, owner_type=None):
  """Specialization of `object_to_graph` for callable functions."""
  node = parser.parse_object(f).body[0]
  node_globals = six.get_function_globals(f)

  # This is needed for non-global functions.
  closure = six.get_function_closure(f)
  if closure:
    for e in closure:
      if callable(e.cell_contents):
        fn = e.cell_contents
        node_globals[fn.__name__] = fn

  namer = conversion_map.new_namer(node_globals)
  node = node_to_graph(node, namer, node_globals, param_value_hints)

  # Simulate a rename to ensure the top level is in the name map. This is needed
  # for top level functions, and it also helps the consistency verification made
  # by update_name_map.
  if owner_type is not None:
    new_name = namer.compiled_function_name(f.__name__, f, owner_type)
  else:
    new_name = namer.compiled_function_name(f.__name__, f)
  node.name = new_name
  conversion_map.update_name_map(namer)
  return node, conversion_map.name_map[f]
Beispiel #5
0
  def test_print_statement(self):

    def test_fn(a):
      b = 0
      c = 1
      print(a, b)
      return c

    node = parser.parse_object(test_fn)
    node = access.resolve(node)

    print_node = node.body[0].body[2]
    if isinstance(print_node, gast.Print):
      # Python 2
      print_args_scope = anno.getanno(print_node, 'args_scope')
    else:
      # Python 3
      assert isinstance(print_node, gast.Expr)
      # The call node should be the one being annotated.
      print_node = print_node.value
      print_args_scope = anno.getanno(print_node, 'args_scope')

    # We basically need to detect which variables are captured by the call
    # arguments.
    self.assertItemsEqual(['a', 'b'], print_args_scope.used)
    self.assertItemsEqual([], print_args_scope.modified)
    self.assertItemsEqual([], print_args_scope.created)
Beispiel #6
0
def function_to_graph(f, conversion_map, param_value_hints, owner_type=None):
    """Specialization of `object_to_graph` for callable functions."""
    node = parser.parse_object(f).body[0]
    node_globals = six.get_function_globals(f)

    # This is needed for non-global functions.
    closure = six.get_function_closure(f)
    if closure:
        for e in closure:
            if callable(e.cell_contents):
                fn = e.cell_contents
                node_globals[fn.__name__] = fn

    namer = conversion_map.new_namer(node_globals)
    node = node_to_graph(node, namer, node_globals, param_value_hints)

    # Simulate a rename to ensure the top level is in the name map. This is needed
    # for top level functions, and it also helps the consistency verification made
    # by update_name_map.
    if owner_type is not None:
        new_name = namer.compiled_function_name(f.__name__, f, owner_type)
    else:
        new_name = namer.compiled_function_name(f.__name__, f)
    node.name = new_name
    conversion_map.update_name_map(namer)
    return node, conversion_map.name_map[f]
Beispiel #7
0
def function_to_graph(f, conversion_map, param_value_hints):
  """Specialization of `object_to_graph` for callable functions."""
  node = parser.parse_object(f).body[0]
  node_globals = six.get_function_globals(f)

  # This is needed for non-global functions.
  closure = six.get_function_closure(f)
  if closure:
    for e in closure:
      if callable(e.cell_contents):
        fn = e.cell_contents
        node_globals[fn.__name__] = fn

  namer = conversion_map.new_namer(node_globals)
  node = node_to_graph(node, namer, node_globals, param_value_hints)

  # Simulate a rename to ensure the top level is in the name map. This is needed
  # for top level functions, and it also helps the consistency verification made
  # by update_name_map.
  namer.compiled_function_name(f.__name__, f)

  conversion_map.add_to_cache(f, node)
  conversion_map.update_name_map(namer)

  # Recursively convert any remaining dependencies.
  for obj in conversion_map.name_map.keys():
    if obj not in conversion_map.dependency_cache:
      object_to_graph(obj, conversion_map, None)
  return node, conversion_map.name_map[f]
Beispiel #8
0
def function_to_graph(f, conversion_map, param_value_hints):
    """Specialization of `object_to_graph` for callable functions."""
    node = parser.parse_object(f).body[0]
    node_globals = six.get_function_globals(f)

    # This is needed for non-global functions.
    closure = six.get_function_closure(f)
    if closure:
        for e in closure:
            if callable(e.cell_contents):
                fn = e.cell_contents
                node_globals[fn.__name__] = fn

    namer = conversion_map.new_namer(node_globals)
    node = node_to_graph(node, namer, node_globals, param_value_hints)

    # Simulate a rename to ensure the top level is in the name map. This is needed
    # for top level functions, and it also helps the consistency verification made
    # by update_name_map.
    namer.compiled_function_name(f.__name__, f)

    conversion_map.add_to_cache(f, node)
    conversion_map.update_name_map(namer)

    # Recursively convert any remaining dependencies.
    for obj in conversion_map.name_map.keys():
        if obj not in conversion_map.dependency_cache:
            object_to_graph(obj, conversion_map, None)
    return node, conversion_map.name_map[f]
Beispiel #9
0
def replace(template, **replacements):
  """Replace placeholders in a Python template.

  Args:
    template: A function to be used as a template. Any placeholder is expected
        to also be a function argument.
    **replacements: A mapping from placeholder names to (lists of) AST nodes
        that these placeholders will be replaced by.

  Returns:
    body: An AST node or list of AST nodes with the replacements made. If the
        template was a function, a list will be returned. If the template was a
        node, the same node will be returned. If the template was a string, an
        AST node will be returned (a `Module` node in the case of a multi-line
        string, an `Expr` node otherwise).

  Raises:
    ValueError: If a function is used as a template and an incorrect set of
        replacements was passed.
  """
  tree = parser.parse_object(template).body[0]
  placeholders = set(arg.id for arg in tree.args.args)
  tree.args.args = []
  if tree.args.vararg:
    placeholders.add(tree.args.vararg)
    tree.args.vararg = None
  if set(replacements.keys()) != placeholders:
    raise ValueError(
        'too many or few replacements. replacements: %s; placeholders: %s' %
        (replacements.keys(), placeholders))

  # Perform the replacement, stripping the function into which the template was
  # wrapped.
  return ReplaceTransformer(replacements).visit(tree).body
Beispiel #10
0
    def test_print_statement(self):
        def test_fn(a):
            b = 0
            c = 1
            print(a, b)
            return c

        node = parser.parse_object(test_fn)
        node = access.resolve(node)

        print_node = node.body[0].body[2]
        if isinstance(print_node, gast.Print):
            # Python 2
            print_args_scope = anno.getanno(print_node, 'args_scope')
        else:
            # Python 3
            assert isinstance(print_node, gast.Expr)
            # The call node should be the one being annotated.
            print_node = print_node.value
            print_args_scope = anno.getanno(print_node, 'args_scope')

        # We basically need to detect which variables are captured by the call
        # arguments.
        self.assertItemsEqual(['a', 'b'], print_args_scope.used)
        self.assertItemsEqual([], print_args_scope.modified)
        self.assertItemsEqual([], print_args_scope.created)
Beispiel #11
0
    def test_if(self):
        def test_fn(x):
            if x > 0:
                x = -x
                y = 2 * x
                z = -y
            else:
                x = 2 * x
                y = -x
                u = -y
            return z, u

        node = parser.parse_object(test_fn)
        node = access.resolve(node)

        if_node = node.body[0].body[0]
        self.assertScopeIs(anno.getanno(if_node, 'body_scope'), ('x', 'y'),
                           ('x', 'y', 'z'), ('y', 'z'))
        # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
        self.assertScopeIs(anno.getanno(if_node,
                                        'body_parent_scope'), ('x', 'z', 'u'),
                           ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
        self.assertScopeIs(anno.getanno(if_node, 'orelse_scope'), ('x', 'y'),
                           ('x', 'y', 'u'), ('y', 'u'))
        self.assertScopeIs(anno.getanno(if_node,
                                        'body_parent_scope'), ('x', 'z', 'u'),
                           ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
Beispiel #12
0
    def test_parameter_class_members(self):
        def test_fn(opt):
            opt.minimize(0)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'training': training}, {})
        with self.assertRaises(ValueError):
            node = type_info.resolve(node, None)
Beispiel #13
0
    def test_nested_members(self):
        def test_fn():
            foo = training.GradientDescentOptimizer(0.1)
            foo.bar.baz()

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'training': training}, {})
        with self.assertRaises(ValueError):
            node = type_info.resolve(node, None)
Beispiel #14
0
  def test_parameter_class_members(self):

    def test_fn(opt):
      opt.minimize(0)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'training': training}, {})
    with self.assertRaises(ValueError):
      node = type_info.resolve(node, None)
Beispiel #15
0
    def test_literals(self):
        def test_fn():
            return Foo  # pylint: disable=undefined-variable

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {}, {'Foo': 'bar'})

        retval_node = node.body[0].body[0].value
        self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
  def test_literals(self):

    def test_fn():
      return Foo  # pylint: disable=undefined-variable

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {}, {'Foo': 'bar'})

    retval_node = node.body[0].body[0].value
    self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
    def test_transform(self):
        def test_fn(a, b, c):
            return (a or b) and (a or b or c)

        node = parser.parse_object(test_fn)
        node = logical_expressions.transform(node)
        result = compiler.ast_to_object(node)
        setattr(result, 'tf', math_ops)

        with self.test_session() as sess:
            self.assertTrue(sess.run(result.test_fn(True, False, True)))
Beispiel #18
0
  def test_nested_members(self):

    def test_fn():
      foo = training.GradientDescentOptimizer(0.1)
      foo.bar.baz()

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'training': training}, {})
    with self.assertRaises(ValueError):
      node = type_info.resolve(node, None)
  def test_transform(self):

    def test_fn(a, b, c):
      return (a or b) and (a or b or c)

    node = parser.parse_object(test_fn)
    node = logical_expressions.transform(node)
    result = compiler.ast_to_object(node)
    setattr(result, 'tf', math_ops)

    with self.test_session() as sess:
      self.assertTrue(sess.run(result.test_fn(True, False, True)))
 def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
     ctx = context.EntityContext(namer=None,
                                 source_code=None,
                                 source_file=None,
                                 namespace=namespace,
                                 arg_values=None,
                                 arg_types=arg_types)
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     node = live_values.resolve(node, namespace, {})
     node = type_info.resolve(node, ctx)
     return node
Beispiel #21
0
    def test_attribute_names(self):
        def test_fn():
            return constant_op.constant(0)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'constant_op': constant_op}, {})

        func_node = node.body[0].body[0].value.func
        self.assertEquals(constant_op.constant,
                          anno.getanno(func_node, 'live_val'))
        self.assertEquals((constant_op.__name__, 'constant'),
                          anno.getanno(func_node, 'fqn'))
Beispiel #22
0
    def test_class_members(self):
        def test_fn():
            opt = training.GradientDescentOptimizer(0.1)
            opt.minimize(0)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'training': training}, {})
        node = type_info.resolve(node, None)

        attr_call_node = node.body[0].body[1].value.func
        self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
                          anno.getanno(attr_call_node, 'type_fqn'))
Beispiel #23
0
  def test_equals(self):

    def test_fn(a, b):
      return a == b

    node = parser.parse_object(test_fn)
    node = logical_expressions.transform(node)
    result = compiler.ast_to_object(node)
    setattr(result, 'tf', math_ops)

    with self.test_session() as sess:
      self.assertTrue(sess.run(result.test_fn(1, 1)))
      self.assertFalse(sess.run(result.test_fn(1, 2)))
  def test_attribute_names(self):

    def test_fn():
      return constant_op.constant(0)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'constant_op': constant_op}, {})

    func_node = node.body[0].body[0].value.func
    self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val'))
    self.assertEquals((constant_op.__name__, 'constant'),
                      anno.getanno(func_node, 'fqn'))
Beispiel #25
0
    def test_function_variables(self):
        def bar():
            pass

        def test_fn():
            foo = bar
            foo()

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'bar': bar}, {})
        with self.assertRaises(ValueError):
            node = type_info.resolve(node, None)
Beispiel #26
0
    def test_constructor_deta_dependent(self):
        def test_fn(x):
            if x > 0:
                opt = training.GradientDescentOptimizer(0.1)
            else:
                opt = training.GradientDescentOptimizer(0.01)
            opt.minimize(0)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'training': training}, {})
        with self.assertRaises(ValueError):
            node = type_info.resolve(node, None)
Beispiel #27
0
  def test_constructor_deta_dependent(self):

    def test_fn(x):
      if x > 0:
        opt = training.GradientDescentOptimizer(0.1)
      else:
        opt = training.GradientDescentOptimizer(0.01)
      opt.minimize(0)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'training': training}, {})
    with self.assertRaises(ValueError):
      node = type_info.resolve(node, {})
Beispiel #28
0
  def test_class_members(self):

    def test_fn():
      opt = training.GradientDescentOptimizer(0.1)
      opt.minimize(0)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'training': training}, {})
    node = type_info.resolve(node, None)

    attr_call_node = node.body[0].body[1].value.func
    self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
                      anno.getanno(attr_call_node, 'type_fqn'))
Beispiel #29
0
    def test_namespace(self):
        def foo():
            return 'bar'

        def test_fn():
            return foo()

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'foo': foo}, {})

        func_node = node.body[0].body[0].value.func
        self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
        self.assertEquals(('foo', ), anno.getanno(func_node, 'fqn'))
Beispiel #30
0
  def test_function_variables(self):

    def bar():
      pass

    def test_fn():
      foo = bar
      foo()

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'bar': bar}, {})
    with self.assertRaises(ValueError):
      node = type_info.resolve(node, None)
Beispiel #31
0
    def test_constructor_detection(self):
        def test_fn():
            opt = training.GradientDescentOptimizer(0.1)
            return opt

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'training': training}, {})
        node = type_info.resolve(node, None)

        call_node = node.body[0].body[0].value
        self.assertEquals(training.GradientDescentOptimizer,
                          anno.getanno(call_node, 'type'))
        self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
                          anno.getanno(call_node, 'type_fqn'))
Beispiel #32
0
    def test_call(self):
        def test_fn(a):
            b = 0
            c = 1
            foo(a, b)  # pylint:disable=undefined-variable
            return c

        node = parser.parse_object(test_fn)
        node = access.resolve(node)

        call_node = node.body[0].body[2].value
        # We basically need to detect which variables are captured by the call
        # arguments.
        self.assertScopeIs(anno.getanno(call_node, 'args_scope'), ('a', 'b'),
                           (), ())
  def test_namespace(self):

    def foo():
      return 'bar'

    def test_fn():
      return foo()

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'foo': foo}, {})

    func_node = node.body[0].body[0].value.func
    self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
    self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
Beispiel #34
0
  def test_call(self):

    def test_fn(a):
      b = 0
      c = 1
      foo(a, b)  # pylint:disable=undefined-variable
      return c

    node = parser.parse_object(test_fn)
    node = access.resolve(node)

    call_node = node.body[0].body[2].value
    # We basically need to detect which variables are captured by the call
    # arguments.
    self.assertScopeIs(
        anno.getanno(call_node, 'args_scope'), ('a', 'b'), (), ())
Beispiel #35
0
  def test_constructor_detection(self):

    def test_fn():
      opt = training.GradientDescentOptimizer(0.1)
      return opt

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'training': training}, {})
    node = type_info.resolve(node, None)

    call_node = node.body[0].body[0].value
    self.assertEquals(training.GradientDescentOptimizer,
                      anno.getanno(call_node, 'type'))
    self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
                      anno.getanno(call_node, 'type_fqn'))
Beispiel #36
0
    def test_while(self):
        def test_fn(a):
            b = a
            while b > 0:
                c = b
                b -= 1
            return b, c

        node = parser.parse_object(test_fn)
        node = access.resolve(node)

        while_node = node.body[0].body[1]
        self.assertScopeIs(anno.getanno(while_node, 'body_scope'), ('b', ),
                           ('b', 'c'), ('c', ))
        self.assertScopeIs(anno.getanno(while_node, 'body_parent_scope'),
                           ('a', 'b', 'c'), ('a', 'b', 'c'), ('a', 'b', 'c'))
Beispiel #37
0
    def test_parameter_class_members_with_value_hints(self):
        def test_fn(opt):
            opt.minimize(0)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'training': training}, {})
        node = type_info.resolve(
            node, {
                'opt': (('%s.GradientDescentOptimizer' % training.__name__),
                        training.GradientDescentOptimizer(0.1))
            })

        attr_call_node = node.body[0].body[0].value.func
        self.assertEquals(
            training.__name__.split('.') + ['GradientDescentOptimizer'],
            anno.getanno(attr_call_node, 'type_fqn'))
Beispiel #38
0
    def test_local_markers(self):
        def test_fn(a):  # pylint:disable=unused-argument
            b = c  # pylint:disable=undefined-variable
            while b > 0:
                b -= 1
            return b

        node = parser.parse_object(test_fn)
        node = access.resolve(node)

        self.assertFalse(anno.getanno(node.body[0].body[0].value,
                                      'is_local'))  # c in b = c
        self.assertTrue(
            anno.getanno(node.body[0].body[1].test.left,
                         'is_local'))  # b in b > 0
        self.assertTrue(anno.getanno(node.body[0].body[2].value,
                                     'is_local'))  # b in return b
Beispiel #39
0
  def test_local_markers(self):

    def test_fn(a):  # pylint:disable=unused-argument
      b = c  # pylint:disable=undefined-variable
      while b > 0:
        b -= 1
      return b

    node = parser.parse_object(test_fn)
    node = access.resolve(node)

    self.assertFalse(anno.getanno(node.body[0].body[0].value,
                                  'is_local'))  # c in b = c
    self.assertTrue(anno.getanno(node.body[0].body[1].test.left,
                                 'is_local'))  # b in b > 0
    self.assertTrue(anno.getanno(node.body[0].body[2].value,
                                 'is_local'))  # b in return b
Beispiel #40
0
  def test_for(self):

    def test_fn(a):
      b = a
      for _ in a:
        c = b
        b -= 1
      return b, c

    node = parser.parse_object(test_fn)
    node = access.resolve(node)

    for_node = node.body[0].body[1]
    self.assertScopeIs(
        anno.getanno(for_node, 'body_scope'), ('b',), ('b', 'c'), ('c',))
    self.assertScopeIs(
        anno.getanno(for_node, 'body_parent_scope'), ('a', 'b', 'c'),
        ('a', 'b', 'c', '_'), ('a', 'b', 'c', '_'))
Beispiel #41
0
  def test_parameter_class_members_with_value_hints(self):

    def test_fn(opt):
      opt.minimize(0)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'training': training}, {})
    node = type_info.resolve(
        node, {
            'opt': (('%s.GradientDescentOptimizer' % training.__name__),
                    training.GradientDescentOptimizer(0.1))
        })

    attr_call_node = node.body[0].body[0].value.func
    self.assertEquals(
        training.__name__.split('.') + ['GradientDescentOptimizer'],
        anno.getanno(attr_call_node, 'type_fqn'))
 def parse_and_analyze(self,
                       test_fn,
                       namespace,
                       arg_types=None,
                       include_type_analysis=True):
   ctx = context.EntityContext(
       namer=None,
       source_code=None,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types)
   node = parser.parse_object(test_fn)
   node = access.resolve(node)
   node = live_values.resolve(node, namespace, {})
   if include_type_analysis:
     node = type_info.resolve(node, ctx)
   return node
Beispiel #43
0
    def test_call(self):
        def test_fn(a):
            b = 0
            c = 1
            foo(a, b)  # pylint:disable=undefined-variable
            return c

        node = parser.parse_object(test_fn)
        node = access.resolve(node)

        call_node = node.body[0].body[2].value
        call_args_scope = anno.getanno(call_node, 'args_scope')

        # We basically need to detect which variables are captured by the call
        # arguments.
        self.assertItemsEqual(['a', 'b'], call_args_scope.used)
        self.assertItemsEqual([], call_args_scope.modified)
        self.assertItemsEqual([], call_args_scope.created)
    def test_transform(self):
        def loss(x, w):
            return x * w

        def test_fn(x, w):
            l, (dw, ) = tfe.value_and_gradients_function(loss, [1])(x, w)  # pylint:disable=undefined-variable
            return l, dw

        node = parser.parse_object(test_fn)
        node = gradients_function.transform(node)
        result = compiler.ast_to_object(node)
        setattr(result, 'tf', gradients_impl)
        setattr(result, 'loss', loss)

        with self.test_session() as sess:
            self.assertEqual((12, 3),
                             sess.run(
                                 result.test_fn(constant_op.constant(3),
                                                constant_op.constant(4))))
Beispiel #45
0
  def test_call(self):

    def test_fn(a):
      b = 0
      c = 1
      foo(a, b)  # pylint:disable=undefined-variable
      return c

    node = parser.parse_object(test_fn)
    node = access.resolve(node)

    call_node = node.body[0].body[2].value
    call_args_scope = anno.getanno(call_node, 'args_scope')

    # We basically need to detect which variables are captured by the call
    # arguments.
    self.assertItemsEqual(['a', 'b'], call_args_scope.used)
    self.assertItemsEqual([], call_args_scope.modified)
    self.assertItemsEqual([], call_args_scope.created)
Beispiel #46
0
  def test_class_members_in_with_stmt(self):

    def test_fn(x):
      with session.Session() as sess:
        sess.run(x)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'session': session}, {})
    node = type_info.resolve(node, {})

    constructor_call = node.body[0].body[0].items[0].context_expr
    self.assertEquals(session.Session, anno.getanno(constructor_call, 'type'))
    self.assertEquals((session.__name__, 'Session'),
                      anno.getanno(constructor_call, 'type_fqn'))

    member_call = node.body[0].body[0].body[0].value.func
    self.assertEquals((session.__name__, 'Session'),
                      anno.getanno(member_call, 'type_fqn'))
Beispiel #47
0
    def test_class_members_in_with_stmt(self):
        def test_fn(x):
            with session.Session() as sess:
                sess.run(x)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'session': session}, {})
        node = type_info.resolve(node, None)

        constructor_call = node.body[0].body[0].items[0].context_expr
        self.assertEquals(session.Session,
                          anno.getanno(constructor_call, 'type'))
        self.assertEquals((session.__name__, 'Session'),
                          anno.getanno(constructor_call, 'type_fqn'))

        member_call = node.body[0].body[0].body[0].value.func
        self.assertEquals((session.__name__, 'Session'),
                          anno.getanno(member_call, 'type_fqn'))
Beispiel #48
0
  def _should_compile(self, node, fqn):
    for i in range(1, len(fqn)):
      if fqn[:i] in self.uncompiled_modules:
        return False

    # Check for local decorations
    if anno.hasanno(node, 'graph_ready'):
      return False

    # The decorators themselves are not to be converted.
    # If present, the decorators should appear as static functions.
    target_obj = self._try_resolve_target(node.func)
    if target_obj is not None:
      # This attribute is set by the decorator itself.
      # TODO(mdan): This may not play nicely with other wrapping decorators.
      if hasattr(target_obj, '__pyct_is_compile_decorator'):
        return False

      if target_obj in self.nocompile_decorators:
        return False

      # Inspect the target function decorators. If any include a @convert
      # or @graph_ready annotation, then they must be called as they are.
      # TODO(mdan): This may be quite heavy.
      # To parse and re-analize each function for every call site could be quite
      # wasteful. Maybe we could cache the parsed AST?
      try:
        target_node = parser.parse_object(target_obj).body[0]
      except TypeError:
        # Functions whose source we cannot access are compilable (e.g. wrapped
        # to py_func).
        return True

      for dec in target_node.decorator_list:
        decorator_fn = self._resolve_name(dec)
        if (decorator_fn is not None and
            decorator_fn in self.nocompile_decorators):
          return False

    return True
Beispiel #49
0
    def test_while(self):
        def test_fn(a):
            b = a
            while b > 0:
                c = b
                b -= 1
            return b, c

        node = parser.parse_object(test_fn)
        node = access.resolve(node)

        while_node = node.body[0].body[1]
        while_body_scope = anno.getanno(while_node, 'body_scope')
        while_parent_scope = anno.getanno(while_node, 'parent_scope')

        self.assertItemsEqual(['b'], while_body_scope.used)
        self.assertItemsEqual(['b', 'c'], while_body_scope.modified)
        self.assertItemsEqual(['c'], while_body_scope.created)

        self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.used)
        self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.modified)
        self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.created)
Beispiel #50
0
  def test_while(self):

    def test_fn(a):
      b = a
      while b > 0:
        c = b
        b -= 1
      return b, c

    node = parser.parse_object(test_fn)
    node = access.resolve(node)

    while_node = node.body[0].body[1]
    while_body_scope = anno.getanno(while_node, 'body_scope')
    while_parent_scope = anno.getanno(while_node, 'parent_scope')

    self.assertItemsEqual(['b'], while_body_scope.used)
    self.assertItemsEqual(['b', 'c'], while_body_scope.modified)
    self.assertItemsEqual(['c'], while_body_scope.created)

    self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.used)
    self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.modified)
    self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.created)
 def _parse_and_analyze(self, test_fn, namespace):
   node = parser.parse_object(test_fn)
   node = access.resolve(node)
   node = live_values.resolve(node, namespace, {})
   node = type_info.resolve(node, {})
   return node
 def _parse_and_analyze(self, test_fn, namespace):
   node = parser.parse_object(test_fn)
   node = access.resolve(node)
   return node
Beispiel #53
0
 def test_parse_object(self):
   mod = parser.parse_object(f)
   self.assertEqual('f', mod.body[0].name)
 def _parse_and_analyze(self, test_fn, namespace):
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     return node
Beispiel #55
0
 def _parse_and_analyze(self, test_fn, namespace):
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     node = live_values.resolve(node, namespace, {})
     node = type_info.resolve(node, None)
     return node