예제 #1
0
  def _wrap_to_py_func_no_return(self, node):
    func_qn = anno.getanno(node.func, anno.Basic.QN)
    args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
    wrapper_name = self.context.namer.new_symbol(func_qn.ssf(),
                                                 args_scope.referenced)
    wrapper_args = []
    for arg in node.args:
      if anno.hasanno(arg, anno.Basic.QN):
        arg_qn = anno.getanno(arg, anno.Basic.QN)
      else:
        arg_qn = qual_names.QN('arg')
      wrapper_args.append(
          self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced))
    # TODO(mdan): Properly handle varargs, kwargs, etc.
    # TODO(mdan): This is best handled as a dynamic dispatch.
    # That way we can separate tensors from non-tensor args.
    template = """
      def wrapper(wrapper_args):
        call(wrapper_args)
        return 1
      tf.py_func(wrapper, original_args, [tf.int64])
    """
    wrapper_def, call_expr = templates.replace(
        template,
        call=node.func,
        wrapper=wrapper_name,
        original_args=gast.List(elts=node.args, ctx=None),
        wrapper_args=wrapper_args)
    anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True)

    return (wrapper_def, call_expr)
예제 #2
0
  def _rename_member_function_of_known_type(self, node):
    assert isinstance(node.func, gast.Attribute)

    type_fqn = anno.getanno(node.func, 'type_fqn')
    assert anno.hasanno(node.func, 'type')
    target_type = anno.getanno(node.func, 'type')

    if not self._should_compile(node, type_fqn):
      return node

    # TODO(mdan): We should not assume that the namer only needs the
    # member function name.
    method_name = node.func.attr
    method_object = getattr(target_type, method_name)
    new_name = self.namer.compiled_function_name(
        method_name, live_object=method_object, owner_type=target_type)
    if new_name != node.func.attr:
      # If a member function call is renamed, then the new function is no
      # longer bound to the target object. We then refactor the call from:
      #   foo.bar(...)
      # to:
      #   renamed_foo(bar, ...)
      # TODO(mdan): This risks causing duplication, if target_type is renamed.
      node.args = [node.func.value] + node.args
      node.func = gast.Name(new_name, gast.Load(), None)
    return node
예제 #3
0
    def test_if(self):
        def test_fn(x):
            if x > 0:
                x = -x
                y = 2 * x
                z = -y
            else:
                x = 2 * x
                y = -x
                u = -y
            return z, u

        node = self._parse_and_analyze(test_fn)
        if_node = node.body[0].body[0]
        self.assertScopeIs(anno.getanno(if_node, NodeAnno.BODY_SCOPE),
                           ('x', 'y'), ('x', 'y', 'z'), ('y', 'z'))
        # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
        self.assertScopeIs(
            anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'),
            ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
        self.assertScopeIs(anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
                           ('x', 'y'), ('x', 'y', 'u'), ('y', 'u'))
        self.assertScopeIs(
            anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent,
            ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
예제 #4
0
  def test_if(self):

    def test_fn(x):
      if x > 0:
        x = -x
        y = 2 * x
        z = -y
      else:
        x = 2 * x
        y = -x
        u = -y
      return z, u

    node = parser.parse_object(test_fn)
    node = access.resolve(node)

    if_node = node.body[0].body[0]
    self.assertScopeIs(
        anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'),
        ('y', 'z'))
    # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
    self.assertScopeIs(
        anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'),
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
    self.assertScopeIs(
        anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'),
        ('y', 'u'))
    self.assertScopeIs(
        anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'),
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
예제 #5
0
    def test_if(self):
        def test_fn(x):
            if x > 0:
                x = -x
                y = 2 * x
                z = -y
            else:
                x = 2 * x
                y = -x
                u = -y
            return z, u

        node = parser.parse_object(test_fn)
        node = access.resolve(node)

        if_node = node.body[0].body[0]
        self.assertScopeIs(anno.getanno(if_node, 'body_scope'), ('x', 'y'),
                           ('x', 'y', 'z'), ('y', 'z'))
        # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
        self.assertScopeIs(anno.getanno(if_node,
                                        'body_parent_scope'), ('x', 'z', 'u'),
                           ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
        self.assertScopeIs(anno.getanno(if_node, 'orelse_scope'), ('x', 'y'),
                           ('x', 'y', 'u'), ('y', 'u'))
        self.assertScopeIs(anno.getanno(if_node,
                                        'body_parent_scope'), ('x', 'z', 'u'),
                           ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
예제 #6
0
  def visit_Call(self, node):
    # If the function is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_obj = anno.getanno(node.func, 'live_val')
      if target_obj in self.nocompile_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least an argument.')
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_obj = anno.getanno(node.func, 'live_val')
      if self._function_is_compilable(target_obj):
        node = self._rename_compilable_function(node)
      else:
        raise NotImplementedError('py_func with return values')
    elif anno.hasanno(node.func, 'type_fqn'):
      node = self._rename_member_function_of_known_type(node)
    else:
      raise NotImplementedError(
          'Member function call (of unknown type): %s.' % node.func.id)
    return node
예제 #7
0
    def test_print_statement(self):
        def test_fn(a):
            b = 0
            c = 1
            print(a, b)
            return c

        node = parser.parse_object(test_fn)
        node = access.resolve(node)

        print_node = node.body[0].body[2]
        if isinstance(print_node, gast.Print):
            # Python 2
            print_args_scope = anno.getanno(print_node, 'args_scope')
        else:
            # Python 3
            assert isinstance(print_node, gast.Expr)
            # The call node should be the one being annotated.
            print_node = print_node.value
            print_args_scope = anno.getanno(print_node, 'args_scope')

        # We basically need to detect which variables are captured by the call
        # arguments.
        self.assertItemsEqual(['a', 'b'], print_args_scope.used)
        self.assertItemsEqual([], print_args_scope.modified)
        self.assertItemsEqual([], print_args_scope.created)
예제 #8
0
 def visit_Call(self, node):
   target = node.func
   if not anno.hasanno(target, 'live_val'):
     if not isinstance(target, gast.Attribute):
       # Suspecting this pattern would reach here:
       #   foo = bar
       #   foo()
       raise ValueError('Dont know how to handle dynamic functions.')
     if not isinstance(target.value, gast.Name):
       # Possible example of this kind:
       #   foo = module.Foo()
       #   foo.bar.baz()
       # TODO(mdan): This should be doable by using the FQN.
       raise ValueError('Dont know how to handle object properties yet.')
     # In the example below, object_source is 'tr.train.Optimizer()':
     #   opt = tf.train.Optimizer()
     #   opt.foo()
     if self.scope.hasval(target.value.id):
       object_source = self.scope.getval(target.value.id)
       if not anno.hasanno(object_source, 'type'):
         raise ValueError('Could not determine type of "%s". Is it dynamic?' %
                          (target.value.id))
       anno.setanno(target, 'type', anno.getanno(object_source, 'type'))
       anno.setanno(target, 'type_fqn', anno.getanno(object_source,
                                                     'type_fqn'))
     else:
       # TODO(mdan): Figure out what could the user do to get past this.
       raise ValueError('No info on "%s". Is it dynamically built?' %
                        (target.value.id))
   self.generic_visit(node)
   return node
예제 #9
0
    def visit_Assign(self, node):
        self.generic_visit(node)
        if isinstance(node.value, gast.Call):
            target = node.value.func
            if anno.hasanno(target, 'live_val'):
                target_obj = anno.getanno(target, 'live_val')
                if tf_inspect.isclass(target_obj):
                    # This is then a constructor.
                    anno.setanno(node.value, 'type', target_obj)
                    anno.setanno(node.value, 'type_fqn',
                                 anno.getanno(target, 'fqn'))
                    # TODO (mdan): Raise an error if constructor has side effects. id:2153 gh:2154
                    # We can have a whitelist of no-side-effects constructors.
                    # We can also step inside the constructor and further analyze.

        for n in node.targets:
            if isinstance(n, gast.Tuple):
                for i, e in enumerate(n.elts):
                    self.scope.setval(
                        e.id,
                        gast.Subscript(node.value,
                                       gast.Index(i),
                                       ctx=gast.Store()))
            else:
                self.scope.setval(n.id, node.value)

        return node
예제 #10
0
    def _process_variable_assignment(self, source, targets):
        if isinstance(source, gast.Call):
            func = source.func
            if anno.hasanno(func, 'live_val'):
                func_obj = anno.getanno(func, 'live_val')
                if tf_inspect.isclass(func_obj):
                    anno.setanno(source, 'is_constructor', True)
                    anno.setanno(source, 'type', func_obj)
                    anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
                    # TODO(mdan): Raise an error if constructor has side effects.
                    # We can have a whitelist of no-side-effects constructors.
                    # We can also step inside the constructor and further analyze.

        for t in targets:
            if isinstance(t, gast.Tuple):
                for i, e in enumerate(t.elts):
                    self.scope.setval(
                        anno.getanno(e, anno.Basic.QN),
                        gast.Subscript(source, gast.Index(i),
                                       ctx=gast.Store()))
            elif isinstance(t, (gast.Name, gast.Attribute)):
                self.scope.setval(anno.getanno(t, anno.Basic.QN), source)
            else:
                raise ValueError('Dont know how to handle assignment to %s' %
                                 t)
예제 #11
0
  def test_print_statement(self):

    def test_fn(a):
      b = 0
      c = 1
      print(a, b)
      return c

    node = parser.parse_object(test_fn)
    node = access.resolve(node)

    print_node = node.body[0].body[2]
    if isinstance(print_node, gast.Print):
      # Python 2
      print_args_scope = anno.getanno(print_node, 'args_scope')
    else:
      # Python 3
      assert isinstance(print_node, gast.Expr)
      # The call node should be the one being annotated.
      print_node = print_node.value
      print_args_scope = anno.getanno(print_node, 'args_scope')

    # We basically need to detect which variables are captured by the call
    # arguments.
    self.assertItemsEqual(['a', 'b'], print_args_scope.used)
    self.assertItemsEqual([], print_args_scope.modified)
    self.assertItemsEqual([], print_args_scope.created)
예제 #12
0
 def visit_Call(self, node):
   target = node.func
   if not anno.hasanno(target, 'live_val'):
     if not isinstance(target, gast.Attribute):
       # Suspecting this pattern would reach here:
       #   foo = bar
       #   foo()
       raise ValueError('Dont know how to handle dynamic functions.')
     if not isinstance(target.value, gast.Name):
       # Possible example of this kind:
       #   foo = module.Foo()
       #   foo.bar.baz()
       # TODO(mdan): This should be doable by using the FQN.
       raise ValueError('Dont know how to handle object properties yet.')
     # In the example below, object_source is 'tr.train.Optimizer()':
     #   opt = tf.train.Optimizer()
     #   opt.foo()
     if self.scope.hasval(target.value.id):
       object_source = self.scope.getval(target.value.id)
       if not anno.hasanno(object_source, 'type'):
         raise ValueError('Could not determine type of "%s". Is it dynamic?' %
                          (target.value.id))
       anno.setanno(target, 'type', anno.getanno(object_source, 'type'))
       anno.setanno(target, 'type_fqn', anno.getanno(object_source,
                                                     'type_fqn'))
     else:
       # TODO(mdan): Figure out what could the user do to get past this.
       raise ValueError('No info on "%s". Is it dynamically built?' %
                        (target.value.id))
   self.generic_visit(node)
   return node
예제 #13
0
    def _rename_compilable_function(self, node):
        assert anno.hasanno(node.func, 'live_val')
        assert anno.hasanno(node.func, 'fqn')
        target_entity = anno.getanno(node.func, 'live_val')
        target_fqn = anno.getanno(node.func, 'fqn')

        if not self._should_compile(node, target_fqn):
            return node

        if anno.hasanno(node, 'is_constructor'):
            new_name = self.context.namer.compiled_class_name(
                target_fqn, live_entity=target_entity)
            do_rename = True
        else:
            owner_type = self._determine_function_owner(target_entity)
            new_name, do_rename = self.context.namer.compiled_function_name(
                target_fqn, live_entity=target_entity, owner_type=owner_type)

        if do_rename:
            if target_entity is not None:
                if tf_inspect.ismethod(target_entity):
                    # The renaming process will transform it into a regular function.
                    # TODO(mdan): Is this complete? How does it work with nested members?
                    node.args = [node.func.value] + node.args
            node.func = gast.Name(new_name, gast.Load(), None)
        return node
예제 #14
0
  def visit_Name(self, node):
    self.generic_visit(node)
    if isinstance(node.ctx, gast.Load):
      assert anno.hasanno(node, 'is_local'), node
      symbol_is_local = anno.getanno(node, 'is_local')
      assert anno.hasanno(node, 'is_modified_since_entry'), node
      symbol_is_modified = anno.getanno(node, 'is_modified_since_entry')
      assert anno.hasanno(node, 'is_param'), node
      symbol_is_param = anno.getanno(node, 'is_param')

      if not symbol_is_local and not symbol_is_param:
        if node.id in self.literals:
          anno.setanno(node, 'live_val', self.literals[node.id])
          # TODO(mdan): Could live values have FQNs? i.e. 'a'.join()
        elif node.id in self.context.namespace:
          obj = self.context.namespace[node.id]
          anno.setanno(node, 'live_val', obj)
          anno.setanno(node, 'fqn', (obj.__name__,))
        else:
          raise ValueError('Could not resolve symbol "%s".' % node.id)
      else:
        pass
        # TODO(mdan): Attempt to trace its value through the local chain.
        # TODO(mdan): Use type annotations as fallback.

      if not symbol_is_modified:
        if node.id in self.context.arg_values:
          obj = self.context.arg_values[node.id]
          anno.setanno(node, 'live_val', obj)
          anno.setanno(node, 'fqn', (obj.__class__.__name__,))
    return node
예제 #15
0
  def _rename_compilable_function(self, node):
    assert anno.hasanno(node.func, 'live_val')
    assert anno.hasanno(node.func, 'fqn')
    target_entity = anno.getanno(node.func, 'live_val')
    target_fqn = anno.getanno(node.func, 'fqn')

    if not self._should_compile(node, target_fqn):
      return node

    if anno.hasanno(node, 'is_constructor'):
      new_name = self.context.namer.compiled_class_name(
          target_fqn, live_entity=target_entity)
      do_rename = True
    else:
      owner_type = self._determine_function_owner(target_entity)
      new_name, do_rename = self.context.namer.compiled_function_name(
          target_fqn, live_entity=target_entity, owner_type=owner_type)

    if do_rename:
      if target_entity is not None:
        if tf_inspect.ismethod(target_entity):
          # The renaming process will transform it into a regular function.
          # TODO(mdan): Is this complete? How does it work with nested members?
          node.args = [node.func.value] + node.args
      node.func = templates.replace('func_name', func_name=new_name)[0]
    return node
예제 #16
0
  def visit_Call(self, node):
    # If the function is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.nocompile_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least an argument.')
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(node, target_fqn)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if self.context.recursive:
        node = self._insert_dynamic_conversion(node)
      else:
        # Unresolved functions are allowed in non-recursive mode.
        pass
    return node
예제 #17
0
  def _process_variable_assignment(self, source, targets):
    if isinstance(source, gast.Call):
      func = source.func
      if anno.hasanno(func, 'live_val'):
        func_obj = anno.getanno(func, 'live_val')
        if tf_inspect.isclass(func_obj):
          anno.setanno(source, 'is_constructor', True)
          anno.setanno(source, 'type', func_obj)
          anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
          # TODO(mdan): Raise an error if constructor has side effects.
          # We can have a whitelist of no-side-effects constructors.
          # We can also step inside the constructor and further analyze.

    for t in targets:
      if isinstance(t, gast.Tuple):
        for i, e in enumerate(t.elts):
          self.scope.setval(e.id,
                            gast.Subscript(
                                source, gast.Index(i), ctx=gast.Store()))
      elif isinstance(t, gast.Name):
        self.scope.setval(t.id, source)
      elif isinstance(t, gast.Attribute):
        if not (isinstance(t.value, gast.Name) and t.value.id == 'self'):
          raise ValueError(
              'Dont know how to handle assignment to attributes of objects'
              ' other than "self": [%s].%s' % (t.value, t.attr))
      else:
        raise ValueError('Dont know how to handle assignment to %s' % t)
예제 #18
0
  def test_call_with_composite_names(self):

    def foo(*_):
      pass

    def test_fn(a):
      foo(a.b, a.c)
      if a > 0:
        a.b = 2
      else:
        d = 2
        d.e = a.c
        f = d.e + 1
        a.c = f

    node = self._parse_and_analyze(test_fn)
    call_node = node.body[0].body[0].value
    self.assertScopeIs(
        anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (),
        ())
    if_node = node.body[0].body[1]
    self.assertScopeIs(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ())
    self.assertScopeIs(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
        ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f'))
예제 #19
0
  def visit_Name(self, node):
    self.generic_visit(node)
    if isinstance(node.ctx, gast.Load):
      assert anno.hasanno(node, NodeAnno.IS_LOCAL), node
      symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL)
      assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node
      symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY)
      assert anno.hasanno(node, NodeAnno.IS_PARAM), node
      symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM)

      if not symbol_is_local and not symbol_is_param:
        if node.id in self.literals:
          anno.setanno(node, 'live_val', self.literals[node.id])
          # TODO(mdan): Could live values have FQNs? i.e. 'a'.join()
        elif node.id in self.context.namespace:
          obj = self.context.namespace[node.id]
          anno.setanno(node, 'live_val', obj)
          anno.setanno(node, 'fqn', (obj.__name__,))
        else:
          pass
          # TODO(mdan): Should we raise an error here?
          # Can encounter this when:
          #  * a symbol truly lacks reference
          #  * a symbol is new, like the new name of a function we just renamed.
      else:
        pass
        # TODO(mdan): Attempt to trace its value through the local chain.
        # TODO(mdan): Use type annotations as fallback.

      if not symbol_is_modified:
        if node.id in self.context.arg_values:
          obj = self.context.arg_values[node.id]
          anno.setanno(node, 'live_val', obj)
          anno.setanno(node, 'fqn', (obj.__class__.__name__,))
    return node
예제 #20
0
 def visit_Attribute(self, node):
   self.generic_visit(node)
   if anno.hasanno(node.value, 'live_val'):
     assert anno.hasanno(node.value, 'fqn')
     parent_object = anno.getanno(node.value, 'live_val')
     if not hasattr(parent_object, node.attr):
       raise AttributeError('%s has no attribute %s' % (parent_object,
                                                        node.attr))
     anno.setanno(node, 'parent_type', type(parent_object))
     anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
     anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
   # TODO(mdan): Investigate the role built-in annotations can play here.
   elif anno.hasanno(node.value, 'type'):
     parent_type = anno.getanno(node.value, 'type')
     if hasattr(parent_type, node.attr):
       # This should hold for static members like methods.
       # This would not hold for dynamic members like function attributes.
       # For the dynamic case, we simply leave the node without an annotation,
       # and let downstream consumers figure out what to do.
       anno.setanno(node, 'parent_type', parent_type)
       anno.setanno(node, 'live_val', getattr(parent_type, node.attr))
       anno.setanno(node, 'fqn',
                    anno.getanno(node.value, 'type_fqn') + (node.attr,))
   elif isinstance(node.value, gast.Name):
     stem_name = node.value
     # All nonlocal symbols should be fully resolved.
     assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name
     # TODO(mdan): Figure out what to do when calling attribute on local object
     # Maybe just leave as-is?
   return node
예제 #21
0
    def _rename_compilable_function(self, node):
        assert anno.hasanno(node.func, 'live_val')
        assert anno.hasanno(node.func, 'fqn')
        target_entity = anno.getanno(node.func, 'live_val')
        target_fqn = anno.getanno(node.func, 'fqn')

        if not self._should_compile(node, target_fqn):
            return node

        if anno.hasanno(node, 'is_constructor'):
            new_name = self.context.namer.compiled_class_name(
                target_fqn, live_entity=target_entity)
            do_rename = True
        else:
            if anno.hasanno(node.func, 'parent_type'):
                owner_type = anno.getanno(node.func, 'parent_type')
            else:
                # Fallback - not reliable.
                owner_type = inspect_utils.getmethodclass(target_entity)
            new_name, do_rename = self.context.namer.compiled_function_name(
                target_fqn, live_entity=target_entity, owner_type=owner_type)

        if do_rename:
            if target_entity is not None:
                if tf_inspect.ismethod(target_entity):
                    # The renaming process will transform it into a regular function.
                    # TODO(mdan): Is this complete? How does it work with nested members?
                    node.args = [node.func.value] + node.args
            node.func = templates.replace('func_name', func_name=new_name)[0]
        return node
예제 #22
0
    def visit_Call(self, node):
        # If the function is wrapped by one of the marker decorators,
        # consider it graph ready.
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')
            if target_entity in self.nocompile_decorators:
                if len(node.args) < 1:
                    raise ValueError(
                        'Found call to decorator function "%s", but it had no arguments. '
                        'A decorator needs at least an argument.')
                anno.setanno(node.args[0], 'graph_ready', True)

        self.generic_visit(node)
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')
            if anno.hasanno(node.func, 'fqn'):
                target_fqn = anno.getanno(node.func, 'fqn')
            if self._function_is_compilable(target_entity):
                node = self._rename_compilable_function(node)
            elif target_fqn in KNOWN_NUMPY_FUNCTIONS:
                # TODO(mdan): Should we replace these with equivalent TF ops instead?
                node = self._wrap_to_py_func_single_return(node, target_fqn)
            else:
                raise NotImplementedError(
                    'py_func with return values (unknown function)')
        else:
            if self.context.recursive:
                node = self._insert_dynamic_conversion(node)
            else:
                # Unresolved functions are allowed in non-recursive mode.
                pass
        return node
예제 #23
0
    def visit_Call(self, node):
        # If the function is wrapped by one of the marker decorators,
        # consider it graph ready.
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')
            if target_entity in self.nocompile_decorators:
                if len(node.args) < 1:
                    raise ValueError(
                        'Found call to decorator function "%s", but it had no arguments. '
                        'A decorator needs at least an argument.')
                anno.setanno(node.args[0], 'graph_ready', True)

        self.generic_visit(node)
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')
            if self._function_is_compilable(target_entity):
                node = self._rename_compilable_function(node)
            else:
                raise NotImplementedError('py_func with return values')
        else:
            if self.context.recursive:
                raise NotImplementedError('Could not resolve target function.')
            else:
                # TODO(mdan): Double check. Is this reachable code?
                pass
        return node
예제 #24
0
  def test_if(self):

    def test_fn(x):
      if x > 0:
        x = -x
        y = 2 * x
        z = -y
      else:
        x = 2 * x
        y = -x
        u = -y
      return z, u

    node = self._parse_and_analyze(test_fn)
    if_node = node.body[0].body[0]
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'),
        ('y', 'z'))
    # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'),
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'),
        ('x', 'y', 'u'), ('y', 'u'))
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'),
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
예제 #25
0
 def visit_Call(self, node):
   if anno.hasanno(node.func, 'live_val'):
     # Symbols targeted by the "set_type" marker function are assigned the data
     # type that it specified.
     if (anno.getanno(node.func, 'live_val') is
         self.context.type_annotation_func):
       # Expecting the actual type to be the second argument.
       if len(node.args) != 2:
         raise ValueError('"%s" must have exactly two parameters'
                          % self.context.type_annotation_func)
       if not anno.hasanno(node.args[0], anno.Basic.QN):
         raise ValueError('the first argument of "%s" must by a symbol'
                          % self.context.type_annotation_func)
       if not anno.hasanno(node.args[1], 'live_val'):
         raise ValueError(
             'the second argument of "%s" must be statically resolvable' %
             self.context.type_annotation_func)
       target_symbol = anno.getanno(node.args[0], anno.Basic.QN)
       element_type = anno.getanno(node.args[1], 'live_val')
       # Find the definition of this symbol and annotate it with the given
       # data type. That in turn will cause future uses of the symbol
       # to receive the same type annotation.
       definition = self.scope.getval(target_symbol)
       anno.setanno(node, 'element_type', element_type)
       anno.setanno(definition, 'element_type', element_type)
       # TODO(mdan): Should we update references between definition and here?
   return self.generic_visit(node)
예제 #26
0
  def test_function_def(self):

    def test_fn(a):

      def f(x):
        y = x * x
        return y

      b = a
      for i in a:
        c = b
        b -= f(i)
      return b, c

    node = self._parse_and_analyze(test_fn)
    fndef_node = node.body[0].body[0]

    self.assertScopeIsRmc(
        anno.getanno(fndef_node,
                     NodeAnno.BODY_SCOPE).parent, ('b', 'i', 'f', 'c', 'a'),
        ('f', 'b', 'c', 'i'), ('f', 'a', 'b', 'c', 'i'))
    self.assertScopeIsRmc(
        anno.getanno(fndef_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), (
            'x',
            'y',
        ))
예제 #27
0
    def _wrap_to_py_func_no_return(self, node):
        func_qn = anno.getanno(node.func, anno.Basic.QN)
        args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
        wrapper_name = self.context.namer.new_symbol(func_qn.ssf(),
                                                     args_scope.referenced)
        wrapper_args = []
        for arg in node.args:
            if anno.hasanno(arg, anno.Basic.QN):
                arg_qn = anno.getanno(arg, anno.Basic.QN)
            else:
                arg_qn = qual_names.QN('arg')
            wrapper_args.append(
                self.context.namer.new_symbol(arg_qn.ssf(),
                                              args_scope.referenced))
        # TODO(mdan): Properly handle varargs, kwargs, etc.
        # TODO(mdan): This is best handled as a dynamic dispatch.
        # That way we can separate tensors from non-tensor args.
        template = """
      def wrapper(wrapper_args):
        call(wrapper_args)
        return 1
      tf.py_func(wrapper, original_args, [tf.int64])
    """
        wrapper_def, call_expr = templates.replace(template,
                                                   call=node.func,
                                                   wrapper=wrapper_name,
                                                   original_args=gast.List(
                                                       elts=node.args,
                                                       ctx=None),
                                                   wrapper_args=wrapper_args)
        anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True)

        return (wrapper_def, call_expr)
예제 #28
0
  def test_call_with_composite_names(self):

    def foo(*_):
      pass

    def test_fn(a):
      foo(a.b, a.c)
      if a > 0:
        a.b = 2
      else:
        d = 2
        d.e = a.c
        f = d.e + 1
        a.c = f

    node = self._parse_and_analyze(test_fn)
    call_node = node.body[0].body[0].value
    self.assertScopeIsRmc(
        anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (),
        ())
    if_node = node.body[0].body[1]
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ())
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
        ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f'))
예제 #29
0
  def visit_Call(self, node):
    # If the function is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.nocompile_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least an argument.')
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      else:
        raise NotImplementedError('py_func with return values')
    else:
      if self.context.recursive:
        raise NotImplementedError('Could not resolve target function.')
      else:
        # TODO(mdan): Double check. Is this reachable code?
        pass
    return node
예제 #30
0
 def visit_Call(self, node):
     if anno.hasanno(node.func, 'live_val'):
         # Symbols targeted by the "set_type" marker function are assigned the data
         # type that it specified.
         if (anno.getanno(node.func, 'live_val') is
                 self.context.type_annotation_func):
             # Expecting the actual type to be the second argument.
             if len(node.args) != 2:
                 raise ValueError('"%s" must have exactly two parameters' %
                                  self.context.type_annotation_func)
             if not anno.hasanno(node.args[0], anno.Basic.QN):
                 raise ValueError(
                     'the first argument of "%s" must by a symbol' %
                     self.context.type_annotation_func)
             if not anno.hasanno(node.args[1], 'live_val'):
                 raise ValueError(
                     'the second argument of "%s" must be statically resolvable'
                     % self.context.type_annotation_func)
             target_symbol = anno.getanno(node.args[0], anno.Basic.QN)
             element_type = anno.getanno(node.args[1], 'live_val')
             # Find the definition of this symbol and annotate it with the given
             # data type. That in turn will cause future uses of the symbol
             # to receive the same type annotation.
             definition = self.scope.getval(target_symbol)
             anno.setanno(node, 'element_type', element_type)
             anno.setanno(definition, 'element_type', element_type)
             # TODO(mdan): Should we update references between definition and here?
     return self.generic_visit(node)
예제 #31
0
 def visit_Attribute(self, node):
   self.generic_visit(node)
   if anno.hasanno(node.value, 'live_val'):
     assert anno.hasanno(node.value, 'fqn')
     parent_object = anno.getanno(node.value, 'live_val')
     if not hasattr(parent_object, node.attr):
       raise AttributeError('%s has no attribute %s' % (parent_object,
                                                        node.attr))
     anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
     anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
   # TODO(mdan): Investigate the role built-in annotations can play here.
   elif anno.hasanno(node.value, 'type'):
     parent_type = anno.getanno(node.value, 'type')
     if hasattr(parent_type, node.attr):
       # This should hold for static members like methods.
       # This would not hold for dynamic members like function attributes.
       # For the dynamic case, we simply leave the node without an annotation,
       # and let downstream consumers figure out what to do.
       anno.setanno(node, 'live_val', getattr(parent_type, node.attr))
       anno.setanno(node, 'fqn',
                    anno.getanno(node.value, 'type_fqn') + (node.attr,))
   elif isinstance(node.value, gast.Name):
     stem_name = node.value
     # All nonlocal symbols should be fully resolved.
     assert anno.hasanno(stem_name, 'is_local'), stem_name
     # TODO(mdan): Figure out what to do when calling attribute on local object
     # Maybe just leave as-is?
   return node
예제 #32
0
 def _try_resolve_target(self, node):
   """Works for methods of objects of known type."""
   if anno.hasanno(node, 'live_val'):
     return anno.getanno(node, 'live_val')
   if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
     member = getattr(anno.getanno(node, 'type'), node.attr)
     return member
   return None
예제 #33
0
  def test_basic(self):
    node = ast.Name()

    self.assertFalse(anno.hasanno(node, 'foo'))
    with self.assertRaises(AttributeError):
      anno.getanno(node, 'foo')

    anno.setanno(node, 'foo', 3)
    self.assertTrue(anno.hasanno(node, 'foo'))
    self.assertEqual(3, anno.getanno(node, 'foo'))
예제 #34
0
    def test_attribute_names(self):
        def test_fn():
            return constant_op.constant(0)

        node = self._parse_and_analyze(test_fn, {'constant_op': constant_op})
        func_node = node.body[0].body[0].value.func
        self.assertEquals(constant_op.constant,
                          anno.getanno(func_node, 'live_val'))
        self.assertEquals((constant_op.__name__, 'constant'),
                          anno.getanno(func_node, 'fqn'))
예제 #35
0
  def test_attribute_names(self):

    def test_fn():
      return constant_op.constant(0)

    node = self._parse_and_analyze(test_fn, {'constant_op': constant_op})
    func_node = node.body[0].body[0].value.func
    self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val'))
    self.assertEquals((constant_op.__name__, 'constant'),
                      anno.getanno(func_node, 'fqn'))
예제 #36
0
    def test_constructor_detection(self):
        def test_fn():
            opt = training.GradientDescentOptimizer(0.1)
            return opt

        node = self._parse_and_analyze(test_fn, {'training': training})
        call_node = node.body[0].body[0].value
        self.assertEquals(training.GradientDescentOptimizer,
                          anno.getanno(call_node, 'type'))
        self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
                          anno.getanno(call_node, 'type_fqn'))
예제 #37
0
    def test_namespace(self):
        def foo():
            return 'bar'

        def test_fn():
            return foo()

        node = self._parse_and_analyze(test_fn, {'foo': foo})
        func_node = node.body[0].body[0].value.func
        self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
        self.assertEquals(('foo', ), anno.getanno(func_node, 'fqn'))
예제 #38
0
  def test_constructor_detection(self):

    def test_fn():
      opt = training.GradientDescentOptimizer(0.1)
      return opt

    node = self._parse_and_analyze(test_fn, {'training': training})
    call_node = node.body[0].body[0].value
    self.assertEquals(training.GradientDescentOptimizer,
                      anno.getanno(call_node, 'type'))
    self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
                      anno.getanno(call_node, 'type_fqn'))
예제 #39
0
 def _try_resolve_target(self, node):
   """Works for methods of objects of known type."""
   if anno.hasanno(node, 'live_val'):
     return anno.getanno(node, 'live_val')
   if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
     owner_type = anno.getanno(node, 'type')
     if hasattr(owner_type, node.attr):
       return getattr(owner_type, node.attr)
     else:
       raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' %
                        (owner_type, node.attr))
   return None
예제 #40
0
  def test_namespace(self):

    def foo():
      return 'bar'

    def test_fn():
      return foo()

    node = self._parse_and_analyze(test_fn, {'foo': foo})
    func_node = node.body[0].body[0].value.func
    self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
    self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
예제 #41
0
    def _rename_compilable_function(self, node):
        assert anno.hasanno(node.func, 'live_val')
        assert anno.hasanno(node.func, 'fqn')
        target_obj = anno.getanno(node.func, 'live_val')
        target_fqn = anno.getanno(node.func, 'fqn')

        if not self._should_compile(target_fqn):
            return node

        new_name = self.namer.compiled_function_name('.'.join(target_fqn),
                                                     live_object=target_obj)
        node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None)
        return node
예제 #42
0
  def test_attribute_names(self):

    def test_fn():
      return constant_op.constant(0)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'constant_op': constant_op}, {})

    func_node = node.body[0].body[0].value.func
    self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val'))
    self.assertEquals((constant_op.__name__, 'constant'),
                      anno.getanno(func_node, 'fqn'))
예제 #43
0
    def test_attribute_names(self):
        def test_fn():
            return constant_op.constant(0)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'constant_op': constant_op}, {})

        func_node = node.body[0].body[0].value.func
        self.assertEquals(constant_op.constant,
                          anno.getanno(func_node, 'live_val'))
        self.assertEquals((constant_op.__name__, 'constant'),
                          anno.getanno(func_node, 'fqn'))
예제 #44
0
  def _rename_compilable_function(self, node):
    assert anno.hasanno(node.func, 'live_val')
    assert anno.hasanno(node.func, 'fqn')
    target_obj = anno.getanno(node.func, 'live_val')
    target_fqn = anno.getanno(node.func, 'fqn')

    if not self._should_compile(target_fqn):
      return node

    new_name = self.namer.compiled_function_name(
        '.'.join(target_fqn), live_object=target_obj)
    node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None)
    return node
예제 #45
0
    def test_type_annotation(self):
        class Foo(object):
            pass

        def test_fn():
            f = []
            f = utils.set_element_type(f, Foo)
            return f

        node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils})
        f_def = node.body[0].body[0].value
        self.assertEqual(anno.getanno(f_def, 'element_type'), Foo)
        f_ref = node.body[0].body[1].value
        self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
예제 #46
0
    def test_while(self):
        def test_fn(a):
            b = a
            while b > 0:
                c = b
                b -= 1
            return b, c

        node = self._parse_and_analyze(test_fn)
        while_node = node.body[0].body[1]
        self.assertScopeIs(anno.getanno(while_node, 'body_scope'), ('b', ),
                           ('b', 'c'), ('c', ))
        self.assertScopeIs(anno.getanno(while_node, 'body_parent_scope'),
                           ('a', 'b', 'c'), ('b', 'c'), ('a', 'b', 'c'))
예제 #47
0
    def test_namespace(self):
        def foo():
            return 'bar'

        def test_fn():
            return foo()

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'foo': foo}, {})

        func_node = node.body[0].body[0].value.func
        self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
        self.assertEquals(('foo', ), anno.getanno(func_node, 'fqn'))
예제 #48
0
 def visit_Attribute(self, node):
     self.generic_visit(node)
     if anno.hasanno(node.value, 'live_val'):
         assert anno.hasanno(node.value, 'fqn')
         parent_object = anno.getanno(node.value, 'live_val')
         if not hasattr(parent_object, node.attr):
             raise AttributeError('%s has no attribute %s' %
                                  (parent_object, node.attr))
         anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
         anno.setanno(node, 'fqn',
                      anno.getanno(node.value, 'fqn') + (node.attr, ))
     # TODO (mdan): Figure out what to do when calling attribute on local object. id:1287 gh:1288
     # Maybe just leave as-is?
     return node
예제 #49
0
 def visit_Name(self, node):
     self.generic_visit(node)
     if isinstance(node.ctx, gast.Param):
         self._process_function_arg(node.id)
     elif isinstance(node.ctx, gast.Load) and self.scope.hasval(node.id):
         # E.g. if we had
         # a = b
         # then for future references to `a` we should have traced_source = `b`
         traced_source = self.scope.getval(node.id)
         if anno.hasanno(traced_source, 'type'):
             anno.setanno(node, 'type', anno.getanno(traced_source, 'type'))
             anno.setanno(node, 'type_fqn',
                          anno.getanno(traced_source, 'type_fqn'))
     return node
예제 #50
0
 def visit_Name(self, node):
   self.generic_visit(node)
   qn = anno.getanno(node, anno.Basic.QN)
   if isinstance(node.ctx, gast.Param):
     self._process_function_arg(qn)
   elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn):
     # E.g. if we had
     # a = b
     # then for future references to `a` we should have traced_source = `b`
     traced_source = self.scope.getval(qn)
     if anno.hasanno(traced_source, 'type'):
       anno.setanno(node, 'type', anno.getanno(traced_source, 'type'))
       anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn'))
   return node
예제 #51
0
  def test_type_annotation(self):

    class Foo(object):
      pass

    def test_fn():
      f = []
      f = utils.set_element_type(f, Foo)
      return f

    node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils})
    f_def = node.body[0].body[0].value
    self.assertEqual(anno.getanno(f_def, 'element_type'), Foo)
    f_ref = node.body[0].body[1].value
    self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
예제 #52
0
  def test_namespace(self):

    def foo():
      return 'bar'

    def test_fn():
      return foo()

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'foo': foo}, {})

    func_node = node.body[0].body[0].value.func
    self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
    self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
예제 #53
0
  def test_class_members_in_with_stmt(self):

    def test_fn(x):
      with session.Session() as sess:
        sess.run(x)

    node = self._parse_and_analyze(test_fn, {'session': session})
    constructor_call = node.body[0].body[0].items[0].context_expr
    self.assertEquals(session.Session, anno.getanno(constructor_call, 'type'))
    self.assertEquals((session.__name__, 'Session'),
                      anno.getanno(constructor_call, 'type_fqn'))

    method_call = node.body[0].body[0].body[0].value.func
    self.assertEquals(session.Session.run, anno.getanno(method_call,
                                                        'live_val'))
예제 #54
0
  def visit_For(self, node):
    self.generic_visit(node)
    body_scope = anno.getanno(node, 'body_scope')

    # TODO(mdan): Distinguish between `for i in n` and `for i in range(n)`
    # Or maybe we should replace range with tf.range?

    if anno.hasanno(node, 'extra_cond'):

      def template(loop_iter, target, body, i, n, extra_cond):  # pylint:disable=unused-argument
        i = 0
        n = len(loop_iter)  # pylint:disable=undefined-variable
        while i < n and extra_cond:
          # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
          target = loop_iter[i]
          body  # pylint:disable=pointless-statement
          i += 1

      return templates.replace(
          template,
          loop_iter=node.iter,
          target=node.target,
          body=node.body,
          i=gast.Name(
              self.namer.new_symbol('i', body_scope.referenced), None, None),
          n=gast.Name(
              self.namer.new_symbol('n', body_scope.referenced), None, None),
          extra_cond=anno.getanno(node, 'extra_cond'))
    else:

      def template(loop_iter, target, body, i, n):  # pylint:disable=unused-argument
        i = 0
        n = len(loop_iter)  # pylint:disable=undefined-variable
        while i < n:
          # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
          target = loop_iter[i]
          body  # pylint:disable=pointless-statement
          i += 1

      return templates.replace(
          template,
          loop_iter=node.iter,
          target=node.target,
          body=node.body,
          i=gast.Name(
              self.namer.new_symbol('i', body_scope.referenced), None, None),
          n=gast.Name(
              self.namer.new_symbol('n', body_scope.referenced), None, None))