Ejemplo n.º 1
0
  def visit(self, cfg_node):
    # cfg_node.value is None for the exit node, which will be visited only once
    if not cfg_node.value:
      for pred in cfg_node.prev:
        self.visit(pred)
      return

    if anno.hasanno(cfg_node.value, self.in_label):
      before = hash(anno.getanno(cfg_node.value, self.in_label))
    else:
      before = None
    succs = [
        anno.getanno(succ.value, self.in_label)
        for succ in cfg_node.next
        if anno.hasanno(succ.value, self.in_label)
    ]
    if succs:
      incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0])
    else:
      incoming = frozenset()
    anno.setanno(cfg_node.value, self.out_label, incoming)
    gen, kill = self.get_gen_kill(cfg_node, incoming)
    anno.setanno(cfg_node.value, self.gen_label, gen)
    anno.setanno(cfg_node.value, self.kill_label, kill)
    anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen)
    if hash(anno.getanno(cfg_node.value, self.in_label)) != before:
      for pred in cfg_node.prev:
        self.visit(pred)
Ejemplo n.º 2
0
def _build_source_map(node, code):
  """Return the Python objects represented by given AST.

  Compiling the AST code this way ensures that the source code is readable by
  e.g. `pdb` or `inspect`.

  Args:
    node: An AST node of the original generated code, before the source code is
      generated.
    code: The string representation of the source code for the newly generated
      code.

  Returns:
    Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
    generated code.
  """
  # After we have the final generated code we reparse it to get the final line
  # numbers. Then we walk through the generated and original ASTs in parallel
  # to build the mapping between the user and generated code.
  new_node = parser.parse_str(code)
  origin_info.resolve(new_node, code)
  source_mapping = {}
  for before, after in ast_util.parallel_walk(node, new_node):
    # Need both checks because if origin information is ever copied over to new
    # nodes then we need to rely on the fact that only the original user code
    # has the origin annotation.
    if (anno.hasanno(before, anno.Basic.ORIGIN) and
        anno.hasanno(after, anno.Basic.ORIGIN)):
      source_info = anno.getanno(before, anno.Basic.ORIGIN)
      new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number
      source_mapping[new_line_number] = source_info
  return source_mapping
Ejemplo n.º 3
0
  def visit(self, node):
    """Depth-first walking the CFG, applying dataflow info propagation."""
    # node.value is None only for the exit CfgNode.
    if not node.value:
      return

    if anno.hasanno(node.value, self.out_label):
      before = hash(anno.getanno(node.value, self.out_label))
    else:
      before = None
    preds = [
        anno.getanno(pred.value, self.out_label)
        for pred in node.prev
        if anno.hasanno(pred.value, self.out_label)
    ]
    if preds:
      incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0])
    else:
      incoming = frozenset()
    anno.setanno(node.value, self.in_label, incoming)
    gen, kill = self.get_gen_kill(node, incoming)
    anno.setanno(node.value, self.gen_label, gen)
    anno.setanno(node.value, self.kill_label, kill)
    anno.setanno(node.value, self.out_label, (incoming - kill) | gen)

    if hash(anno.getanno(node.value, self.out_label)) != before:
      for succ in node.next:
        self.visit(succ)
Ejemplo n.º 4
0
  def _rename_compilable_function(self, node):
    assert anno.hasanno(node.func, 'live_val')
    assert anno.hasanno(node.func, 'fqn')
    target_entity = anno.getanno(node.func, 'live_val')
    target_fqn = anno.getanno(node.func, 'fqn')

    if not self._should_compile(node, target_fqn):
      return node

    if anno.hasanno(node, 'is_constructor'):
      new_name = self.ctx.namer.compiled_class_name(
          target_fqn, live_entity=target_entity)
      do_rename = True
    else:
      if anno.hasanno(node.func, 'parent_type'):
        owner_type = anno.getanno(node.func, 'parent_type')
      else:
        # Fallback - not reliable.
        owner_type = inspect_utils.getmethodclass(target_entity)
      new_name, do_rename = self.ctx.namer.compiled_function_name(
          target_fqn, live_entity=target_entity, owner_type=owner_type)

    if do_rename:
      if target_entity is not None:
        if tf_inspect.ismethod(target_entity):
          # The renaming process will transform it into a regular function.
          # TODO(mdan): Is this complete? How does it work with nested members?
          node.args = [node.func.value] + node.args
      node.func = templates.replace('func_name', func_name=new_name)[0]
    return node
Ejemplo n.º 5
0
 def visit_Call(self, node):
   if anno.hasanno(node.func, 'live_val'):
     # Symbols targeted by the "set_type" marker function are assigned the data
     # type that it specified.
     if (anno.getanno(node.func, 'live_val') is
         self.context.type_annotation_func):
       # Expecting the actual type to be the second argument.
       if len(node.args) != 2:
         raise ValueError('"%s" must have exactly two parameters'
                          % self.context.type_annotation_func)
       if not anno.hasanno(node.args[0], anno.Basic.QN):
         raise ValueError('the first argument of "%s" must by a symbol'
                          % self.context.type_annotation_func)
       if not anno.hasanno(node.args[1], 'live_val'):
         raise ValueError(
             'the second argument of "%s" must be statically resolvable' %
             self.context.type_annotation_func)
       target_symbol = anno.getanno(node.args[0], anno.Basic.QN)
       element_type = anno.getanno(node.args[1], 'live_val')
       # Find the definition of this symbol and annotate it with the given
       # data type. That in turn will cause future uses of the symbol
       # to receive the same type annotation.
       definition = self.scope.getval(target_symbol)
       anno.setanno(node, 'element_type', element_type)
       anno.setanno(definition, 'element_type', element_type)
       # TODO(mdan): Should we update references between definition and here?
   return self.generic_visit(node)
Ejemplo n.º 6
0
def _build_source_map(node, code):
    """Return the Python objects represented by given AST.

  Compiling the AST code this way ensures that the source code is readable by
  e.g. `pdb` or `inspect`.

  Args:
    node: An AST node of the original generated code, before the source code is
      generated.
    code: The string representation of the source code for the newly generated
      code.

  Returns:
    Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
    generated code.
  """
    # After we have the final generated code we reparse it to get the final line
    # numbers. Then we walk through the generated and original ASTs in parallel
    # to build the mapping between the user and generated code.
    new_node = parser.parse_str(code)
    origin_info.resolve(new_node, code)
    source_mapping = {}
    for before, after in ast_util.parallel_walk(node, new_node):
        # Need both checks because if origin information is ever copied over to new
        # nodes then we need to rely on the fact that only the original user code
        # has the origin annotation.
        if (anno.hasanno(before, anno.Basic.ORIGIN)
                and anno.hasanno(after, anno.Basic.ORIGIN)):
            source_info = anno.getanno(before, anno.Basic.ORIGIN)
            new_line_number = anno.getanno(after,
                                           anno.Basic.ORIGIN).line_number
            source_mapping[new_line_number] = source_info
    return source_mapping
Ejemplo n.º 7
0
 def visit_Attribute(self, node):
   self.generic_visit(node)
   if anno.hasanno(node.value, 'live_val'):
     assert anno.hasanno(node.value, 'fqn')
     parent_object = anno.getanno(node.value, 'live_val')
     if not hasattr(parent_object, node.attr):
       raise AttributeError('%s has no attribute %s' % (parent_object,
                                                        node.attr))
     anno.setanno(node, 'parent_type', type(parent_object))
     anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
     anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
   # TODO(mdan): Investigate the role built-in annotations can play here.
   elif anno.hasanno(node.value, 'type'):
     parent_type = anno.getanno(node.value, 'type')
     if hasattr(parent_type, node.attr):
       # This should hold for static members like methods.
       # This would not hold for dynamic members like function attributes.
       # For the dynamic case, we simply leave the node without an annotation,
       # and let downstream consumers figure out what to do.
       anno.setanno(node, 'parent_type', parent_type)
       anno.setanno(node, 'live_val', getattr(parent_type, node.attr))
       anno.setanno(node, 'fqn',
                    anno.getanno(node.value, 'type_fqn') + (node.attr,))
   elif isinstance(node.value, gast.Name):
     stem_name = node.value
     # All nonlocal symbols should be fully resolved.
     assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name
     # TODO(mdan): Figure out what to do when calling attribute on local object
     # Maybe just leave as-is?
   return node
Ejemplo n.º 8
0
  def visit_Call(self, node):
    if anno.hasanno(node.func, 'live_val'):
      # Symbols targeted by the "set_type" marker function are assigned the data
      # type that it specified.
      if anno.getanno(node.func, 'live_val') is utils.set_element_type:

        if len(node.args) < 2 or len(node.args) > 3:
          raise ValueError('"%s" must have either two or three parameters'
                           % self.context.type_annotation_func)
        if len(node.args) == 2:
          target_arg, type_arg = node.args
          shape_arg = parser.parse_expression('None')
        else:
          target_arg, type_arg, shape_arg = node.args
        if not anno.hasanno(target_arg, anno.Basic.QN):
          raise ValueError('the first argument of "%s" must by a symbol' %
                           utils.set_element_type)
        # TODO(mdan): This is vulnerable to symbol renaming.
        element_type = type_arg
        element_shape = shape_arg

        target_symbol = anno.getanno(target_arg, anno.Basic.QN)
        # Find the definition of this symbol and annotate it with the given
        # data type. That in turn will cause future uses of the symbol
        # to receive the same type annotation.
        definition = self.scope.getval(target_symbol)
        anno.setanno(node, 'element_type', element_type)
        anno.setanno(node, 'element_shape', element_shape)
        anno.setanno(definition, 'element_type', element_type)
        anno.setanno(definition, 'element_shape', element_shape)
        # TODO(mdan): Should we update references between definition and here?
    return self.generic_visit(node)
Ejemplo n.º 9
0
    def visit(self, cfg_node):
        # cfg_node.value is None for the exit node, which will be visited only once
        if not cfg_node.value:
            for pred in cfg_node.prev:
                self.visit(pred)
            return

        if anno.hasanno(cfg_node.value, self.in_label):
            before = hash(anno.getanno(cfg_node.value, self.in_label))
        else:
            before = None
        succs = [
            anno.getanno(succ.value, self.in_label) for succ in cfg_node.next
            if anno.hasanno(succ.value, self.in_label)
        ]
        if succs:
            incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0])
        else:
            incoming = frozenset()
        anno.setanno(cfg_node.value, self.out_label, incoming)
        gen, kill = self.get_gen_kill(cfg_node, incoming)
        anno.setanno(cfg_node.value, self.gen_label, gen)
        anno.setanno(cfg_node.value, self.kill_label, kill)
        anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen)
        if hash(anno.getanno(cfg_node.value, self.in_label)) != before:
            for pred in cfg_node.prev:
                self.visit(pred)
Ejemplo n.º 10
0
    def visit_Call(self, node):
        if anno.hasanno(node.func, 'live_val'):
            # Symbols targeted by the "set_type" marker function are assigned the data
            # type that it specified.
            if (anno.getanno(node.func, 'live_val') is
                    self.context.type_annotation_func):

                if len(node.args) != 2:
                    raise ValueError('"%s" must have exactly two parameters' %
                                     self.context.type_annotation_func)
                target_arg, type_arg = node.args
                if not anno.hasanno(target_arg, anno.Basic.QN):
                    raise ValueError(
                        'the first argument of "%s" must by a symbol' %
                        self.context.type_annotation_func)
                if isinstance(type_arg, gast.Str):
                    element_type = type_arg.s
                elif isinstance(type_arg, gast.Num):
                    element_type = type_arg.n
                else:
                    if not anno.hasanno(type_arg, 'live_val'):
                        raise ValueError(
                            'the second argument of "%s" must be statically resolvable'
                            % self.context.type_annotation_func)
                    element_type = anno.getanno(type_arg, 'live_val')

                target_symbol = anno.getanno(target_arg, anno.Basic.QN)
                # Find the definition of this symbol and annotate it with the given
                # data type. That in turn will cause future uses of the symbol
                # to receive the same type annotation.
                definition = self.scope.getval(target_symbol)
                anno.setanno(node, 'element_type', element_type)
                anno.setanno(definition, 'element_type', element_type)
                # TODO(mdan): Should we update references between definition and here?
        return self.generic_visit(node)
Ejemplo n.º 11
0
  def visit_Call(self, node):
    # If the function is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.nocompile_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least an argument.')
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if self.context.recursive:
        node = self._insert_dynamic_conversion(node)
      else:
        # Unresolved functions are allowed in non-recursive mode.
        pass
    return node
Ejemplo n.º 12
0
    def visit(self, node):
        """Depth-first walking the CFG, applying dataflow information propagation."""
        # node.value is None only for the exit CfgNode.
        if not node.value:
            return

        if anno.hasanno(node.value, self.out_label):
            before = hash(anno.getanno(node.value, self.out_label))
        else:
            before = None
        preds = [
            anno.getanno(pred.value, self.out_label) for pred in node.prev
            if anno.hasanno(pred.value, self.out_label)
        ]
        if preds:
            incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0])
        else:
            incoming = frozenset()
        anno.setanno(node.value, self.in_label, incoming)
        gen, kill = self.get_gen_kill(node, incoming)
        anno.setanno(node.value, self.gen_label, gen)
        anno.setanno(node.value, self.kill_label, kill)
        anno.setanno(node.value, self.out_label, (incoming - kill) | gen)

        if hash(anno.getanno(node.value, self.out_label)) != before:
            for succ in node.next:
                self.visit(succ)
    def _rename_compilable_function(self, node):
        assert anno.hasanno(node.func, 'live_val')
        assert anno.hasanno(node.func, 'fqn')
        target_entity = anno.getanno(node.func, 'live_val')
        target_fqn = anno.getanno(node.func, 'fqn')

        if not self._should_compile(node, target_fqn):
            return node

        if anno.hasanno(node, 'is_constructor'):
            new_name = self.context.namer.compiled_class_name(
                target_fqn, live_entity=target_entity)
            do_rename = True
        else:
            if anno.hasanno(node.func, 'parent_type'):
                owner_type = anno.getanno(node.func, 'parent_type')
            else:
                # Fallback - not reliable.
                owner_type = inspect_utils.getmethodclass(target_entity)
            new_name, do_rename = self.context.namer.compiled_function_name(
                target_fqn, live_entity=target_entity, owner_type=owner_type)

        if do_rename:
            if target_entity is not None:
                if tf_inspect.ismethod(target_entity):
                    # The renaming process will transform it into a regular function.
                    # TODO(mdan): Is this complete? How does it work with nested members?
                    node.args = [node.func.value] + node.args
            node.func = templates.replace('func_name', func_name=new_name)[0]
        return node
    def visit_Call(self, node):
        # If the function is wrapped by one of the marker decorators,
        # consider it graph ready.
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')
            if target_entity in self.nocompile_decorators:
                if len(node.args) < 1:
                    raise ValueError(
                        'Found call to decorator function "%s", but it had no arguments. '
                        'A decorator needs at least an argument.')
                anno.setanno(node.args[0], 'graph_ready', True)

        self.generic_visit(node)
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')
            if anno.hasanno(node.func, 'fqn'):
                target_fqn = anno.getanno(node.func, 'fqn')
            else:
                target_fqn = None
            if self._function_is_compilable(target_entity):
                node = self._rename_compilable_function(node)
            elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
                # TODO(mdan): Should we replace these with equivalent TF ops instead?
                node = self._wrap_to_py_func_single_return(
                    node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
            else:
                raise NotImplementedError(
                    'py_func with return values (unknown function)')
        else:
            if self.context.recursive:
                node = self._insert_dynamic_conversion(node)
            else:
                # Unresolved functions are allowed in non-recursive mode.
                pass
        return node
Ejemplo n.º 15
0
 def visit_Attribute(self, node):
     self.generic_visit(node)
     if anno.hasanno(node.value, 'live_val'):
         assert anno.hasanno(node.value, 'fqn')
         parent_object = anno.getanno(node.value, 'live_val')
         if not hasattr(parent_object, node.attr):
             raise AttributeError('%s has no attribute %s' %
                                  (parent_object, node.attr))
         anno.setanno(node, 'parent_type', type(parent_object))
         anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
         anno.setanno(node, 'fqn',
                      anno.getanno(node.value, 'fqn') + (node.attr, ))
     # TODO(mdan): Investigate the role built-in annotations can play here.
     elif anno.hasanno(node.value, 'type'):
         parent_type = anno.getanno(node.value, 'type')
         if hasattr(parent_type, node.attr):
             # This should hold for static members like methods.
             # This would not hold for dynamic members like function attributes.
             # For the dynamic case, we simply leave the node without an annotation,
             # and let downstream consumers figure out what to do.
             anno.setanno(node, 'parent_type', parent_type)
             anno.setanno(node, 'live_val', getattr(parent_type, node.attr))
             anno.setanno(
                 node, 'fqn',
                 anno.getanno(node.value, 'type_fqn') + (node.attr, ))
     elif isinstance(node.value, gast.Name):
         stem_name = node.value
         # All nonlocal symbols should be fully resolved.
         assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name
         # TODO(mdan): Figure out what to do when calling attribute on local object
         # Maybe just leave as-is?
     return node
Ejemplo n.º 16
0
    def visit_Call(self, node):
        if anno.hasanno(node.func, 'live_val'):
            # Symbols targeted by the "set_type" marker function are assigned the data
            # type that it specified.
            if (anno.getanno(node.func, 'live_val') is
                    self.context.type_annotation_func):

                if len(node.args) < 2 or len(node.args) > 3:
                    raise ValueError(
                        '"%s" must have either two or three parameters' %
                        self.context.type_annotation_func)
                if len(node.args) == 2:
                    target_arg, type_arg = node.args
                    shape_arg = parser.parse_expression('None')
                else:
                    target_arg, type_arg, shape_arg = node.args
                if not anno.hasanno(target_arg, anno.Basic.QN):
                    raise ValueError(
                        'the first argument of "%s" must by a symbol' %
                        self.context.type_annotation_func)
                # TODO(mdan): This is vulnerable to symbol renaming.
                element_type = type_arg
                element_shape = shape_arg

                target_symbol = anno.getanno(target_arg, anno.Basic.QN)
                # Find the definition of this symbol and annotate it with the given
                # data type. That in turn will cause future uses of the symbol
                # to receive the same type annotation.
                definition = self.scope.getval(target_symbol)
                anno.setanno(node, 'element_type', element_type)
                anno.setanno(node, 'element_shape', element_shape)
                anno.setanno(definition, 'element_type', element_type)
                anno.setanno(definition, 'element_shape', element_shape)
                # TODO(mdan): Should we update references between definition and here?
        return self.generic_visit(node)
Ejemplo n.º 17
0
    def test_copyanno(self):
        node_1 = ast.Name()
        anno.setanno(node_1, 'foo', 3)

        node_2 = ast.Name()
        anno.copyanno(node_1, node_2, 'foo')
        anno.copyanno(node_1, node_2, 'bar')

        self.assertTrue(anno.hasanno(node_2, 'foo'))
        self.assertFalse(anno.hasanno(node_2, 'bar'))
Ejemplo n.º 18
0
  def test_copy(self):
    node_1 = ast.Name()
    anno.setanno(node_1, 'foo', 3)

    node_2 = ast.Name()
    anno.copyanno(node_1, node_2, 'foo')
    anno.copyanno(node_1, node_2, 'bar')

    self.assertTrue(anno.hasanno(node_2, 'foo'))
    self.assertFalse(anno.hasanno(node_2, 'bar'))
Ejemplo n.º 19
0
 def _try_resolve_target(self, node):
   """Works for methods of objects of known type."""
   if anno.hasanno(node, 'live_val'):
     return anno.getanno(node, 'live_val')
   if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
     owner_type = anno.getanno(node, 'type')
     if hasattr(owner_type, node.attr):
       return getattr(owner_type, node.attr)
     else:
       raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' %
                        (owner_type, node.attr))
   return None
Ejemplo n.º 20
0
 def _try_resolve_target(self, node):
   """Works for methods of objects of known type."""
   if anno.hasanno(node, 'live_val'):
     return anno.getanno(node, 'live_val')
   if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
     owner_type = anno.getanno(node, 'type')
     if hasattr(owner_type, node.attr):
       return getattr(owner_type, node.attr)
     else:
       raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' %
                        (owner_type, node.attr))
   return None
Ejemplo n.º 21
0
  def visit_Call(self, node):
    # If the function call is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.ctx.program.autograph_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least one positional argument.' %
              target_entity)
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if anno.hasanno(node.func, anno.Basic.QN):
        # Special-case a few builtins that otherwise go undetected. This
        # normally doesn't pose a problem, but the dict built-in doesn't
        # work with inspect.getargspec which is required for dynamic functions.
        # Note: expecting this is resilient to aliasing (e.g.
        # dict = an_evil_dict), because in those cases the regular mechanisms
        # process a simple user function.
        qn = anno.getanno(node.func, anno.Basic.QN)
        # Add items to this list as needed.
        if str(qn) in ('dict',):
          return node

      if ast_util.matches(node, 'super(_)'):
        # super() calls are preserved. The class conversion mechanism will
        # ensure that they return the correct value.
        return node

      if self.ctx.program.recursive:
        node = self._insert_dynamic_conversion(node)
    return node
Ejemplo n.º 22
0
  def visit_Call(self, node):
    # If the function call is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.ctx.program.autograph_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least one positional argument.' %
              target_entity)
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if anno.hasanno(node.func, anno.Basic.QN):
        # Special-case a few builtins that otherwise go undetected. This
        # normally doesn't pose a problem, but the dict built-in doesn't
        # work with inspect.getargspec which is required for dynamic functions.
        # Note: expecting this is resilient to aliasing (e.g.
        # dict = an_evil_dict), because in those cases the regular mechanisms
        # process a simple user function.
        qn = anno.getanno(node.func, anno.Basic.QN)
        # Add items to this list as needed.
        if str(qn) in ('dict',):
          return node

      if ast_util.matches(node, 'super(_)'):
        # super() calls are preserved. The class conversion mechanism will
        # ensure that they return the correct value.
        return node

      if self.ctx.program.recursive:
        node = self._insert_dynamic_conversion(node)
    return node
Ejemplo n.º 23
0
    def test_nested_assignment(self):
        def test_fn(foo):
            a, (b, c) = foo
            return a, b, c

        node = self._parse_and_analyze(test_fn, {'foo': (1, 2, 3)})
        lhs = node.body[0].body[1].value.elts
        a = lhs[0]
        b = lhs[1]
        c = lhs[2]
        # TODO(mdan): change these once we have the live values propagating
        # correctly
        self.assertFalse(anno.hasanno(a, 'live_val'))
        self.assertFalse(anno.hasanno(b, 'live_val'))
        self.assertFalse(anno.hasanno(c, 'live_val'))
Ejemplo n.º 24
0
  def test_basic(self):
    node = ast.Name()

    self.assertFalse(anno.hasanno(node, 'foo'))
    with self.assertRaises(AttributeError):
      anno.getanno(node, 'foo')

    anno.setanno(node, 'foo', 3)
    self.assertTrue(anno.hasanno(node, 'foo'))
    self.assertEqual(3, anno.getanno(node, 'foo'))

    anno.delanno(node, 'foo')
    self.assertFalse(anno.hasanno(node, 'foo'))
    with self.assertRaises(AttributeError):
      anno.getanno(node, 'foo')
Ejemplo n.º 25
0
  def test_duplicate(self):
    node = ast.If(
        test=ast.Num(1),
        body=[ast.Expr(ast.Name('bar', ast.Load()))],
        orelse=[])
    anno.setanno(node, 'spam', 1)
    anno.setanno(node, 'ham', 1)
    anno.setanno(node.body[0], 'ham', 1)

    anno.dup(node, {'spam': 'eggs'})

    self.assertTrue(anno.hasanno(node, 'spam'))
    self.assertTrue(anno.hasanno(node, 'ham'))
    self.assertTrue(anno.hasanno(node, 'eggs'))
    self.assertFalse(anno.hasanno(node.body[0], 'eggs'))
Ejemplo n.º 26
0
    def test_basic(self):
        node = ast.Name()

        self.assertFalse(anno.hasanno(node, 'foo'))
        with self.assertRaises(AttributeError):
            anno.getanno(node, 'foo')

        anno.setanno(node, 'foo', 3)
        self.assertTrue(anno.hasanno(node, 'foo'))
        self.assertEqual(3, anno.getanno(node, 'foo'))

        anno.delanno(node, 'foo')
        self.assertFalse(anno.hasanno(node, 'foo'))
        with self.assertRaises(AttributeError):
            anno.getanno(node, 'foo')
Ejemplo n.º 27
0
    def visit_Name(self, node):
        self.generic_visit(node)
        if isinstance(node.ctx, gast.Load):
            assert anno.hasanno(node, NodeAnno.IS_LOCAL), node
            symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL)
            assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node
            symbol_is_modified = anno.getanno(node,
                                              NodeAnno.IS_MODIFIED_SINCE_ENTRY)
            assert anno.hasanno(node, NodeAnno.IS_PARAM), node
            symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM)

            if not symbol_is_local and not symbol_is_param:
                if node.id in self.literals:
                    anno.setanno(node, 'live_val', self.literals[node.id])
                elif node.id in self.context.namespace:
                    obj = self.context.namespace[node.id]
                    anno.setanno(node, 'live_val', obj)
                    if hasattr(obj, '__name__'):
                        anno.setanno(node, 'fqn', (obj.__name__, ))
                    elif hasattr(obj, '__class__'):
                        obj_class = obj.__class__
                        anno.setanno(
                            node, 'fqn',
                            (obj_class.__module__, obj_class.__name__))
                    else:
                        # If the symbol value is for example a primitive, then it will not
                        # have a name.
                        pass
                else:
                    pass
                    # TODO (mdan): Should we raise an error here? id:997
                    # https://github.com/imdone/tensorflow/issues/998
                    # Can encounter this when:
                    #  * a symbol truly lacks reference
                    #  * a symbol is new, like the new name of a function we just renamed.
            else:
                pass
                # TODO (mdan): Attempt to trace its value through the local chain. id:730
                # https://github.com/imdone/tensorflow/issues/731
                # TODO (mdan): Use type annotations as fallback. id:700
                # https://github.com/imdone/tensorflow/issues/701

            if not symbol_is_modified:
                if node.id in self.context.arg_values:
                    obj = self.context.arg_values[node.id]
                    anno.setanno(node, 'live_val', obj)
                    anno.setanno(node, 'fqn', (obj.__class__.__name__, ))
        return node
Ejemplo n.º 28
0
  def visit(self, node):
    source_code = self.context.source_code
    source_file = self.context.source_file
    did_enter_function = False

    try:
      if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)):
        self._enclosing_entities.append(node)
        did_enter_function = True

      if source_code and hasattr(node, 'lineno'):
        self._lineno = node.lineno
        self._col_offset = node.col_offset
      if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
        return node
      return super(Base, self).visit(node)

    except (ValueError, AttributeError, KeyError, NotImplementedError,
            AssertionError) as e:
      msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % (
          e.__class__.__name__, str(e), try_ast_to_source(node),
          pretty_printer.fmt(node, color=False))
      if source_code:
        line = source_code.splitlines()[self._lineno - 1]
      else:
        line = '<no source available>'
      six.reraise(AutographParseError,
                  AutographParseError(
                      msg,
                      (source_file, self._lineno, self._col_offset + 1, line)),
                  sys.exc_info()[2])
    finally:
      if did_enter_function:
        self._enclosing_entities.pop()
Ejemplo n.º 29
0
 def _visit_and_reindent(self, nodes):
   new_nodes = []
   current_dest = new_nodes
   alias_map = {}
   reindent_requested = False
   for n in nodes:
     n = self.visit(n)
     # NOTE: the order in which these statements execute is important; in
     # particular, watch out for ending up with cycles in the AST.
     if alias_map:
       n = ast_util.rename_symbols(n, alias_map)
     if isinstance(n, (list, tuple)):
       current_dest.extend(n)
     else:
       current_dest.append(n)
     if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
       reindent_requested = True
       new_dest, new_alias_map = anno.getanno(
           n, anno.Basic.INDENT_BLOCK_REMAINDER)
       anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
       new_alias_map.update(alias_map)
       alias_map = new_alias_map
       current_dest = new_dest
   if reindent_requested and not current_dest:
     # TODO(mdan): There may still be something that could be done.
     raise ValueError('Unable to insert statement into the computation flow: '
                      'it is not followed by any computation which '
                      'the statement could gate.')
   return new_nodes
Ejemplo n.º 30
0
 def visit_Call(self, node):
   node = self.generic_visit(node)
   if anno.hasanno(node.func, 'live_val'):
     live_val = anno.getanno(node.func, 'live_val')
     if live_val in py_builtins.SUPPORTED_BUILTINS:
       node = self._convert_builtin(live_val, node.args, as_expression=True)
   return node
Ejemplo n.º 31
0
 def _node_sets_self_attribute(self, node):
   if anno.hasanno(node, anno.Basic.QN):
     qn = anno.getanno(node, anno.Basic.QN)
     # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'.
     if qn.has_attr and qn.parent.qn == ('self',):
       return True
   return False
Ejemplo n.º 32
0
  def visit_node(self, node):
    prev_live_in = self.in_[node]

    if anno.hasanno(node.ast_node, anno.Static.SCOPE):
      node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)

      gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset())
      # TODO(mdan): verify whether composites' parents need to be added.
      # E.g. if x.y is live whether x needs to be added. Theoretically the
      # activity analysis should have both so that wouldn't be needed.
      kill = node_scope.modified

      live_out = set()
      for n in node.next:
        live_out |= self.in_[n]
      live_in = gen | (live_out - kill)

    else:
      # Nodes that don't have a scope annotation are assumed not to touch any
      # symbols.
      # This Name node below is a literal name, e.g. False
      assert isinstance(node.ast_node,
                        (gast.Name, gast.Continue, gast.Break)), type(
                            node.ast_node)
      live_in = prev_live_in
      live_out = live_in

    self.in_[node] = live_in
    self.out[node] = live_out

    # TODO(mdan): Move this to the superclass?
    return prev_live_in != live_in
Ejemplo n.º 33
0
 def visit_Attribute(self, node):
     node = self.generic_visit(node)
     if anno.hasanno(node.value, anno.Basic.QN):
         anno.setanno(
             node, anno.Basic.QN,
             QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr))
     return node
    def visit_node(self, node):
        prev_live_in = self.in_[node]

        if anno.hasanno(node.ast_node, anno.Static.SCOPE):
            node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)

            gen = node_scope.used | self.extra_gen.get(node.ast_node,
                                                       frozenset())
            # TODO(mdan): verify whether composites' parents need to be added.
            # E.g. if x.y is live whether x needs to be added. Theoretically the
            # activity analysis should have both so that wouldn't be needed.
            kill = node_scope.modified

            live_out = set()
            for n in node.next:
                live_out |= self.in_[n]
            live_in = gen | (live_out - kill)

        else:
            # Nodes that don't have a scope annotation are assumed not to touch any
            # symbols.
            # This Name node below is a literal name, e.g. False
            assert isinstance(node.ast_node,
                              (gast.Name, gast.Continue, gast.Break)), type(
                                  node.ast_node)
            live_in = prev_live_in
            live_out = live_in

        self.in_[node] = live_in
        self.out[node] = live_out

        # TODO(mdan): Move this to the superclass?
        return prev_live_in != live_in
Ejemplo n.º 35
0
    def test_parameter_class_members(self):
        def test_fn(opt):
            opt.minimize(0)

        node = self._parse_and_analyze(test_fn, {})
        method_call = node.body[0].body[0].value.func
        self.assertFalse(anno.hasanno(method_call, 'live_val'))
Ejemplo n.º 36
0
 def visit(self, node):
     source_code = self.context.source_code
     source_file = self.context.source_file
     try:
         if source_code and hasattr(node, 'lineno'):
             self._lineno = node.lineno
             self._col_offset = node.col_offset
         if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
             return node
         return super(Base, self).visit(node)
     except (ValueError, AttributeError, KeyError, NotImplementedError,
             AssertionError) as e:
         msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % (
             e.__class__.__name__, str(e), try_ast_to_source(node),
             pretty_printer.fmt(node, color=False))
         if source_code:
             line = source_code.splitlines()[self._lineno - 1]
         else:
             line = '<no source available>'
         six.reraise(
             AutographParseError,
             AutographParseError(
                 msg,
                 (source_file, self._lineno, self._col_offset + 1, line)),
             sys.exc_info()[2])
Ejemplo n.º 37
0
  def _track_symbol(self, node):
    # This can happen when we have an attribute (or subscript) on a function
    # call.  Example: a().b
    if not anno.hasanno(node, anno.Basic.QN):
      return
    qn = anno.getanno(node, anno.Basic.QN)

    if isinstance(node.ctx, gast.Store):
      self.scope.mark_write(qn)
    elif isinstance(node.ctx, gast.Load):
      self.scope.mark_read(qn)
    elif isinstance(node.ctx, gast.Param):
      # Param contexts appear in function defs, so they have the meaning of
      # defining a variable.
      # TODO(mdan): This bay be incorrect with nested functions.
      # For nested functions, we'll have to add the notion of hiding args from
      # the parent scope, not writing to them.
      self.scope.mark_creation(qn)
      self.scope.mark_param(qn)
    else:
      raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))

    anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))
    anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY,
                 self.scope.is_modified_since_entry(qn))
    anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn))

    if self._in_return_statement:
      self.scope.mark_returned(qn)
  def _track_symbol(self,
                    node,
                    composite_writes_alter_parent=False,
                    writes_create_symbol=False):
    # A QN may be missing when we have an attribute (or subscript) on a function
    # call. Example: a().b
    if not anno.hasanno(node, anno.Basic.QN):
      return
    qn = anno.getanno(node, anno.Basic.QN)

    if isinstance(node.ctx, gast.Store):
      self.scope.mark_write(qn)
      if qn.is_composite and composite_writes_alter_parent:
        self.scope.mark_write(qn.parent)
      if writes_create_symbol:
        self.scope.mark_creation(qn, writes_create_symbol=True)
      if self._in_aug_assign:
        self.scope.mark_read(qn)
    elif isinstance(node.ctx, gast.Load):
      self.scope.mark_read(qn)
    elif isinstance(node.ctx, gast.Param):
      # Param contexts appear in function defs, so they have the meaning of
      # defining a variable.
      self.scope.mark_write(qn)
      self.scope.mark_param(qn, self.enclosing_entities[-1])
    else:
      raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))

    anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))

    if self._in_return_statement:
      self.scope.mark_returned(qn)
 def _node_sets_self_attribute(self, node):
   if anno.hasanno(node, anno.Basic.QN):
     qn = anno.getanno(node, anno.Basic.QN)
     # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'.
     if qn.has_attr and qn.parent.qn == ('self',):
       return True
   return False
Ejemplo n.º 40
0
 def _node_sets_self_attribute(self, node):
     if anno.hasanno(node, anno.Basic.QN):
         qn = anno.getanno(node, anno.Basic.QN)
         # TODO (mdan): The 'self' argument is not guaranteed to be called 'self'. id:994
         # https://github.com/imdone/tensorflow/issues/996
         if qn.has_attr and qn.parent.qn == ('self', ):
             return True
Ejemplo n.º 41
0
  def _process_variable_assignment(self, source, targets):
    # Special case: constructors.
    if isinstance(source, gast.Call):
      func = source.func
      if anno.hasanno(func, 'live_val'):
        func_obj = anno.getanno(func, 'live_val')
        if tf_inspect.isclass(func_obj):
          anno.setanno(source, 'is_constructor', True)
          anno.setanno(source, 'type', func_obj)
          anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
          # TODO(mdan): Raise an error if constructor has side effects.
          # We can have a whitelist of no-side-effects constructors.
          # We can also step inside the constructor and further analyze.

    # Multiple targets mean multiple assignment.
    for target in targets:
      # Tuple target means unpacking.
      if isinstance(target, gast.Tuple):
        for i, target_item in enumerate(target.elts):
          # Two cases here:
          #   1. Static unpacking, e.g. a, b = c, d
          #   2. Dynamic unpacking, e.g. a, b = c
          # The former case is optimized away.
          if isinstance(source, (gast.Tuple, gast.List)):
            source_item = source.elts[i]
          else:
            source_item = gast.Subscript(source, gast.Index(i), ctx=None)
          self._process_variable_assignment(source_item, (target_item,))
      elif isinstance(target, (gast.Name, gast.Attribute)):
        target_symbol = anno.getanno(target, anno.Basic.QN)
        self.scope.setval(target_symbol, source)
      else:
        raise ValueError(
            'assignment target has unknown type: %s' % target_item)
Ejemplo n.º 42
0
  def test_local_scope_info_stack(self):

    class TestTransformer(transformer.Base):

      # Extract all string constants from the block.
      def visit_Str(self, node):
        self.set_local('string', self.get_local('string', default='') + node.s)
        return self.generic_visit(node)

      def _annotate_result(self, node):
        self.enter_local_scope()
        node = self.generic_visit(node)
        anno.setanno(node, 'test', self.get_local('string'))
        self.exit_local_scope()
        return node

      def visit_While(self, node):
        return self._annotate_result(node)

      def visit_For(self, node):
        return self._annotate_result(node)

    tr = TestTransformer(self._context_for_testing())

    def test_function(a):
      """Docstring."""
      assert a == 'This should not be counted'
      for i in range(3):
        _ = 'a'
        if i > 2:
          return 'b'
        else:
          _ = 'c'
          while True:
            raise '1'
      return 'nor this'

    node, _ = parser.parse_entity(test_function)
    node = tr.visit(node)

    for_node = node.body[0].body[2]
    while_node = for_node.body[1].orelse[1]

    self.assertFalse(anno.hasanno(for_node, 'string'))
    self.assertEqual('abc', anno.getanno(for_node, 'test'))
    self.assertFalse(anno.hasanno(while_node, 'string'))
    self.assertEqual('1', anno.getanno(while_node, 'test'))
Ejemplo n.º 43
0
    def test_local_scope_info_stack(self):
        class TestTransformer(transformer.Base):

            # Extract all string constants from the block.
            def visit_Str(self, node):
                self.set_local('string',
                               self.get_local('string', default='') + node.s)
                return self.generic_visit(node)

            def _annotate_result(self, node):
                self.enter_local_scope()
                node = self.generic_visit(node)
                anno.setanno(node, 'test', self.get_local('string'))
                self.exit_local_scope()
                return node

            def visit_While(self, node):
                return self._annotate_result(node)

            def visit_For(self, node):
                return self._annotate_result(node)

        tr = TestTransformer(self._simple_source_info())

        def test_function(a):
            """Docstring."""
            assert a == 'This should not be counted'
            for i in range(3):
                _ = 'a'
                if i > 2:
                    return 'b'
                else:
                    _ = 'c'
                    while True:
                        raise '1'
            return 'nor this'

        node, _ = parser.parse_entity(test_function)
        node = tr.visit(node)

        for_node = node.body[0].body[2]
        while_node = for_node.body[1].orelse[1]

        self.assertFalse(anno.hasanno(for_node, 'string'))
        self.assertEqual('abc', anno.getanno(for_node, 'test'))
        self.assertFalse(anno.hasanno(while_node, 'string'))
        self.assertEqual('1', anno.getanno(while_node, 'test'))
Ejemplo n.º 44
0
    def test_nested_members(self):
        def test_fn():
            foo = training.GradientDescentOptimizer(0.1)
            foo.bar.baz()

        node = self._parse_and_analyze(test_fn, {'training': training})
        method_call = node.body[0].body[1].value.func
        self.assertFalse(anno.hasanno(method_call, 'live_val'))
Ejemplo n.º 45
0
    def test_basic(self):
        node = ast.Name()

        self.assertFalse(anno.hasanno(node, 'foo'))
        with self.assertRaises(AttributeError):
            anno.getanno(node, 'foo')

        anno.setanno(node, 'foo', 3)
        self.assertTrue(anno.hasanno(node, 'foo'))
        self.assertEqual(anno.getanno(node, 'foo'), 3)
        self.assertEqual(anno.getanno(node, 'bar', default=7), 7)

        anno.delanno(node, 'foo')
        self.assertFalse(anno.hasanno(node, 'foo'))
        with self.assertRaises(AttributeError):
            anno.getanno(node, 'foo')
        self.assertIsNone(anno.getanno(node, 'foo', default=None))
Ejemplo n.º 46
0
  def test_basic(self):
    node = ast.Name()

    self.assertFalse(anno.hasanno(node, 'foo'))
    with self.assertRaises(AttributeError):
      anno.getanno(node, 'foo')

    anno.setanno(node, 'foo', 3)
    self.assertTrue(anno.hasanno(node, 'foo'))
    self.assertEqual(anno.getanno(node, 'foo'), 3)
    self.assertEqual(anno.getanno(node, 'bar', default=7), 7)

    anno.delanno(node, 'foo')
    self.assertFalse(anno.hasanno(node, 'foo'))
    with self.assertRaises(AttributeError):
      anno.getanno(node, 'foo')
    self.assertIsNone(anno.getanno(node, 'foo', default=None))
 def visit_Expr(self, node):
     if isinstance(node.value, gast.Call):
         if anno.hasanno(node.value.func, 'live_val'):
             target_entity = anno.getanno(node.value.func, 'live_val')
             if not self._function_is_compilable(target_entity):
                 if anno.hasanno(node.value.func, 'fqn'):
                     target_fqn = anno.getanno(node.value.func, 'fqn')
                     if not self._should_compile(node.value, target_fqn):
                         return node
                     node = self._wrap_to_py_func_no_return(node.value)
                     return node
         # Only the case of py_func with no return value is special.
         # Everything else is processed by visit_Call.
         self.visit(node.value)
     else:
         self.generic_visit(node)
     return node
Ejemplo n.º 48
0
 def visit_Expr(self, node):
   if isinstance(node.value, gast.Call):
     if anno.hasanno(node.value.func, 'live_val'):
       target_entity = anno.getanno(node.value.func, 'live_val')
       if not self._function_is_compilable(target_entity):
         if anno.hasanno(node.value.func, 'fqn'):
           target_fqn = anno.getanno(node.value.func, 'fqn')
           if not self._should_compile(node.value, target_fqn):
             return node
           node = self._wrap_to_py_func_no_return(node.value)
           return node
     # Only the case of py_func with no return value is special.
     # Everything else is processed by visit_Call.
     self.visit(node.value)
   else:
     self.generic_visit(node)
   return node
Ejemplo n.º 49
0
 def visit_Name(self, node):
   self.generic_visit(node)
   qn = anno.getanno(node, anno.Basic.QN)
   if isinstance(node.ctx, gast.Param):
     self._process_function_arg(qn)
   elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn):
     # E.g. if we had
     # a = b
     # then for future references to `a` we should have definition = `b`
     definition = self.scope.getval(qn)
     if anno.hasanno(definition, 'type'):
       anno.setanno(node, 'type', anno.getanno(definition, 'type'))
       anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn'))
     if anno.hasanno(definition, 'element_type'):
       anno.setanno(node, 'element_type',
                    anno.getanno(definition, 'element_type'))
   return node
Ejemplo n.º 50
0
  def test_parameter_class_members(self):

    def test_fn(opt):
      opt.minimize(0)

    node = self._parse_and_analyze(test_fn, {})
    method_call = node.body[0].body[0].value.func
    self.assertFalse(anno.hasanno(method_call, 'live_val'))
Ejemplo n.º 51
0
    def test_inner_scope(self):
        def test_fn():
            a = []
            utils.set_element_type(a, 1)
            for _ in a:
                b = []
                utils.set_element_type(b, 2)
                return a, b

        node = self._parse_and_analyze(test_fn, {'utils': utils})
        a, b = node.body[0].body[2].body[2].value.elts
        self.assertEquals(1, anno.getanno(a, 'element_type'))
        self.assertEquals(2, anno.getanno(b, 'element_type'))
        self.assertFalse(anno.hasanno(a, 'type'))
        self.assertFalse(anno.hasanno(b, 'type'))
        self.assertFalse(anno.hasanno(a, 'live_val'))
        self.assertFalse(anno.hasanno(b, 'live_val'))
Ejemplo n.º 52
0
    def visit_For(self, node):
        self.generic_visit(node)

        self._validate_no_live_vars_created(node)

        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        body_closure = body_scope.modified - body_scope.created
        all_referenced = body_scope.referenced

        state = list(body_closure)

        state_ssf = [
            self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
        ]
        ssf_map = {
            name: ssf
            for name, ssf in zip(state, state_ssf) if str(name) != ssf
        }

        if len(state) == 1:
            state = state[0]
            state_ssf = state_ssf[0]
            state_ast_tuple = state
        else:
            state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

        node_body = ast_util.rename_symbols(node.body, ssf_map)
        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        else:
            extra_test = parser.parse_expression('True')

        template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
        node = templates.replace(
            template,
            state=state,
            state_ssf=state_ssf,
            state_ast_tuple=state_ast_tuple,
            iter_=node.iter,
            iterate=node.target,
            extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                      all_referenced),
            extra_test_expr=extra_test,
            body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
            body=node_body)

        return node
Ejemplo n.º 53
0
  def visit_For(self, node):
    self.generic_visit(node)

    self._validate_no_live_vars_created(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iter_=node.iter,
        iterate=node.target,
        extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
        extra_test_expr=extra_test,
        body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
Ejemplo n.º 54
0
  def test_inner_scope(self):

    def test_fn():
      a = []
      utils.set_element_type(a, 1)
      for _ in a:
        b = []
        utils.set_element_type(b, 2)
        return a, b

    node = self._parse_and_analyze(test_fn, {'utils': utils})
    a, b = node.body[0].body[2].body[2].value.elts
    self.assertEquals(1, anno.getanno(a, 'element_type'))
    self.assertEquals(2, anno.getanno(b, 'element_type'))
    self.assertFalse(anno.hasanno(a, 'type'))
    self.assertFalse(anno.hasanno(b, 'type'))
    self.assertFalse(anno.hasanno(a, 'live_val'))
    self.assertFalse(anno.hasanno(b, 'live_val'))
Ejemplo n.º 55
0
  def test_nested_members(self):

    def test_fn():
      foo = training.GradientDescentOptimizer(0.1)
      foo.bar.baz()

    node = self._parse_and_analyze(test_fn, {'training': training})
    method_call = node.body[0].body[1].value.func
    self.assertFalse(anno.hasanno(method_call, 'live_val'))
 def _expect_simple_symbol(self, operand):
   if isinstance(operand, gast.Name):
     return
   if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND):
     return
   raise NotImplementedError(
       'only simple local variables are supported in logical and compound '
       'comparison expressions; for example, we support "a or b" but not '
       '"a.x or b"; for a workaround, assign the expression to a local '
       'variable and use that instead, for example "tmp = a.x", "tmp or b"')