Пример #1
0
 def visit_Call(self, node):
   if anno.hasanno(node.func, 'live_val'):
     # Symbols targeted by the "set_type" marker function are assigned the data
     # type that it specified.
     if (anno.getanno(node.func, 'live_val') is
         self.context.type_annotation_func):
       # Expecting the actual type to be the second argument.
       if len(node.args) != 2:
         raise ValueError('"%s" must have exactly two parameters'
                          % self.context.type_annotation_func)
       if not anno.hasanno(node.args[0], anno.Basic.QN):
         raise ValueError('the first argument of "%s" must by a symbol'
                          % self.context.type_annotation_func)
       if not anno.hasanno(node.args[1], 'live_val'):
         raise ValueError(
             'the second argument of "%s" must be statically resolvable' %
             self.context.type_annotation_func)
       target_symbol = anno.getanno(node.args[0], anno.Basic.QN)
       element_type = anno.getanno(node.args[1], 'live_val')
       # Find the definition of this symbol and annotate it with the given
       # data type. That in turn will cause future uses of the symbol
       # to receive the same type annotation.
       definition = self.scope.getval(target_symbol)
       anno.setanno(node, 'element_type', element_type)
       anno.setanno(definition, 'element_type', element_type)
       # TODO(mdan): Should we update references between definition and here?
   return self.generic_visit(node)
Пример #2
0
  def _wrap_to_py_func_no_return(self, node):
    args_scope = anno.getanno(node, 'args_scope')
    # TODO(mdan): Properly handle varargs, kwargs, etc.
    args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used)

    # pylint:disable=undefined-variable,unused-argument,function-redefined

    def template(call, wrapper, args):

      def wrapper(args):
        call(args)
        return 1

      tf.py_func(wrapper, [args], [tf.int64])

    # pylint:enable=undefined-variable,unused-argument,function-redefined

    wrapper_name = self.namer.compiled_function_name(node.func.id)
    wrapper_def, call_expr = templates.replace(
        template,
        call=node.func,
        wrapper=gast.Name(wrapper_name, gast.Load(), None),
        args=args)
    anno.setanno(call_expr.value, 'args_scope', args_scope)
    anno.setanno(wrapper_def, 'skip_processing', True)

    return (wrapper_def, call_expr)
Пример #3
0
  def visit_Call(self, node):
    # If the function is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.nocompile_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least an argument.')
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      else:
        raise NotImplementedError('py_func with return values')
    else:
      if self.context.recursive:
        raise NotImplementedError('Could not resolve target function.')
      else:
        # TODO(mdan): Double check. Is this reachable code?
        pass
    return node
Пример #4
0
  def visit_Call(self, node):
    # If the function is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.nocompile_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least an argument.')
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(node, target_fqn)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if self.context.recursive:
        node = self._insert_dynamic_conversion(node)
      else:
        # Unresolved functions are allowed in non-recursive mode.
        pass
    return node
Пример #5
0
 def visit_Call(self, node):
   target = node.func
   if not anno.hasanno(target, 'live_val'):
     if not isinstance(target, gast.Attribute):
       # Suspecting this pattern would reach here:
       #   foo = bar
       #   foo()
       raise ValueError('Dont know how to handle dynamic functions.')
     if not isinstance(target.value, gast.Name):
       # Possible example of this kind:
       #   foo = module.Foo()
       #   foo.bar.baz()
       # TODO(mdan): This should be doable by using the FQN.
       raise ValueError('Dont know how to handle object properties yet.')
     # In the example below, object_source is 'tr.train.Optimizer()':
     #   opt = tf.train.Optimizer()
     #   opt.foo()
     if self.scope.hasval(target.value.id):
       object_source = self.scope.getval(target.value.id)
       if not anno.hasanno(object_source, 'type'):
         raise ValueError('Could not determine type of "%s". Is it dynamic?' %
                          (target.value.id))
       anno.setanno(target, 'type', anno.getanno(object_source, 'type'))
       anno.setanno(target, 'type_fqn', anno.getanno(object_source,
                                                     'type_fqn'))
     else:
       # TODO(mdan): Figure out what could the user do to get past this.
       raise ValueError('No info on "%s". Is it dynamically built?' %
                        (target.value.id))
   self.generic_visit(node)
   return node
Пример #6
0
  def visit_Name(self, node):
    self.generic_visit(node)
    if isinstance(node.ctx, gast.Load):
      assert anno.hasanno(node, NodeAnno.IS_LOCAL), node
      symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL)
      assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node
      symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY)
      assert anno.hasanno(node, NodeAnno.IS_PARAM), node
      symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM)

      if not symbol_is_local and not symbol_is_param:
        if node.id in self.literals:
          anno.setanno(node, 'live_val', self.literals[node.id])
          # TODO(mdan): Could live values have FQNs? i.e. 'a'.join()
        elif node.id in self.context.namespace:
          obj = self.context.namespace[node.id]
          anno.setanno(node, 'live_val', obj)
          anno.setanno(node, 'fqn', (obj.__name__,))
        else:
          pass
          # TODO(mdan): Should we raise an error here?
          # Can encounter this when:
          #  * a symbol truly lacks reference
          #  * a symbol is new, like the new name of a function we just renamed.
      else:
        pass
        # TODO(mdan): Attempt to trace its value through the local chain.
        # TODO(mdan): Use type annotations as fallback.

      if not symbol_is_modified:
        if node.id in self.context.arg_values:
          obj = self.context.arg_values[node.id]
          anno.setanno(node, 'live_val', obj)
          anno.setanno(node, 'fqn', (obj.__class__.__name__,))
    return node
Пример #7
0
  def visit_Name(self, node):
    self.generic_visit(node)
    if isinstance(node.ctx, gast.Load):
      assert anno.hasanno(node, 'is_local'), node
      symbol_is_local = anno.getanno(node, 'is_local')
      assert anno.hasanno(node, 'is_modified_since_entry'), node
      symbol_is_modified = anno.getanno(node, 'is_modified_since_entry')
      assert anno.hasanno(node, 'is_param'), node
      symbol_is_param = anno.getanno(node, 'is_param')

      if not symbol_is_local and not symbol_is_param:
        if node.id in self.literals:
          anno.setanno(node, 'live_val', self.literals[node.id])
          # TODO(mdan): Could live values have FQNs? i.e. 'a'.join()
        elif node.id in self.context.namespace:
          obj = self.context.namespace[node.id]
          anno.setanno(node, 'live_val', obj)
          anno.setanno(node, 'fqn', (obj.__name__,))
        else:
          raise ValueError('Could not resolve symbol "%s".' % node.id)
      else:
        pass
        # TODO(mdan): Attempt to trace its value through the local chain.
        # TODO(mdan): Use type annotations as fallback.

      if not symbol_is_modified:
        if node.id in self.context.arg_values:
          obj = self.context.arg_values[node.id]
          anno.setanno(node, 'live_val', obj)
          anno.setanno(node, 'fqn', (obj.__class__.__name__,))
    return node
Пример #8
0
  def visit_Call(self, node):
    # If the function is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_obj = anno.getanno(node.func, 'live_val')
      if target_obj in self.nocompile_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least an argument.')
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_obj = anno.getanno(node.func, 'live_val')
      if self._function_is_compilable(target_obj):
        node = self._rename_compilable_function(node)
      else:
        raise NotImplementedError('py_func with return values')
    elif anno.hasanno(node.func, 'type_fqn'):
      node = self._rename_member_function_of_known_type(node)
    else:
      raise NotImplementedError(
          'Member function call (of unknown type): %s.' % node.func.id)
    return node
Пример #9
0
  def _wrap_to_py_func_no_return(self, node):
    func_qn = anno.getanno(node.func, anno.Basic.QN)
    args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
    wrapper_name = self.context.namer.new_symbol(func_qn.ssf(),
                                                 args_scope.referenced)
    wrapper_args = []
    for arg in node.args:
      if anno.hasanno(arg, anno.Basic.QN):
        arg_qn = anno.getanno(arg, anno.Basic.QN)
      else:
        arg_qn = qual_names.QN('arg')
      wrapper_args.append(
          self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced))
    # TODO(mdan): Properly handle varargs, kwargs, etc.
    # TODO(mdan): This is best handled as a dynamic dispatch.
    # That way we can separate tensors from non-tensor args.
    template = """
      def wrapper(wrapper_args):
        call(wrapper_args)
        return 1
      tf.py_func(wrapper, original_args, [tf.int64])
    """
    wrapper_def, call_expr = templates.replace(
        template,
        call=node.func,
        wrapper=wrapper_name,
        original_args=gast.List(elts=node.args, ctx=None),
        wrapper_args=wrapper_args)
    anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True)

    return (wrapper_def, call_expr)
Пример #10
0
 def visit_With(self, node):
   current_scope = self.scope
   with_scope = Scope(current_scope, isolated=False)
   self.scope = with_scope
   self.generic_visit(node)
   anno.setanno(node, NodeAnno.BODY_SCOPE, with_scope)
   self.scope = current_scope
   return node
Пример #11
0
 def _as_function(self, func_name, args):
   template = """
     func_name(args)
   """
   replacement = templates.replace_as_expression(
       template, func_name=parser.parse_expression(func_name), args=args)
   anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
   return replacement
Пример #12
0
 def visit_Print(self, node):
   current_scope = self.scope
   args_scope = Scope(current_scope)
   self.scope = args_scope
   for n in node.values:
     self.visit(n)
   anno.setanno(node, 'args_scope', args_scope)
   self.scope = current_scope
   return node
Пример #13
0
 def _inline_tf_op(self, op_name, args):
   template = """
     tf.op_name(args)
   """
   replacement = templates.replace(template, op_name=op_name, args=args)
   # It's a body with a single expression, we want its value.
   n = replacement[0].value
   anno.setanno(n, SAFE_BOOLEAN_OPERAND, True)
   return n
Пример #14
0
 def _process_block_node(self, node, block, scope_name):
   current_scope = self.scope
   block_scope = Scope(current_scope, isolated=False)
   self.scope = block_scope
   for n in block:
     self.visit(n)
   anno.setanno(node, '%s_scope' % scope_name, block_scope)
   self.scope = current_scope
   return node
Пример #15
0
  def test_basic(self):
    node = ast.Name()

    self.assertFalse(anno.hasanno(node, 'foo'))
    with self.assertRaises(AttributeError):
      anno.getanno(node, 'foo')

    anno.setanno(node, 'foo', 3)
    self.assertTrue(anno.hasanno(node, 'foo'))
    self.assertEqual(3, anno.getanno(node, 'foo'))
Пример #16
0
  def test_copyanno(self):
    node_1 = ast.Name()
    anno.setanno(node_1, 'foo', 3)

    node_2 = ast.Name()
    anno.copyanno(node_1, node_2, 'foo')
    anno.copyanno(node_1, node_2, 'bar')

    self.assertTrue(anno.hasanno(node_2, 'foo'))
    self.assertFalse(anno.hasanno(node_2, 'bar'))
Пример #17
0
 def _process_function_arg(self, arg_name):
   str_name = str(arg_name)
   if self.function_level == 1 and str_name in self.context.arg_types:
     # Forge a node to hold the type information, so that method calls on
     # it can resolve the type.
     type_holder = arg_name.ast()
     type_string, type_obj = self.context.arg_types[str_name]
     anno.setanno(type_holder, 'type', type_obj)
     anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
     self.scope.setval(arg_name, type_holder)
Пример #18
0
 def visit_FunctionDef(self, node):
   if self.scope:
     qn = QN(node.name)
     self.scope.mark_write(qn)
   current_scope = self.scope
   fndef_scope = Scope(current_scope, isolated=True)
   self.scope = fndef_scope
   self.generic_visit(node)
   anno.setanno(node, NodeAnno.BODY_SCOPE, fndef_scope)
   self.scope = current_scope
   return node
Пример #19
0
 def visit_Call(self, node):
   current_scope = self.scope
   args_scope = Scope(current_scope, isolated=False)
   self.scope = args_scope
   for n in node.args:
     self.visit(n)
   # TODO(mdan): Account starargs, kwargs
   for n in node.keywords:
     self.visit(n)
   anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope)
   self.scope = current_scope
   self.visit(node.func)
   return node
Пример #20
0
 def visit_Name(self, node):
   self.generic_visit(node)
   if isinstance(node.ctx, gast.Param):
     self.scope.setval(node.id, gast.Name(node.id, gast.Load(), None))
     if self.function_level == 1 and node.id in self.context.arg_types:
       # Forge a node to hold the type information, so that method calls on
       # it can resolve the type.
       type_holder = gast.Name(node.id, gast.Load(), None)
       type_string, type_obj = self.context.arg_types[node.id]
       anno.setanno(type_holder, 'type', type_obj)
       anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
       self.scope.setval(node.id, type_holder)
   return node
Пример #21
0
 def visit_Call(self, node):
   current_scope = self.scope
   args_scope = Scope(current_scope)
   self.scope = args_scope
   for n in node.args:
     self.visit(n)
   # TODO(mdan): Account starargs, kwargs
   for n in node.keywords:
     self.visit(n)
   anno.setanno(node, 'args_scope', args_scope)
   self.scope = current_scope
   self.visit(node.func)
   return node
Пример #22
0
 def visit_Name(self, node):
   self.generic_visit(node)
   qn = anno.getanno(node, anno.Basic.QN)
   if isinstance(node.ctx, gast.Param):
     self._process_function_arg(qn)
   elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn):
     # E.g. if we had
     # a = b
     # then for future references to `a` we should have traced_source = `b`
     traced_source = self.scope.getval(qn)
     if anno.hasanno(traced_source, 'type'):
       anno.setanno(node, 'type', anno.getanno(traced_source, 'type'))
       anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn'))
   return node
Пример #23
0
 def visit_While(self, node):
   self.visit(node.test)
   current_scope = self.scope
   anno.setanno(node, 'parent_scope', current_scope)
   body_scope = Scope(current_scope, isolated=False)
   self.scope = body_scope
   for n in node.body:
     self.visit(n)
   anno.setanno(node, 'body_scope', body_scope)
   if node.orelse:
     raise NotImplementedError()
     # TODO(mdan): Add support for orelse.
   self.scope = current_scope
   return node
Пример #24
0
 def visit_Attribute(self, node):
   self.generic_visit(node)
   if anno.hasanno(node.value, 'live_val'):
     assert anno.hasanno(node.value, 'fqn')
     parent_object = anno.getanno(node.value, 'live_val')
     if not hasattr(parent_object, node.attr):
       raise AttributeError('%s has no attribute %s' % (parent_object,
                                                        node.attr))
     anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
     anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
   # TODO(mdan): Investigate the role built-in annotations can play here.
   elif anno.hasanno(node.value, 'type'):
     parent_type = anno.getanno(node.value, 'type')
     if hasattr(parent_type, node.attr):
       # This should hold for static members like methods.
       # This would not hold for dynamic members like function attributes.
       # For the dynamic case, we simply leave the node without an annotation,
       # and let downstream consumers figure out what to do.
       anno.setanno(node, 'live_val', getattr(parent_type, node.attr))
       anno.setanno(node, 'fqn',
                    anno.getanno(node.value, 'type_fqn') + (node.attr,))
   elif isinstance(node.value, gast.Name):
     stem_name = node.value
     # All nonlocal symbols should be fully resolved.
     assert anno.hasanno(stem_name, 'is_local'), stem_name
     # TODO(mdan): Figure out what to do when calling attribute on local object
     # Maybe just leave as-is?
   return node
 def _inline_tf_op(self, op_name, args):
   if 'py2tf_utils' in op_name:
     # TODO(alexbw): explicitly spelling out the attribute function name
     # until fix for issue highlighted in cl/188931581 lands.
     template = """
     py2tf_utils.op_name(args)
   """
     op_name = op_name.replace('py2tf_utils.', '')
   else:
     template = """
       tf.op_name(args)
     """
   replacement = templates.replace_as_expression(
       template, op_name=op_name, args=args)
   anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
   return replacement
Пример #26
0
 def visit_Attribute(self, node):
   self.generic_visit(node)
   if anno.hasanno(node.value, 'live_val'):
     assert anno.hasanno(node.value, 'fqn')
     parent_object = anno.getanno(node.value, 'live_val')
     if not hasattr(parent_object, node.attr):
       raise AttributeError('%s has no attribute %s' % (parent_object,
                                                        node.attr))
     anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
     anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
   elif isinstance(node.value, gast.Name):
     stem_name = node.value
     # All nonlocal symbols should be fully resolved.
     assert anno.hasanno(stem_name, 'is_local'), stem_name
     assert anno.getanno(stem_name, 'is_local'), stem_name
     # TODO(mdan): Figure out what to do when calling attribute on local object
     # Maybe just leave as-is?
   return node
Пример #27
0
 def visit_Subscript(self, node):
   node = self.generic_visit(node)
   s = node.slice
   if not isinstance(s, gast.Index):
     # TODO(mdan): Support range and multi-dimensional indices.
     # Continuing silently because some demos use these.
     return node
   if isinstance(s.value, gast.Num):
     subscript = QN(NumberLiteral(s.value.n))
   elif isinstance(s.value, gast.Str):
     subscript = QN(StringLiteral(s.value.s))
   else:
     subscript = anno.getanno(node.slice.value, anno.Basic.QN)
   if anno.hasanno(node.value, anno.Basic.QN):
     anno.setanno(node, anno.Basic.QN,
                  QN(anno.getanno(node.value, anno.Basic.QN),
                     subscript=subscript))
   return node
Пример #28
0
 def visit_Name(self, node):
   # TODO(mdan): This is insufficient for object fields, e.g. hp.learning_rate.
   self.generic_visit(node)
   if isinstance(node.ctx, gast.Store):
     self.scope.mark_write(node.id)
   elif isinstance(node.ctx, gast.Load):
     anno.setanno(node, 'is_local', self.scope.has(node.id))
     self.scope.mark_read(node.id)
   elif isinstance(node.ctx, gast.Param):
     # Param contexts appear in function defs, so they have the meaning of
     # defining a variable.
     # TODO(mdan): This bay be incorrect with nested functions.
     # For nested functions, we'll have to add the notion of hiding args from
     # the parent scope, not writing to them.
     self.scope.mark_write(node.id)
   else:
     raise ValueError('Unknown context %s for node %s.' % (type(node.ctx),
                                                           node.id))
   return node
Пример #29
0
  def _wrap_to_py_func_no_return(self, node):
    args_scope = anno.getanno(node, 'args_scope')
    # TODO(mdan): Properly handle varargs, kwargs, etc.
    template = """
      def wrapper(args):
        call(args)
        return 1
      tf.py_func(wrapper, [args], [tf.int64])
    """
    wrapper_def, call_expr = templates.replace(
        template,
        call=node.func,
        wrapper=self.context.namer.compiled_function_name(node.func.id)[0],
        args=tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used))
    anno.setanno(call_expr.value, 'args_scope', args_scope)
    # TODO(mdan): Rename this annotation to 'graph_ready'
    anno.setanno(wrapper_def, 'skip_processing', True)

    return (wrapper_def, call_expr)
Пример #30
0
  def visit_For(self, node):
    self.generic_visit(node.target)
    self.generic_visit(node.iter)
    scope = anno.getanno(node, 'body_scope')

    break_var = self.namer.new_symbol('break_requested', scope.referenced)
    self.break_uses.append([False, break_var])
    node.body = self._manual_visit_list(node.body)
    if self.break_uses[-1][0]:
      anno.setanno(node, 'extra_cond',
                   gast.UnaryOp(gast.Not(),
                                gast.Name(break_var, gast.Load(), None)))
      final_nodes = [self._create_break_init(), node]
    else:
      final_nodes = node
    self.break_uses.pop()

    for n in node.orelse:
      self.generic_visit(n)
    return final_nodes
    def visit_For(self, node):
        self.generic_visit(node.target)
        self.generic_visit(node.iter)
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)

        break_var = self.context.namer.new_symbol('break_requested',
                                                  scope.referenced)
        self.break_uses.append([False, break_var])
        node.body = self._manual_visit_list(node.body)
        if self.break_uses[-1][0]:
            anno.setanno(
                node, 'extra_cond',
                gast.UnaryOp(gast.Not(), gast.Name(break_var, gast.Load(),
                                                   None)))
            final_nodes = [self._create_break_init(), node]
        else:
            final_nodes = node
        self.break_uses.pop()

        for n in node.orelse:
            self.generic_visit(n)
        return final_nodes
Пример #32
0
 def generic_visit(self, node):
   new_fields = {}
   for f in node._fields:
     if f.startswith('__'):
       continue
     if not hasattr(node, f):
       continue
     v = getattr(node, f)
     if isinstance(v, list):
       v = [self.generic_visit(n) for n in v]
     elif isinstance(v, tuple):
       v = tuple(self.generic_visit(n) for n in v)
     elif isinstance(v, (gast.AST, ast.AST)):
       v = self.generic_visit(v)
     else:
       # Assume everything else is a value type.
       pass
     new_fields[f] = v
   new_node = type(node)(**new_fields)
   if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
     anno.setanno(new_node, anno.Basic.SKIP_PROCESSING, True)
   return new_node
Пример #33
0
    def visit_Name(self, node):
        self.generic_visit(node)
        if isinstance(node.ctx, gast.Load):
            assert anno.hasanno(node, NodeAnno.IS_LOCAL), node
            symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL)
            assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node
            symbol_is_modified = anno.getanno(node,
                                              NodeAnno.IS_MODIFIED_SINCE_ENTRY)
            assert anno.hasanno(node, NodeAnno.IS_PARAM), node
            symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM)

            if not symbol_is_local and not symbol_is_param:
                if node.id in self.literals:
                    anno.setanno(node, 'live_val', self.literals[node.id])
                elif node.id in self.context.namespace:
                    obj = self.context.namespace[node.id]
                    anno.setanno(node, 'live_val', obj)
                    if hasattr(obj, '__name__'):
                        # If the symbol value is for example a primitive, then it will not
                        # have a name.
                        anno.setanno(node, 'fqn', (obj.__name__, ))
                else:
                    pass
                    # TODO(mdan): Should we raise an error here?
                    # Can encounter this when:
                    #  * a symbol truly lacks reference
                    #  * a symbol is new, like the new name of a function we just renamed.
            else:
                pass
                # TODO(mdan): Attempt to trace its value through the local chain.
                # TODO(mdan): Use type annotations as fallback.

            if not symbol_is_modified:
                if node.id in self.context.arg_values:
                    obj = self.context.arg_values[node.id]
                    anno.setanno(node, 'live_val', obj)
                    anno.setanno(node, 'fqn', (obj.__class__.__name__, ))
        return node
    def visit_Compare(self, node):
        node = self.generic_visit(node)
        ops_and_comps = list(zip(node.ops, node.comparators))
        left = node.left
        op_tree = None

        # Repeated comparisons are converted to conjunctions:
        #   a < b < c   ->   a < b and b < c
        while ops_and_comps:
            op, right = ops_and_comps.pop(0)
            binary_comparison = self._inline_tf_op(self._matching_tf_op(op),
                                                   (left, right))
            if isinstance(left, gast.Name) and isinstance(right, gast.Name):
                anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True)
            if op_tree:
                self._expect_simple_symbol(right)
                op_tree = self._inline_tf_op('logical_and',
                                             (binary_comparison, op_tree))
            else:
                op_tree = binary_comparison
            left = right
        assert op_tree is not None
        return op_tree
Пример #35
0
 def visit_Call(self, node):
   target = node.func
   if not anno.hasanno(target, 'live_val'):
     if not isinstance(target, gast.Attribute):
       # Suspecting this pattern would reach here:
       #   foo = bar
       #   foo()
       raise ValueError('Dont know how to handle dynamic functions.')
     if not isinstance(target.value, gast.Name):
       # Possible example of this kind:
       #   foo = module.Foo()
       #   foo.bar.baz()
       # TODO(mdan): This should be doable by using the FQN.
       raise ValueError('Dont know how to handle object properties yet.')
     # In the example below, object_source is 'tr.train.Optimizer()':
     #   opt = tf.train.Optimizer()
     #   opt.foo()
     object_source = self.scope.getval(target.value.id)
     if not anno.hasanno(object_source, 'type'):
       raise ValueError('Could not determine type of "%s". Is it dynamic?' %
                        (target.value.id))
     anno.setanno(target, 'type_fqn', anno.getanno(object_source, 'type_fqn'))
   self.generic_visit(node)
   return node
Пример #36
0
  def visit_Assign(self, node):
    self.generic_visit(node)
    if isinstance(node.value, gast.Call):
      target = node.value.func
      if anno.hasanno(target, 'live_val'):
        target_obj = anno.getanno(target, 'live_val')
        if tf_inspect.isclass(target_obj):
          # This is then a constructor.
          anno.setanno(node.value, 'type', target_obj)
          anno.setanno(node.value, 'type_fqn', anno.getanno(target, 'fqn'))
          # TODO(mdan): Raise an error if constructor has side effects.
          # We can have a whitelist of no-side-effects constructors.
          # We can also step inside the constructor and further analyze.

    for n in node.targets:
      if isinstance(n, gast.Tuple):
        for i, e in enumerate(n.elts):
          self.scope.setval(e.id,
                            gast.Subscript(
                                node.value, gast.Index(i), ctx=gast.Store()))
      else:
        self.scope.setval(n.id, node.value)

    return node
Пример #37
0
 def visit_Print(self, node):
     self.generic_visit(node)
     for n in node.values:
         n.ctx = gast.Param()
     call_node = gast.Call(func=gast.Name('print', gast.Load(), None),
                           args=node.values,
                           keywords=[])
     anno.setanno(call_node.func, 'live_val', print)
     anno.setanno(call_node.func, 'fqn', 'print')
     anno.setanno(call_node, 'args_scope', anno.getanno(node, 'args_scope'))
     node = gast.Expr(call_node)
     return node
Пример #38
0
 def visit_Name(self, node):
     self.generic_visit(node)
     qn = anno.getanno(node, anno.Basic.QN)
     if isinstance(node.ctx, gast.Param):
         self._process_function_arg(qn)
     elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn):
         # E.g. if we had
         # a = b
         # then for future references to `a` we should have definition = `b`
         definition = self.scope.getval(qn)
         if anno.hasanno(definition, 'type'):
             anno.setanno(node, 'type', anno.getanno(definition, 'type'))
             anno.setanno(node, 'type_fqn',
                          anno.getanno(definition, 'type_fqn'))
         if anno.hasanno(definition, 'element_type'):
             anno.setanno(node, 'element_type',
                          anno.getanno(definition, 'element_type'))
     return node
Пример #39
0
 def visit_Name(self, node):
     self.generic_visit(node)
     if isinstance(node.ctx, gast.Load):
         assert anno.hasanno(node, 'is_local'), node
         symbol_is_local = anno.getanno(node, 'is_local')
         if not symbol_is_local:
             if node.id in self.literals:
                 anno.setanno(node, 'live_val', self.literals[node.id])
                 # TODO (mdan): Could live values have FQNs? i.e. 'a'.join() id:2150 gh:2151
             elif node.id in self.namespace:
                 obj = self.namespace[node.id]
                 anno.setanno(node, 'live_val', obj)
                 anno.setanno(node, 'fqn', (obj.__name__, ))
             else:
                 raise ValueError('Could not find global symbol %s.' %
                                  node.id)
         else:
             pass
             # TODO (mdan): Attempt to trace its value through the local chain. id:919 gh:920
             # TODO (mdan): Use type annotations as fallback. id:1530 gh:1531
     return node
Пример #40
0
    def _process_variable_assignment(self, source, targets):
        if isinstance(source, gast.Call):
            func = source.func
            if anno.hasanno(func, 'live_val'):
                func_obj = anno.getanno(func, 'live_val')
                if tf_inspect.isclass(func_obj):
                    anno.setanno(source, 'is_constructor', True)
                    anno.setanno(source, 'type', func_obj)
                    anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
                    # TODO(mdan): Raise an error if constructor has side effects.
                    # We can have a whitelist of no-side-effects constructors.
                    # We can also step inside the constructor and further analyze.

        for t in targets:
            if isinstance(t, gast.Tuple):
                # need to recurse on the case of assigning nested tuples,
                # ex. a, (b, c) = f()
                self._process_tuple_assignment(source, t)
            elif isinstance(t, (gast.Name, gast.Attribute)):
                self.scope.setval(anno.getanno(t, anno.Basic.QN), source)
            else:
                raise ValueError('Dont know how to handle assignment to %s' %
                                 t)
Пример #41
0
 def visit_While(self, node):
     anno.setanno(node, 'parent_scope_values', self.scope.copy())
     self.generic_visit(node)
     return node
Пример #42
0
 def visit_Attribute(self, node):
   self.generic_visit(node)
   if anno.hasanno(node.value, 'live_val'):
     assert anno.hasanno(node.value, 'fqn')
     parent_object = anno.getanno(node.value, 'live_val')
     if not hasattr(parent_object, node.attr):
       raise AttributeError('%s has no attribute %s' % (parent_object,
                                                        node.attr))
     anno.setanno(node, 'parent_type', type(parent_object))
     anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
     anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
   # TODO(mdan): Investigate the role built-in annotations can play here.
   elif anno.hasanno(node.value, 'type'):
     parent_type = anno.getanno(node.value, 'type')
     if hasattr(parent_type, node.attr):
       # This should hold for static members like methods.
       # This would not hold for dynamic members like function attributes.
       # For the dynamic case, we simply leave the node without an annotation,
       # and let downstream consumers figure out what to do.
       anno.setanno(node, 'parent_type', parent_type)
       anno.setanno(node, 'live_val', getattr(parent_type, node.attr))
       anno.setanno(node, 'fqn',
                    anno.getanno(node.value, 'type_fqn') + (node.attr,))
   elif isinstance(node.value, gast.Name):
     stem_name = node.value
     # All nonlocal symbols should be fully resolved.
     assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name
     # TODO(mdan): Figure out what to do when calling attribute on local object
     # Maybe just leave as-is?
   return node
Пример #43
0
 def visit_ClassDef(self, node):
   self.generic_visit(node)
   anno.setanno(node, 'live_val', self.context.namespace[node.name])
   return node
Пример #44
0
 def visit_Attribute(self, node):
   self.generic_visit(node)
   anno.setanno(node, anno.Basic.QN,
                QN(anno.getanno(node.value, anno.Basic.QN), node.attr))
   return node
Пример #45
0
 def visit_Name(self, node):
   self.generic_visit(node)
   anno.setanno(node, anno.Basic.QN, QN(node.id))
   return node
Пример #46
0
    def visit_Expr(self, node):
        self.generic_visit(node)
        if isinstance(node.value, gast.Call):
            # Patterns of single function calls, like:
            #   opt.minimize(loss)
            # or:
            #   tf.py_func(...)

            # First, attempt to gate future evaluation of args. If that's not
            # possible, gate all remaining statements (and that may fail too, see
            # _visit_and_reindent.
            args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
            # NOTE: We can't guard object attributes because they may not be writable.
            guarded_args = tuple(s for s in args_scope.used
                                 if not s.is_composite())

            # TODO(mdan): Include all arguments which depended on guarded_args too.
            # For example, the following will still cause a race:
            #   tf.assign(a, a + 1)
            #   b = a + 1
            #   tf.assign(a, a + 1)  # Control deps here should include `b`
            #   c = b + 1
            # Or maybe we should just raise an "unsafe assign" error?

            if guarded_args:
                # The aliases may need new names to avoid incorrectly making them local.
                # TODO(mdan): This is brutal. It will even rename modules - any fix?
                need_alias = tuple(s for s in guarded_args
                                   if s not in args_scope.parent.modified)
                aliased_new_names = tuple(
                    qual_names.QN(
                        self.context.namer.new_symbol(
                            s.ssf(), args_scope.parent.referenced))
                    for s in need_alias)
                alias_map = dict(zip(need_alias, aliased_new_names))
                if len(guarded_args) == 1:
                    s, = guarded_args
                    aliased_guarded_args = alias_map.get(s, s)
                else:
                    aliased_guarded_args = gast.Tuple(
                        [alias_map.get(s, s).ast() for s in guarded_args],
                        None)

                template = """
          with py2tf_utils.control_dependency_on_returns(tf, call):
            aliased_guarded_args = py2tf_utils.alias_tensors(tf, guarded_args)
        """
                control_deps_guard = templates.replace(
                    template,
                    call=node.value,
                    aliased_guarded_args=aliased_guarded_args,
                    guarded_args=guarded_args)[-1]
            else:
                alias_map = {}

                template = """
          with py2tf_utils.control_dependency_on_returns(tf, call):
            pass
        """
                control_deps_guard = templates.replace(template,
                                                       call=node.value)[-1]
                control_deps_guard.body = []

            node = control_deps_guard
            anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER,
                         (node.body, alias_map))
        return node