def visit_Print(self, node):
   self._enter_scope(False)
   node.values = self.visit_block(node.values)
   anno.setanno(node, anno.Static.SCOPE, self.scope)
   anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
   self._exit_scope()
   return node
Example #2
0
  def visit_For(self, node):
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    break_var = self.context.namer.new_symbol('break__', scope.referenced)

    node.target = self.visit(node.target)
    node.iter = self.visit(node.iter)
    node.body, break_used = self._track_body(node.body, break_var)
    # A break in the else clause applies to the containing scope.
    node.orelse = self.visit_block(node.orelse)

    if break_used:
      node.orelse = self._guard_if_present(node.orelse, break_var)
      template = """
        var_name = False
        for_stmt
      """
      # Python's else clause only triggers if the loop exited cleanly (e.g.
      # break did not trigger).
      node = templates.replace(
          template,
          var_name=break_var,
          for_stmt=node)
      extra_test = templates.replace_as_expression(
          'not var_name', var_name=break_var)
      anno.setanno(node[1], 'extra_test', extra_test)

    return node
Example #3
0
 def visit_Attribute(self, node):
     node = self.generic_visit(node)
     if anno.hasanno(node.value, anno.Basic.QN):
         anno.setanno(
             node, anno.Basic.QN,
             QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr))
     return node
    def visit_Call(self, node):
        # If the function is wrapped by one of the marker decorators,
        # consider it graph ready.
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')
            if target_entity in self.nocompile_decorators:
                if len(node.args) < 1:
                    raise ValueError(
                        'Found call to decorator function "%s", but it had no arguments. '
                        'A decorator needs at least an argument.')
                anno.setanno(node.args[0], 'graph_ready', True)

        self.generic_visit(node)
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')
            if anno.hasanno(node.func, 'fqn'):
                target_fqn = anno.getanno(node.func, 'fqn')
            else:
                target_fqn = None
            if self._function_is_compilable(target_entity):
                node = self._rename_compilable_function(node)
            elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
                # TODO(mdan): Should we replace these with equivalent TF ops instead?
                node = self._wrap_to_py_func_single_return(
                    node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
            else:
                raise NotImplementedError(
                    'py_func with return values (unknown function)')
        else:
            if self.context.recursive:
                node = self._insert_dynamic_conversion(node)
            else:
                # Unresolved functions are allowed in non-recursive mode.
                pass
        return node
 def _aggregate_predecessors_defined_in(self, node):
     preds = self.current_analyzer.graph.stmt_prev[node]
     node_defined_in = set()
     for p in preds:
         node_defined_in |= set(self.current_analyzer.out[p].value.keys())
     anno.setanno(node, anno.Static.DEFINED_VARS_IN,
                  frozenset(node_defined_in))
  def _track_symbol(self,
                    node,
                    composite_writes_alter_parent=False,
                    writes_create_symbol=False):
    # A QN may be missing when we have an attribute (or subscript) on a function
    # call. Example: a().b
    if not anno.hasanno(node, anno.Basic.QN):
      return
    qn = anno.getanno(node, anno.Basic.QN)

    if isinstance(node.ctx, gast.Store):
      self.scope.mark_write(qn)
      if qn.is_composite and composite_writes_alter_parent:
        self.scope.mark_write(qn.parent)
      if writes_create_symbol:
        self.scope.mark_creation(qn, writes_create_symbol=True)
      if self._in_aug_assign:
        self.scope.mark_read(qn)
    elif isinstance(node.ctx, gast.Load):
      self.scope.mark_read(qn)
    elif isinstance(node.ctx, gast.Param):
      # Param contexts appear in function defs, so they have the meaning of
      # defining a variable.
      self.scope.mark_write(qn)
      self.scope.mark_param(qn, self.enclosing_entities[-1])
    else:
      raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))

    anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))

    if self._in_return_statement:
      self.scope.mark_returned(qn)
Example #7
0
 def visit_Call(self, node):
   if anno.hasanno(node.func, 'live_val'):
     # Symbols targeted by the "set_type" marker function are assigned the data
     # type that it specified.
     if (anno.getanno(node.func, 'live_val') is
         self.context.type_annotation_func):
       # Expecting the actual type to be the second argument.
       if len(node.args) != 2:
         raise ValueError('"%s" must have exactly two parameters'
                          % self.context.type_annotation_func)
       if not anno.hasanno(node.args[0], anno.Basic.QN):
         raise ValueError('the first argument of "%s" must by a symbol'
                          % self.context.type_annotation_func)
       if not anno.hasanno(node.args[1], 'live_val'):
         raise ValueError(
             'the second argument of "%s" must be statically resolvable' %
             self.context.type_annotation_func)
       target_symbol = anno.getanno(node.args[0], anno.Basic.QN)
       element_type = anno.getanno(node.args[1], 'live_val')
       # Find the definition of this symbol and annotate it with the given
       # data type. That in turn will cause future uses of the symbol
       # to receive the same type annotation.
       definition = self.scope.getval(target_symbol)
       anno.setanno(node, 'element_type', element_type)
       anno.setanno(definition, 'element_type', element_type)
       # TODO(mdan): Should we update references between definition and here?
   return self.generic_visit(node)
Example #8
0
    def visit_Call(self, node):
        if anno.hasanno(node.func, 'live_val'):
            # Symbols targeted by the "set_type" marker function are assigned the data
            # type that it specified.
            if (anno.getanno(node.func, 'live_val') is
                    self.context.type_annotation_func):

                if len(node.args) != 2:
                    raise ValueError('"%s" must have exactly two parameters' %
                                     self.context.type_annotation_func)
                target_arg, type_arg = node.args
                if not anno.hasanno(target_arg, anno.Basic.QN):
                    raise ValueError(
                        'the first argument of "%s" must by a symbol' %
                        self.context.type_annotation_func)
                if isinstance(type_arg, gast.Str):
                    element_type = type_arg.s
                elif isinstance(type_arg, gast.Num):
                    element_type = type_arg.n
                else:
                    if not anno.hasanno(type_arg, 'live_val'):
                        raise ValueError(
                            'the second argument of "%s" must be statically resolvable'
                            % self.context.type_annotation_func)
                    element_type = anno.getanno(type_arg, 'live_val')

                target_symbol = anno.getanno(target_arg, anno.Basic.QN)
                # Find the definition of this symbol and annotate it with the given
                # data type. That in turn will cause future uses of the symbol
                # to receive the same type annotation.
                definition = self.scope.getval(target_symbol)
                anno.setanno(node, 'element_type', element_type)
                anno.setanno(definition, 'element_type', element_type)
                # TODO(mdan): Should we update references between definition and here?
        return self.generic_visit(node)
Example #9
0
  def visit_Call(self, node):
    # If the function is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.nocompile_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least an argument.')
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if self.context.recursive:
        node = self._insert_dynamic_conversion(node)
      else:
        # Unresolved functions are allowed in non-recursive mode.
        pass
    return node
Example #10
0
 def _aggregate_successors_live_in(self, node):
   successors = self.current_analyzer.graph.stmt_next[node]
   node_live_out = set()
   for s in successors:
     node_live_out.update(self.current_analyzer.in_[s])
   anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(node_live_out))
   node = self.generic_visit(node)
   return node
Example #11
0
 def _as_function(self, func_name, args):
     template = """
   func_name(args)
 """
     replacement = templates.replace_as_expression(
         template, func_name=parser.parse_expression(func_name), args=args)
     anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
     return replacement
Example #12
0
 def visit_With(self, node):
     self.generic_visit(node)
     incoming = anno.getanno(node.body[0], self.in_label)
     for item in node.items:
         incoming |= anno.getanno(item, self.in_label)
     outgoing = anno.getanno(node.body[-1], self.out_label)
     anno.setanno(node, self.in_label, incoming)
     anno.setanno(node, self.out_label, outgoing)
Example #13
0
 def visit_With(self, node):
   current_scope = self.scope
   with_scope = Scope(current_scope, isolated=False)
   self.scope = with_scope
   self.generic_visit(node)
   anno.setanno(node, NodeAnno.BODY_SCOPE, with_scope)
   self.scope = current_scope
   return node
Example #14
0
 def visit_With(self, node):
   self.generic_visit(node)
   incoming = anno.getanno(node.body[0], self.in_label)
   for item in node.items:
     incoming |= anno.getanno(item, self.in_label)
   outgoing = anno.getanno(node.body[-1], self.out_label)
   anno.setanno(node, self.in_label, incoming)
   anno.setanno(node, self.out_label, outgoing)
 def _as_function(self, func_name, args):
   template = """
     func_name(args)
   """
   replacement = templates.replace_as_expression(
       template, func_name=parser.parse_expression(func_name), args=args)
   anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
   return replacement
Example #16
0
 def visit_With(self, node):
     current_scope = self.scope
     with_scope = Scope(current_scope, isolated=False)
     self.scope = with_scope
     self.generic_visit(node)
     anno.setanno(node, NodeAnno.BODY_SCOPE, with_scope)
     self.scope = current_scope
     return node
 def _aggregate_successors_live_in(self, node):
     successors = self.current_analyzer.graph.stmt_next[node]
     node_live_out = set()
     for s in successors:
         node_live_out.update(self.current_analyzer.in_[s])
     anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(node_live_out))
     node = self.generic_visit(node)
     return node
 def compiled_fn(self, test_fn, add_origin=False):
     node = self.parse_and_analyze(test_fn, {})
     if add_origin:
         anno.setanno(
             node.body[0], anno.Basic.ORIGIN,
             origin_info.OriginInfo(__file__, None, None, None, None))
     node = error_handlers.transform(node, self.ctx)
     module = self.compiled(node, )
     return module
 def visit_Call(self, node):
   self._enter_scope(False)
   node.args = self.visit_block(node.args)
   node.keywords = self.visit_block(node.keywords)
   # TODO(mdan): Account starargs, kwargs
   anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
   self._exit_scope()
   node.func = self.visit(node.func)
   return node
Example #20
0
 def visit_Print(self, node):
   current_scope = self.scope
   args_scope = Scope(current_scope)
   self.scope = args_scope
   for n in node.values:
     self.visit(n)
   anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope)
   self.scope = current_scope
   return node
Example #21
0
 def visit_Print(self, node):
     current_scope = self.scope
     args_scope = Scope(current_scope)
     self.scope = args_scope
     for n in node.values:
         self.visit(n)
     anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope)
     self.scope = current_scope
     return node
Example #22
0
 def _process_block_node(self, node, block, scope_name):
   current_scope = self.scope
   block_scope = Scope(current_scope, isolated=False)
   self.scope = block_scope
   for n in block:
     self.visit(n)
   anno.setanno(node, scope_name, block_scope)
   self.scope = current_scope
   return node
 def _process_statement_directive(self, call_node, directive):
     if self.local_scope_level < 1:
         raise ValueError('"%s" must be used inside a statement' %
                          directive.__name__)
     target = self.get_local(ENCLOSING_LOOP)
     node_anno = anno.getanno(target, converter.AgAnno.DIRECTIVES, {})
     node_anno[directive] = _map_args(call_node, directive)
     anno.setanno(target, converter.AgAnno.DIRECTIVES, node_anno)
     return call_node
Example #24
0
 def _process_statement_directive(self, call_node, directive):
   if self.local_scope_level < 1:
     raise ValueError(
         '"%s" must be used inside a statement' % directive.__name__)
   target = self.get_local(ENCLOSING_LOOP)
   node_anno = anno.getanno(target, converter.AgAnno.DIRECTIVES, {})
   node_anno[directive] = _map_args(call_node, directive)
   anno.setanno(target, converter.AgAnno.DIRECTIVES, node_anno)
   return call_node
Example #25
0
 def _process_block_node(self, node, block, scope_name):
     current_scope = self.scope
     block_scope = Scope(current_scope, isolated=False)
     self.scope = block_scope
     for n in block:
         self.visit(n)
     anno.setanno(node, scope_name, block_scope)
     self.scope = current_scope
     return node
Example #26
0
 def _process_function_arg(self, arg_name):
   str_name = str(arg_name)
   if self.function_level == 1 and str_name in self.context.arg_types:
     # Forge a node to hold the type information, so that method calls on
     # it can resolve the type.
     type_holder = arg_name.ast()
     type_string, type_obj = self.context.arg_types[str_name]
     anno.setanno(type_holder, 'type', type_obj)
     anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
     self.scope.setval(arg_name, type_holder)
Example #27
0
 def visit_For(self, node):
     self.generic_visit(node)
     incoming = set(anno.getanno(node.body[0], self.in_label))
     incoming -= set((anno.getanno(node.target, anno.Basic.QN), ))
     outgoing = anno.getanno(node.body[-1], self.out_label)
     if node.orelse:
         orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
         outgoing = self.transfer_fn(outgoing, orelse_outgoing)
     anno.setanno(node, self.in_label, frozenset(incoming))
     anno.setanno(node, self.out_label, outgoing)
Example #28
0
 def visit_While(self, node):
     self.generic_visit(node)
     incoming = anno.getanno(node.body[0], self.in_label)
     incoming |= anno.getanno(node.test, self.in_label)
     outgoing = anno.getanno(node.body[-1], self.out_label)
     if node.orelse:
         orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
         outgoing = self.transfer_fn(outgoing, orelse_outgoing)
     anno.setanno(node, self.in_label, incoming)
     anno.setanno(node, self.out_label, outgoing)
Example #29
0
  def test_rename_symbols_annotations(self):
    node = parser.parse_str('a[i]')
    node = qual_names.resolve(node)
    anno.setanno(node, 'foo', 'bar')
    orig_anno = anno.getanno(node, 'foo')

    node = ast_util.rename_symbols(node,
                                   {qual_names.QN('a'): qual_names.QN('b')})

    self.assertIs(anno.getanno(node, 'foo'), orig_anno)
Example #30
0
 def visit_While(self, node):
   self.generic_visit(node)
   incoming = anno.getanno(node.body[0], self.in_label)
   incoming |= anno.getanno(node.test, self.in_label)
   outgoing = anno.getanno(node.body[-1], self.out_label)
   if node.orelse:
     orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
     outgoing = self.transfer_fn(outgoing, orelse_outgoing)
   anno.setanno(node, self.in_label, incoming)
   anno.setanno(node, self.out_label, outgoing)
 def visit_For(self, node):
   self._enter_scope(False)
   node.target = self.visit(node.target)
   node.iter = self.visit(node.iter)
   anno.setanno(node.iter, anno.Static.SCOPE, self.scope)
   self._exit_scope()
   node = self._process_parallel_blocks(node,
                                        ((node.body, NodeAnno.BODY_SCOPE),
                                         (node.orelse, NodeAnno.ORELSE_SCOPE)))
   return node
Example #32
0
    def test_rename_symbols_annotations(self):
        node = parser.parse_str('a[i]')
        node = qual_names.resolve(node)
        anno.setanno(node, 'foo', 'bar')
        orig_anno = anno.getanno(node, 'foo')

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('b')})

        self.assertIs(anno.getanno(node, 'foo'), orig_anno)
Example #33
0
  def test_copy(self):
    node_1 = ast.Name()
    anno.setanno(node_1, 'foo', 3)

    node_2 = ast.Name()
    anno.copyanno(node_1, node_2, 'foo')
    anno.copyanno(node_1, node_2, 'bar')

    self.assertTrue(anno.hasanno(node_2, 'foo'))
    self.assertFalse(anno.hasanno(node_2, 'bar'))
 def visit_While(self, node):
   self._enter_scope(False)
   node.test = self.visit(node.test)
   anno.setanno(node, NodeAnno.COND_SCOPE, self.scope)
   anno.setanno(node.test, anno.Static.SCOPE, self.scope)
   self._exit_scope()
   node = self._process_parallel_blocks(node,
                                        ((node.body, NodeAnno.BODY_SCOPE),
                                         (node.orelse, NodeAnno.ORELSE_SCOPE)))
   return node
Example #35
0
 def visit_For(self, node):
   self.generic_visit(node)
   incoming = set(anno.getanno(node.body[0], self.in_label))
   incoming -= set((anno.getanno(node.target, anno.Basic.QN),))
   outgoing = anno.getanno(node.body[-1], self.out_label)
   if node.orelse:
     orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
     outgoing = self.transfer_fn(outgoing, orelse_outgoing)
   anno.setanno(node, self.in_label, frozenset(incoming))
   anno.setanno(node, self.out_label, outgoing)
Example #36
0
    def test_copyanno(self):
        node_1 = ast.Name()
        anno.setanno(node_1, 'foo', 3)

        node_2 = ast.Name()
        anno.copyanno(node_1, node_2, 'foo')
        anno.copyanno(node_1, node_2, 'bar')

        self.assertTrue(anno.hasanno(node_2, 'foo'))
        self.assertFalse(anno.hasanno(node_2, 'bar'))
def resolve(nodes, source, function=None):
    """Adds an origin information to all nodes inside the body of function.

  Args:
    nodes: Union[ast.AST, Iterable[ast.AST, ...]]
    source: Text, the source code string for the function whose body nodes will
      be annotated.
    function: Callable, the function that will have all nodes inside of it
      annotation with an OriginInfo annotation with key anno.Basic.ORIGIN.  If
      it is None then only the line numbers and column offset will be set in the
      annotation, with the rest of the information being None.

  Returns:
    A tuple of the AST node for function and a String containing its source
    code.
  """
    if not isinstance(nodes, (list, tuple)):
        nodes = (nodes, )

    if function:
        _, function_lineno = tf_inspect.getsourcelines(function)
        function_filepath = tf_inspect.getsourcefile(function)
    else:
        function_lineno = None
        function_filepath = None

    # TODO(mdan): Pull this to a separate utility.
    code_reader = six.StringIO(source)
    comment_map = {}
    for token in tokenize.generate_tokens(code_reader.readline):
        tok_type, tok_string, loc, _, _ = token
        srow, _ = loc
        if tok_type == tokenize.COMMENT:
            comment_map[srow] = tok_string.strip()[1:].strip()

    source_lines = source.split('\n')
    for node in nodes:
        for n in gast.walk(node):
            if not hasattr(n, 'lineno'):
                continue

            lineno_in_body = n.lineno

            source_code_line = source_lines[lineno_in_body - 1]
            if function:
                source_lineno = function_lineno + lineno_in_body
                function_name = function.__name__
            else:
                source_lineno = lineno_in_body
                function_name = None

            location = Location(function_filepath, source_lineno, n.col_offset)
            origin = OriginInfo(location, function_name, source_code_line,
                                comment_map.get(source_lineno))
            anno.setanno(n, anno.Basic.ORIGIN, origin)
Example #38
0
def resolve(nodes, source, function=None):
  """Adds an origin information to all nodes inside the body of function.

  Args:
    nodes: Union[ast.AST, Iterable[ast.AST, ...]]
    source: Text, the source code string for the function whose body nodes will
      be annotated.
    function: Callable, the function that will have all nodes inside of it
      annotation with an OriginInfo annotation with key anno.Basic.ORIGIN.  If
      it is None then only the line numbers and column offset will be set in the
      annotation, with the rest of the information being None.

  Returns:
    A tuple of the AST node for function and a String containing its source
    code.
  """
  if not isinstance(nodes, (list, tuple)):
    nodes = (nodes,)

  if function:
    _, function_lineno = tf_inspect.getsourcelines(function)
    function_filepath = tf_inspect.getsourcefile(function)
  else:
    function_lineno = None
    function_filepath = None

  # TODO(mdan): Pull this to a separate utility.
  code_reader = six.StringIO(source)
  comment_map = {}
  for token in tokenize.generate_tokens(code_reader.readline):
    tok_type, tok_string, loc, _, _ = token
    srow, _ = loc
    if tok_type == tokenize.COMMENT:
      comment_map[srow] = tok_string.strip()[1:].strip()

  source_lines = source.split('\n')
  for node in nodes:
    for n in gast.walk(node):
      if not hasattr(n, 'lineno'):
        continue

      lineno_in_body = n.lineno

      source_code_line = source_lines[lineno_in_body - 1]
      if function:
        source_lineno = function_lineno + lineno_in_body
        function_name = function.__name__
      else:
        source_lineno = lineno_in_body
        function_name = None

      location = Location(function_filepath, source_lineno, n.col_offset)
      origin = OriginInfo(location, function_name,
                          source_code_line, comment_map.get(source_lineno))
      anno.setanno(n, anno.Basic.ORIGIN, origin)
Example #39
0
 def _process_function_arg(self, arg_name):
     str_name = str(arg_name)
     if self.function_level == 1 and str_name in self.context.arg_types:
         # Forge a node to hold the type information, so that method calls on
         # it can resolve the type.
         type_holder = arg_name.ast()
         type_string, type_obj = self.context.arg_types[str_name]
         anno.setanno(type_holder, 'type', type_obj)
         anno.setanno(type_holder, 'type_fqn',
                      tuple(type_string.split('.')))
         self.scope.setval(arg_name, type_holder)
Example #40
0
 def test_copy_clean_preserves_annotations(self):
     node = parser.parse_str(
         textwrap.dedent("""
   def f(a):
     return a + 1
 """))
     anno.setanno(node.body[0], 'foo', 'bar')
     anno.setanno(node.body[0], 'baz', 1)
     new_node = ast_util.copy_clean(node, preserve_annos={'foo'})
     self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar')
     self.assertFalse(anno.hasanno(new_node.body[0], 'baz'))
Example #41
0
 def visit_FunctionDef(self, node):
   if self.scope:
     qn = QN(node.name)
     self.scope.mark_write(qn)
   current_scope = self.scope
   body_scope = Scope(current_scope, isolated=True)
   self.scope = body_scope
   self.generic_visit(node)
   anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope)
   self.scope = current_scope
   return node
Example #42
0
 def visit_FunctionDef(self, node):
     if self.scope:
         qn = QN(node.name)
         self.scope.mark_write(qn)
     current_scope = self.scope
     body_scope = Scope(current_scope, isolated=True)
     self.scope = body_scope
     self.generic_visit(node)
     anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope)
     self.scope = current_scope
     return node
 def _process_function_arg(self, arg_node):
   qn = anno.getanno(arg_node, anno.Basic.QN)
   arg_name = str(qn)
   self.scope.setval(qn, arg_node)
   if (len(self.enclosing_entities) == 1 and
       arg_name in self.entity_info.arg_types):
     # Forge a node to hold the type information, so that method calls on
     # it can resolve the type.
     type_string, type_obj = self.entity_info.arg_types[arg_name]
     anno.setanno(arg_node, 'type', type_obj)
     anno.setanno(arg_node, 'type_fqn', tuple(type_string.split('.')))
Example #44
0
 def test_copy_clean_preserves_annotations(self):
   node = parser.parse_str(
       textwrap.dedent("""
     def f(a):
       return a + 1
   """))
   anno.setanno(node.body[0], 'foo', 'bar')
   anno.setanno(node.body[0], 'baz', 1)
   new_node = ast_util.copy_clean(node, preserve_annos={'foo'})
   self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar')
   self.assertFalse(anno.hasanno(new_node.body[0], 'baz'))
  def visit_Name(self, node):
    if self.current_analyzer is None:
      # Names may appear outside function defs - for example in class
      # definitions.
      return node

    qn = anno.getanno(node, anno.Basic.QN)
    assert self.current_stmt_defs is not None, (
        'name node outside of any statement?')
    anno.setanno(node, anno.Static.DEFINITIONS,
                 tuple(self.current_stmt_defs.get(qn, ())))
    return node
Example #46
0
  def visit_While(self, node):
    current_scope = self.scope
    cond_scope = Scope(current_scope, isolated=False)
    self.scope = cond_scope
    self.visit(node.test)
    anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope)
    self.scope = current_scope

    node = self._process_parallel_blocks(node,
                                         ((node.body, NodeAnno.BODY_SCOPE),
                                          (node.orelse, NodeAnno.ORELSE_SCOPE)))
    return node
Example #47
0
    def visit_While(self, node):
        current_scope = self.scope
        cond_scope = Scope(current_scope, isolated=False)
        self.scope = cond_scope
        self.visit(node.test)
        anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope)
        self.scope = current_scope

        node = self._process_parallel_blocks(
            node, ((node.body, NodeAnno.BODY_SCOPE),
                   (node.orelse, NodeAnno.ORELSE_SCOPE)))
        return node
Example #48
0
    def test_basic(self):
        def test_fn():
            raise ValueError()

        node, ctx = self.prepare(test_fn, {})
        anno.setanno(
            node.body[0], anno.Basic.ORIGIN,
            origin_info.OriginInfo('test_path', None, None, None, None))
        node = error_handlers.transform(node, ctx)
        with self.compiled(node, {}) as result:
            with self.assertRaises(errors.GraphConstructionError):
                result.test_fn()
Example #49
0
 def visit_Call(self, node):
   current_scope = self.scope
   args_scope = Scope(current_scope, isolated=False)
   self.scope = args_scope
   for n in node.args:
     self.visit(n)
   # TODO(mdan): Account starargs, kwargs
   for n in node.keywords:
     self.visit(n)
   anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope)
   self.scope = current_scope
   self.visit(node.func)
   return node
Example #50
0
  def visit_Call(self, node):
    # If the function call is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.ctx.program.autograph_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least one positional argument.' %
              target_entity)
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if anno.hasanno(node.func, anno.Basic.QN):
        # Special-case a few builtins that otherwise go undetected. This
        # normally doesn't pose a problem, but the dict built-in doesn't
        # work with inspect.getargspec which is required for dynamic functions.
        # Note: expecting this is resilient to aliasing (e.g.
        # dict = an_evil_dict), because in those cases the regular mechanisms
        # process a simple user function.
        qn = anno.getanno(node.func, anno.Basic.QN)
        # Add items to this list as needed.
        if str(qn) in ('dict',):
          return node

      if ast_util.matches(node, 'super(_)'):
        # super() calls are preserved. The class conversion mechanism will
        # ensure that they return the correct value.
        return node

      if self.ctx.program.recursive:
        node = self._insert_dynamic_conversion(node)
    return node
Example #51
0
  def visit(self, cfg_node):
    # cfg_node.value is None for the exit node, which will be visited only once
    if not cfg_node.value:
      for pred in cfg_node.prev:
        self.visit(pred)
      return

    if anno.hasanno(cfg_node.value, self.in_label):
      before = hash(anno.getanno(cfg_node.value, self.in_label))
    else:
      before = None
    succs = [
        anno.getanno(succ.value, self.in_label)
        for succ in cfg_node.next
        if anno.hasanno(succ.value, self.in_label)
    ]
    if succs:
      incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0])
    else:
      incoming = frozenset()
    anno.setanno(cfg_node.value, self.out_label, incoming)
    gen, kill = self.get_gen_kill(cfg_node, incoming)
    anno.setanno(cfg_node.value, self.gen_label, gen)
    anno.setanno(cfg_node.value, self.kill_label, kill)
    anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen)
    if hash(anno.getanno(cfg_node.value, self.in_label)) != before:
      for pred in cfg_node.prev:
        self.visit(pred)
Example #52
0
  def visit(self, node):
    """Depth-first walking the CFG, applying dataflow info propagation."""
    # node.value is None only for the exit CfgNode.
    if not node.value:
      return

    if anno.hasanno(node.value, self.out_label):
      before = hash(anno.getanno(node.value, self.out_label))
    else:
      before = None
    preds = [
        anno.getanno(pred.value, self.out_label)
        for pred in node.prev
        if anno.hasanno(pred.value, self.out_label)
    ]
    if preds:
      incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0])
    else:
      incoming = frozenset()
    anno.setanno(node.value, self.in_label, incoming)
    gen, kill = self.get_gen_kill(node, incoming)
    anno.setanno(node.value, self.gen_label, gen)
    anno.setanno(node.value, self.kill_label, kill)
    anno.setanno(node.value, self.out_label, (incoming - kill) | gen)

    if hash(anno.getanno(node.value, self.out_label)) != before:
      for succ in node.next:
        self.visit(succ)
  def test_basic(self):

    def test_fn():
      raise ValueError()

    node, ctx = self.prepare(test_fn, {})
    anno.setanno(node, anno.Basic.ORIGIN,
                 origin_info.OriginInfo(None, None, None))
    node = error_handlers.transform(node, ctx)
    with self.compiled(node, {}) as result:
      with self.assertRaises(errors.GraphConstructionError):
        # Here we just assert that the handler works. Its correctness is
        # verified by errors_test.py.
        result.test_fn()