Exemplo n.º 1
0
  def _rename_compilable_function(self, node):
    assert anno.hasanno(node.func, 'live_val')
    assert anno.hasanno(node.func, 'fqn')
    target_entity = anno.getanno(node.func, 'live_val')
    target_fqn = anno.getanno(node.func, 'fqn')

    if anno.hasanno(node, 'is_constructor'):
      new_name = self.ctx.namer.compiled_class_name(
          target_fqn, live_entity=target_entity)
      do_rename = True
    else:
      if anno.hasanno(node.func, 'parent_type'):
        owner_type = anno.getanno(node.func, 'parent_type')
      else:
        # Fallback - not reliable.
        owner_type = inspect_utils.getmethodclass(target_entity)
      new_name, do_rename = self.ctx.namer.compiled_function_name(
          target_fqn, live_entity=target_entity, owner_type=owner_type)

    if do_rename:
      if target_entity is not None:
        if tf_inspect.ismethod(target_entity):
          # The renaming process will transform it into a regular function.
          # TODO(mdan): Is this complete? How does it work with nested members?
          node.args = [node.func.value] + node.args
      node.func = templates.replace_as_expression(
          'func_name', func_name=new_name)
    return node
Exemplo n.º 2
0
    def test_static_attribute_of_ambiguous_type(self):

        test_self = self

        class TestClass1:

            a = 1

        class TestClass2:

            a = 2

        tc = TestClass1()

        class Resolver(type_inference.Resolver):
            def res_name(self, ns, types_ns, name):
                test_self.assertEqual(name, qual_names.QN('tc'))
                return {TestClass1, TestClass2}, None

            def res_value(self, ns, value):
                test_self.assertIn(value, (1, 2))
                return {str}

        def test_fn():
            return tc.a

        node, _ = TestTranspiler(Resolver).transform(test_fn, None)
        fn_body = node.body

        self.assertTypes(fn_body[0].value.value, (TestClass1, TestClass2))
        self.assertFalse(anno.hasanno(fn_body[0].value, anno.Static.TYPES))
        self.assertFalse(
            anno.hasanno(fn_body[0].value.value, anno.Static.VALUE))
        self.assertFalse(anno.hasanno(fn_body[0].value, anno.Static.VALUE))
Exemplo n.º 3
0
    def test_dynamic_attribute_of_typed_value(self):

        test_self = self

        class TestClass:
            def __init__(self):
                self.a = 1

        tc = TestClass()

        class Resolver(type_inference.Resolver):
            def res_name(self, ns, types_ns, name):
                test_self.assertEqual(name, qual_names.QN('tc'))
                return {TestClass}, None

        def test_fn():
            return tc.a

        node, _ = TestTranspiler(Resolver).transform(test_fn, None)
        fn_body = node.body

        self.assertTypes(fn_body[0].value.value, TestClass)
        self.assertFalse(anno.hasanno(fn_body[0].value, anno.Static.TYPES))
        self.assertFalse(
            anno.hasanno(fn_body[0].value.value, anno.Static.VALUE))
        self.assertFalse(anno.hasanno(fn_body[0].value, anno.Static.VALUE))
Exemplo n.º 4
0
    def visit(self, node):
        if not isinstance(node, gast.AST):
            # This is not that uncommon a mistake: various node bodies are lists, for
            # example, posing a land mine for transformers that need to recursively
            # call `visit`.  The error needs to be raised before the exception handler
            # below is installed, because said handler will mess up if `node` is not,
            # in fact, a node.
            msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
                   ' visit lists of nodes, use "visit_block" instead').format(
                       type(node))
            raise ValueError(msg)

        did_enter_function = False
        local_scope_size_at_entry = len(self._local_scope_state)
        processing_expr_node = False

        parent_origin = self.ctx.current_origin
        if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)):
            did_enter_function = True
        elif isinstance(node, gast.Expr):
            processing_expr_node = True

        if did_enter_function:
            self._enclosing_entities.append(node)

        if anno.hasanno(node, anno.Basic.ORIGIN):
            self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)

        if processing_expr_node:
            entry_expr_value = node.value

        if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
            result = super(Base, self).visit(node)
        self.ctx.current_origin = parent_origin

        # Adjust for consistency: replacing the value of an Expr with
        # an Assign node removes the need for the Expr node.
        if processing_expr_node:
            if isinstance(result,
                          gast.Expr) and result.value != entry_expr_value:
                # When the replacement is a list, it is assumed that the list came
                # from a template that contained a number of statements, which
                # themselves are standalone and don't require an enclosing Expr.
                if isinstance(result.value,
                              (list, tuple, gast.Assign, gast.AugAssign)):
                    result = result.value

        # On exception, the local scope integrity is not guaranteed.
        if did_enter_function:
            self._enclosing_entities.pop()

        if local_scope_size_at_entry != len(self._local_scope_state):
            raise AssertionError(
                'Inconsistent local scope stack. Before entering node %s, the'
                ' stack had length %d, after exit it has length %d. This'
                ' indicates enter_local_scope and exit_local_scope are not'
                ' well paired.' % (node, local_scope_size_at_entry,
                                   len(self._local_scope_state)))
        return result
Exemplo n.º 5
0
  def visit(self, node):
    if not isinstance(node, gast.AST):
      # This is not that uncommon a mistake: various node bodies are lists, for
      # example, posing a land mine for transformers that need to recursively
      # call `visit`.  The error needs to be raised before the exception handler
      # below is installed, because said handler will mess up if `node` is not,
      # in fact, a node.
      msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
             ' visit lists of nodes, use "visit_block" instead').format(
                 type(node))
      raise ValueError(msg)

    did_enter_function = False
    local_scope_size_at_entry = len(self._local_scope_state)
    processing_expr_node = False

    parent_origin = self.ctx.current_origin
    if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)):
      did_enter_function = True
    elif isinstance(node, gast.Expr):
      processing_expr_node = True

    if did_enter_function:
      self._enclosing_entities.append(node)

    if anno.hasanno(node, anno.Basic.ORIGIN):
      self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)

    if processing_expr_node:
      entry_expr_value = node.value

    if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
      result = super(Base, self).visit(node)
    self.ctx.current_origin = parent_origin

    # Adjust for consistency: replacing the value of an Expr with
    # an Assign node removes the need for the Expr node.
    if processing_expr_node:
      if isinstance(result, gast.Expr) and result.value != entry_expr_value:
        # When the replacement is a list, it is assumed that the list came
        # from a template that contained a number of statements, which
        # themselves are standalone and don't require an enclosing Expr.
        if isinstance(result.value,
                      (list, tuple, gast.Assign, gast.AugAssign)):
          result = result.value

    # On exception, the local scope integrity is not guaranteed.
    if did_enter_function:
      self._enclosing_entities.pop()

    if local_scope_size_at_entry != len(self._local_scope_state):
      raise AssertionError(
          'Inconsistent local scope stack. Before entering node %s, the'
          ' stack had length %d, after exit it has length %d. This'
          ' indicates enter_local_scope and exit_local_scope are not'
          ' well paired.' % (node, local_scope_size_at_entry,
                             len(self._local_scope_state)))
    return result
Exemplo n.º 6
0
    def visit(self, node):
        if not isinstance(node, gast.AST):
            # This is not that uncommon a mistake: various node bodies are lists, for
            # example, posing a land mine for transformers that need to recursively
            # call `visit`.  The error needs to be raised before the exception handler
            # below is installed, because said handler will mess up if `node` is not,
            # in fact, a node.
            msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
                   ' visit lists of nodes, use "visit_block" instead').format(
                       type(node))
            raise ValueError(msg)

        if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
            return node

        parent_origin = self.ctx.current_origin
        if anno.hasanno(node, anno.Basic.ORIGIN):
            self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)

        try:
            processing_expr_node = isinstance(node, gast.Expr)
            if processing_expr_node:
                entry_expr_value = node.value

            result = super(Base, self).visit(node)

            # Adjust for consistency: replacing the value of an Expr with
            # an Assign node removes the need for the Expr node.
            if (processing_expr_node and isinstance(result, gast.Expr)
                    and (result.value is not entry_expr_value)):
                # When the replacement is a list, it is assumed that the list came
                # from a template that contained a number of statements, which
                # themselves are standalone and don't require an enclosing Expr.
                if isinstance(result.value,
                              (list, tuple, gast.Assign, gast.AugAssign)):
                    result = result.value

            # By default, all replacements receive the origin info of the replaced
            # node.
            if result is not node and result is not None:
                inherited_origin = anno.getanno(node,
                                                anno.Basic.ORIGIN,
                                                default=parent_origin)
                if inherited_origin is not None:
                    nodes_to_adjust = result
                    if isinstance(result, (list, tuple)):
                        nodes_to_adjust = result
                    else:
                        nodes_to_adjust = (result, )
                    for n in nodes_to_adjust:
                        if not anno.hasanno(n, anno.Basic.ORIGIN):
                            anno.setanno(n, anno.Basic.ORIGIN,
                                         inherited_origin)
        finally:
            self.ctx.current_origin = parent_origin

        return result
Exemplo n.º 7
0
  def test_copy(self):
    node_1 = ast.Name()
    anno.setanno(node_1, 'foo', 3)

    node_2 = ast.Name()
    anno.copyanno(node_1, node_2, 'foo')
    anno.copyanno(node_1, node_2, 'bar')

    self.assertTrue(anno.hasanno(node_2, 'foo'))
    self.assertFalse(anno.hasanno(node_2, 'bar'))
Exemplo n.º 8
0
 def _try_resolve_target(self, node):
   """Works for methods of objects of known type."""
   if anno.hasanno(node, 'live_val'):
     return anno.getanno(node, 'live_val')
   if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
     owner_type = anno.getanno(node, 'type')
     if hasattr(owner_type, node.attr):
       return getattr(owner_type, node.attr)
     else:
       raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' %
                        (owner_type, node.attr))
   return None
Exemplo n.º 9
0
 def _try_resolve_target(self, node):
   """Works for methods of objects of known type."""
   if anno.hasanno(node, 'live_val'):
     return anno.getanno(node, 'live_val')
   if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
     owner_type = anno.getanno(node, 'type')
     if hasattr(owner_type, node.attr):
       return getattr(owner_type, node.attr)
     else:
       raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' %
                        (owner_type, node.attr))
   return None
Exemplo n.º 10
0
    def visit_Call(self, node):
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')

            if anno.hasanno(node.func, 'fqn'):
                target_fqn = anno.getanno(node.func, 'fqn')
            else:
                target_fqn = None

            if self._function_is_compilable(target_entity):
                if self._should_compile(node, target_fqn):
                    node = self._rename_compilable_function(node)
                else:
                    node = self.generic_visit(node)
                    return node

            elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
                # TODO(mdan): Should we replace these with equivalent TF ops instead?
                node = self._wrap_to_py_func_single_return(
                    node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)

            elif inspect_utils.isbuiltin(target_entity):
                # Note: Any builtin that passed the builtins converter is assumed to be
                # safe for graph mode.
                return node

            elif inspect_utils.isnamedtuple(target_entity):
                # Although not compilable, we assume they are safe for graph mode.
                node = self.generic_visit(node)
                return node

            else:
                # TODO(mdan): Instert dynamic conversion here instead.
                raise NotImplementedError(
                    'py_func with return values (unknown function)')
        else:
            # Special cases
            # TODO(mdan): These need a systematic review - there may be more.

            # 1. super() calls - these are preserved. The class conversion mechanism
            # will ensure that they return the correct value.
            if ast_util.matches(node, parser.parse_expression('super(_)')):
                return node

            # 2. super().method calls - these are preserved as well, when the
            # conversion processes the entire class.
            if (ast_util.matches(node,
                                 parser.parse_expression('super(_)._(_)'))
                    and self.ctx.info.owner_type is not None):
                return node

            node = self._insert_dynamic_conversion(node)
        return node
Exemplo n.º 11
0
  def visit_Call(self, node):
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')

      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None

      if self._function_is_compilable(target_entity):
        if self._should_compile(node, target_fqn):
          node = self._rename_compilable_function(node)
        else:
          node = self.generic_visit(node)
          return node

      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)

      elif inspect_utils.isbuiltin(target_entity):
        # Note: Any builtin that passed the builtins converter is assumed to be
        # safe for graph mode.
        return node

      elif inspect_utils.isnamedtuple(target_entity):
        # Although not compilable, we assume they are safe for graph mode.
        node = self.generic_visit(node)
        return node

      else:
        # TODO(mdan): Instert dynamic conversion here instead.
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      # Special cases
      # TODO(mdan): These need a systematic review - there may be more.

      # 1. super() calls - these are preserved. The class conversion mechanism
      # will ensure that they return the correct value.
      if ast_util.matches(node, 'super(_)'):
        return node

      # 2. super().method calls - these are preserved as well, when the
      # conversion processes the entire class.
      if (ast_util.matches(node, 'super(_)._(_)') and
          self.ctx.info.owner_type is not None):
        return node

      node = self._insert_dynamic_conversion(node)
    return node
Exemplo n.º 12
0
    def test_duplicate(self):
        node = ast.If(test=ast.Num(1),
                      body=[ast.Expr(ast.Name('bar', ast.Load()))],
                      orelse=[])
        anno.setanno(node, 'spam', 1)
        anno.setanno(node, 'ham', 1)
        anno.setanno(node.body[0], 'ham', 1)

        anno.dup(node, {'spam': 'eggs'})

        self.assertTrue(anno.hasanno(node, 'spam'))
        self.assertTrue(anno.hasanno(node, 'ham'))
        self.assertTrue(anno.hasanno(node, 'eggs'))
        self.assertFalse(anno.hasanno(node.body[0], 'eggs'))
Exemplo n.º 13
0
    def visit_Call(self, node):
        # If the function call is wrapped by one of the marker decorators,
        # consider it graph ready.
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')
            if target_entity in self.ctx.program.autograph_decorators:
                if len(node.args) < 1:
                    raise ValueError(
                        'Found call to decorator function "%s", but it had no arguments. '
                        'A decorator needs at least one positional argument.' %
                        target_entity)
                anno.setanno(node.args[0], 'graph_ready', True)

        self.generic_visit(node)
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')
            if anno.hasanno(node.func, 'fqn'):
                target_fqn = anno.getanno(node.func, 'fqn')
            else:
                target_fqn = None
            if self._function_is_compilable(target_entity):
                node = self._rename_compilable_function(node)
            elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
                # TODO(mdan): Should we replace these with equivalent TF ops instead?
                node = self._wrap_to_py_func_single_return(
                    node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
            else:
                raise NotImplementedError(
                    'py_func with return values (unknown function)')
        else:
            if anno.hasanno(node.func, anno.Basic.QN):
                # Special-case a few builtins that otherwise go undetected. This
                # normally doesn't pose a problem, but the dict built-in doesn't
                # work with inspect.getargspec which is required for dynamic functions.
                # Note: expecting this is resilient to aliasing (e.g.
                # dict = an_evil_dict), because in those cases the regular mechanisms
                # process a simple user function.
                qn = anno.getanno(node.func, anno.Basic.QN)
                # Add items to this list as needed.
                if str(qn) in ('dict', ):
                    return node

            if ast_util.matches(node, 'super(_)'):
                # super() calls are preserved. The class conversion mechanism will
                # ensure that they return the correct value.
                return node

            if self.ctx.program.recursive:
                node = self._insert_dynamic_conversion(node)
        return node
Exemplo n.º 14
0
  def visit_Call(self, node):
    # If the function call is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.ctx.program.options.strip_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least one positional argument.' %
              target_entity)
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if anno.hasanno(node.func, anno.Basic.QN):
        # Special-case a few builtins that otherwise go undetected. This
        # normally doesn't pose a problem, but the dict built-in doesn't
        # work with inspect.getargspec which is required for dynamic functions.
        # Note: expecting this is resilient to aliasing (e.g.
        # dict = an_evil_dict), because in those cases the regular mechanisms
        # process a simple user function.
        qn = anno.getanno(node.func, anno.Basic.QN)
        # Add items to this list as needed.
        if str(qn) in ('dict',):
          return node

      if ast_util.matches(node, 'super(_)'):
        # super() calls are preserved. The class conversion mechanism will
        # ensure that they return the correct value.
        return node

      if self.ctx.program.options.recursive:
        node = self._insert_dynamic_conversion(node)
    return node
    def test_no_inference_on_unknown_operand_types(self):

        # No information on types of a and b, see TestResolver.
        def magic_no_types(a, b):
            return a < b, a - b

        node, _ = TestTranspiler().transform(magic_no_types, None)
        fn_body = node.body

        # With no information on operand types, the operators will assert nothing.
        self.assertFalse(
            anno.hasanno(fn_body[0].value.elts[0], anno.Static.TYPES))
        self.assertFalse(
            anno.hasanno(fn_body[0].value.elts[1], anno.Static.TYPES))
Exemplo n.º 16
0
  def test_duplicate(self):
    node = ast.If(
        test=ast.Num(1),
        body=[ast.Expr(ast.Name('bar', ast.Load()))],
        orelse=[])
    anno.setanno(node, 'spam', 1)
    anno.setanno(node, 'ham', 1)
    anno.setanno(node.body[0], 'ham', 1)

    anno.dup(node, {'spam': 'eggs'})

    self.assertTrue(anno.hasanno(node, 'spam'))
    self.assertTrue(anno.hasanno(node, 'ham'))
    self.assertTrue(anno.hasanno(node, 'eggs'))
    self.assertFalse(anno.hasanno(node.body[0], 'eggs'))
Exemplo n.º 17
0
    def test_local_scope_info_stack(self):
        class TestTransformer(transformer.Base):

            # Extract all string constants from the block.
            def visit_Constant(self, node):
                self.set_local(
                    'string',
                    self.get_local('string', default='') + str(node.value))
                return self.generic_visit(node)

            def _annotate_result(self, node):
                self.enter_local_scope()
                node = self.generic_visit(node)
                anno.setanno(node, 'test', self.get_local('string'))
                self.exit_local_scope()
                return node

            def visit_While(self, node):
                return self._annotate_result(node)

            def visit_For(self, node):
                return self._annotate_result(node)

        tr = TestTransformer(self._simple_context())

        def test_function(a):
            """Docstring."""
            assert a == 'This should not be counted'
            for i in range(3):
                _ = 'a'
                if i > 2:
                    return 'b'
                else:
                    _ = 'c'
                    while 4:
                        raise '1'
            return 'nor this'

        node, _ = parser.parse_entity(test_function, future_features=())
        node = tr.visit(node)

        for_node = node.body[2]
        while_node = for_node.body[1].orelse[1]

        self.assertFalse(anno.hasanno(for_node, 'string'))
        self.assertEqual('3a2bc', anno.getanno(for_node, 'test'))
        self.assertFalse(anno.hasanno(while_node, 'string'))
        self.assertEqual('41', anno.getanno(while_node, 'test'))
Exemplo n.º 18
0
  def visit_For(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node)
    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols)

    has_extra_test = anno.hasanno(node, 'extra_test')
    if loop_state:
      if has_extra_test:
        # Loop with early stopping (e.g. break or return)
        extra_test = anno.getanno(node, 'extra_test')
        extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        extra_test_name = self.ctx.namer.new_symbol('extra_test',
                                                    reserved_symbols)
        loop_nodes = self._for_loop_with_extra_test(
            loop_state, state_ssf, state_ast_tuple, node, extra_test_name,
            extra_test, body_name, node_body)

      else:
        # Loop with loop-carried state and no early stopping
        loop_nodes = self._for_loop_with_state(
            loop_state, state_ssf, state_ast_tuple, node, body_name, node_body)

    else:
      # Loop with no loop-carried state and no early stopping
      assert not has_extra_test, ('Early stoppiong (e.g. break and/or return) '
                                  'should create state variables.')
      loop_nodes = self._for_loop_without_state(node, body_name, node_body)

    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
    return undefined_assigns + loop_nodes
Exemplo n.º 19
0
 def can_ignore(self, node):
     """Returns True if the node can safely be assumed not to touch variables."""
     ast_node = node.ast_node
     if anno.hasanno(ast_node, anno.Basic.SKIP_PROCESSING):
         return True
     return isinstance(ast_node,
                       (gast.Break, gast.Continue, gast.Raise, gast.Pass))
Exemplo n.º 20
0
    def visit_node(self, node):
        prev_live_in = self.in_[node]

        if anno.hasanno(node.ast_node, anno.Static.SCOPE):
            node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)

            gen = node_scope.used | self.extra_gen.get(node.ast_node,
                                                       frozenset())
            # TODO(mdan): verify whether composites' parents need to be added.
            # E.g. if x.y is live whether x needs to be added. Theoretically the
            # activity analysis should have both so that wouldn't be needed.
            kill = node_scope.modified

            live_out = set()
            for n in node.next:
                live_out |= self.in_[n]
            live_in = gen | (live_out - kill)

        else:
            # Nodes that don't have a scope annotation are assumed not to touch any
            # symbols.
            # This Name node below is a literal name, e.g. False
            assert isinstance(node.ast_node,
                              (gast.Name, gast.Continue, gast.Break)), type(
                                  node.ast_node)
            live_in = prev_live_in
            live_out = live_in

        self.in_[node] = live_in
        self.out[node] = live_out

        # TODO(mdan): Move this to the superclass?
        return prev_live_in != live_in
Exemplo n.º 21
0
 def _visit_and_reindent(self, nodes):
   new_nodes = []
   current_dest = new_nodes
   alias_map = {}
   reindent_requested = False
   for n in nodes:
     n = self.visit(n)
     # NOTE: the order in which these statements execute is important; in
     # particular, watch out for ending up with cycles in the AST.
     if alias_map:
       n = ast_util.rename_symbols(n, alias_map)
     if isinstance(n, (list, tuple)):
       current_dest.extend(n)
     else:
       current_dest.append(n)
     if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
       reindent_requested = True
       new_dest, new_alias_map = anno.getanno(
           n, anno.Basic.INDENT_BLOCK_REMAINDER)
       anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
       new_alias_map.update(alias_map)
       alias_map = new_alias_map
       current_dest = new_dest
   if reindent_requested and not current_dest:
     # TODO(mdan): There may still be something that could be done.
     raise ValueError('Unable to insert statement into the computation flow: '
                      'it is not followed by any computation which '
                      'the statement could gate.')
   return new_nodes
Exemplo n.º 22
0
    def test_no_inference_on_unknown_operand_types(self):
        class Resolver(type_inference.Resolver):
            def res_arg(self, ns, types_ns, f_name, name, type_anno):
                return None

        def test_fn(a, b):
            return a < b, a - b

        node, _ = TestTranspiler(Resolver).transform(test_fn, None)
        fn_body = node.body

        # With no information on operand types, the operators will infer nothing.
        self.assertFalse(
            anno.hasanno(fn_body[0].value.elts[0], anno.Static.TYPES))
        self.assertFalse(
            anno.hasanno(fn_body[0].value.elts[1], anno.Static.TYPES))
Exemplo n.º 23
0
 def visit_Attribute(self, node):
     node = self.generic_visit(node)
     if anno.hasanno(node.value, anno.Basic.QN):
         anno.setanno(
             node, anno.Basic.QN,
             QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr))
     return node
Exemplo n.º 24
0
    def test_static_attribute_of_typed_value(self):

        test_self = self

        class TestClass:

            a = 1

        tc = TestClass()

        class Resolver(type_inference.Resolver):
            def res_name(self, ns, types_ns, name):
                test_self.assertEqual(name, qual_names.QN('tc'))
                return {TestClass}, None

            def res_value(self, ns, value):
                test_self.assertIs(value, tc.a)
                return {str}

        def test_fn():
            return tc.a

        node, _ = TestTranspiler(Resolver).transform(test_fn, None)
        fn_body = node.body

        self.assertTypes(fn_body[0].value.value, TestClass)
        self.assertTypes(fn_body[0].value, str)  # Resolver is SOT
        self.assertFalse(
            anno.hasanno(fn_body[0].value.value, anno.Static.VALUE))
        self.assertEqual(anno.getanno(fn_body[0].value, anno.Static.VALUE), 1)
Exemplo n.º 25
0
    def visit_For(self, node):
        self.builder.begin_statement(node)
        self._enter_lexical_scope(node)

        self.builder.enter_section(node)

        # Note: Strictly speaking, this should be node.target + node.iter.
        # However, the activity analysis accounts for this inconsistency,
        # so dataflow analysis produces the correct values.
        self.generic_visit(node.iter)
        self.builder.enter_loop_section(node, node.iter)
        # Also include the "extra loop test" annotation, to capture things like the
        # control variable for return and break in for loops.
        if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
            self._process_basic_statement(
                anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST))
        for stmt in node.body:
            self.visit(stmt)
        self.builder.exit_loop_section(node)

        # Note: although the orelse is technically part of the loop node,
        # they don't count as loop bodies.  For example, a break in the loop's
        # orelse will affect the parent loop, not the current one.
        self._exit_lexical_scope(node)

        for stmt in node.orelse:
            self.visit(stmt)

        self.builder.exit_section(node)
        self.builder.end_statement(node)
Exemplo n.º 26
0
    def test_property_of_typed_value(self):

        test_self = self

        class TestClass:
            @property
            def a(self):
                return 1

        tc = TestClass()

        class Resolver(type_inference.Resolver):
            def res_name(self, ns, types_ns, name):
                test_self.assertEqual(name, qual_names.QN('tc'))
                return {TestClass}, None

            def res_value(self, ns, value):
                test_self.assertIs(value, TestClass.a)
                test_self.assertNotEqual(
                    value, 1)  # Can't evaluate property of class.
                return {property}

        def test_fn():
            return tc.a

        node, _ = TestTranspiler(Resolver).transform(test_fn, None)
        fn_body = node.body

        self.assertTypes(fn_body[0].value.value, TestClass)
        self.assertTypes(fn_body[0].value, property)
        self.assertFalse(
            anno.hasanno(fn_body[0].value.value, anno.Static.VALUE))
        self.assertEqual(anno.getanno(fn_body[0].value, anno.Static.VALUE),
                         TestClass.a)
Exemplo n.º 27
0
    def test_parameter_class_members(self):
        def test_fn(opt):
            opt.minimize(0)

        node = self._parse_and_analyze(test_fn, {})
        method_call = node.body[0].body[0].value.func
        self.assertFalse(anno.hasanno(method_call, 'live_val'))
Exemplo n.º 28
0
 def _node_sets_self_attribute(self, node):
   if anno.hasanno(node, anno.Basic.QN):
     qn = anno.getanno(node, anno.Basic.QN)
     # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'.
     if qn.has_attr and qn.parent.qn == ('self',):
       return True
   return False
Exemplo n.º 29
0
    def _track_symbol(self,
                      node,
                      composite_writes_alter_parent=False,
                      writes_create_symbol=False):
        # A QN may be missing when we have an attribute (or subscript) on a function
        # call. Example: a().b
        if not anno.hasanno(node, anno.Basic.QN):
            return
        qn = anno.getanno(node, anno.Basic.QN)

        if isinstance(node.ctx, gast.Store):
            self.scope.mark_write(qn)
            if qn.is_composite and composite_writes_alter_parent:
                self.scope.mark_write(qn.parent)
            if writes_create_symbol:
                self.scope.mark_creation(qn, writes_create_symbol=True)
            if self._in_aug_assign:
                self.scope.mark_read(qn)
        elif isinstance(node.ctx, gast.Load):
            self.scope.mark_read(qn)
        elif isinstance(node.ctx, gast.Param):
            # Param contexts appear in function defs, so they have the meaning of
            # defining a variable.
            self.scope.mark_write(qn)
            self.scope.mark_param(qn, self.enclosing_entities[-1])
        else:
            raise ValueError('Unknown context %s for node %s.' %
                             (type(node.ctx), qn))

        anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))

        if self._in_return_statement:
            self.scope.mark_returned(qn)
Exemplo n.º 30
0
 def visit_Name(self, node):
     # Only the loads which existed in the original code are overloaded.
     if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS):
         return node
     if isinstance(node.ctx, gast.Load):
         node = templates.replace_as_expression('ag__.ld(var_)', var_=node)
     return node
Exemplo n.º 31
0
 def _visit_and_reindent(self, nodes):
     new_nodes = []
     current_dest = new_nodes
     alias_map = {}
     reindent_requested = False
     for n in nodes:
         n = self.visit(n)
         # NOTE: the order in which these statements execute is important; in
         # particular, watch out for ending up with cycles in the AST.
         if alias_map:
             n = ast_util.rename_symbols(n, alias_map)
         if isinstance(n, (list, tuple)):
             current_dest.extend(n)
         else:
             current_dest.append(n)
         if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
             reindent_requested = True
             new_dest, new_alias_map = anno.getanno(
                 n, anno.Basic.INDENT_BLOCK_REMAINDER)
             anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
             new_alias_map.update(alias_map)
             alias_map = new_alias_map
             current_dest = new_dest
     if reindent_requested and not current_dest:
         # TODO(mdan): There may still be something that could be done.
         raise ValueError(
             'Unable to insert statement into the computation flow: '
             'it is not followed by any computation which '
             'the statement could gate.')
     return new_nodes
Exemplo n.º 32
0
  def _should_compile(self, node, fqn):
    """Determines whether an entity should be compiled in the context."""
    # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
    module_name = fqn[0]
    for mod in self.ctx.program.uncompiled_modules:
      if module_name.startswith(mod[0] + '.'):
        return False

    for i in range(1, len(fqn)):
      if fqn[:i] in self.ctx.program.uncompiled_modules:
        return False

    # Check for local decorations
    if anno.hasanno(node, 'graph_ready'):
      return False

    # The decorators themselves are not to be converted.
    # If present, the decorators should appear as static functions.
    target_entity = self._try_resolve_target(node.func)

    if target_entity is not None:

      # This may be reached when "calling" a callable attribute of an object.
      # For example:
      #
      #   self.fc = tf.keras.layers.Dense()
      #   self.fc()
      #
      for mod in self.ctx.program.uncompiled_modules:
        if target_entity.__module__.startswith(mod[0] + '.'):
          return False

      # This attribute is set by the decorator itself.
      # TODO(mdan): This may not play nicely with other wrapping decorators.
      if hasattr(target_entity, '__pyct_is_compile_decorator'):
        return False

      if target_entity in self.ctx.program.options.strip_decorators:
        return False

      # Inspect the target function decorators. If any include a @convert
      # or @graph_ready annotation, then they must be called as they are.
      # TODO(mdan): This may be quite heavy.
      # To parse and re-analyze each function for every call site could be quite
      # wasteful. Maybe we could cache the parsed AST?
      try:
        target_node, _ = parser.parse_entity(target_entity)
        target_node = target_node.body[0]
      except TypeError:
        # Functions whose source we cannot access are compilable (e.g. wrapped
        # to py_func).
        return True

      for dec in target_node.decorator_list:
        decorator_fn = self._resolve_name(dec)
        if (decorator_fn is not None and
            decorator_fn in self.ctx.program.options.strip_decorators):
          return False

    return True
Exemplo n.º 33
0
    def visit_FunctionDef(self, node):
        self.state[_Function].enter()
        # Note: if the conversion process ever creates helper functions, this
        # assumption will no longer hold.
        assert anno.hasanno(node, 'function_context_name'), (
            'The function_scopes converter always creates a scope for functions.'
        )
        self.state[_Function].context_name = anno.getanno(
            node, 'function_context_name')
        node.args = self.visit(node.args)
        node.body = self.visit_block(node.body)

        if self.state[_Function].level < 2:
            # Top-level functions lose their decorator because the conversion is
            # always just-in-time and by the time it happens the decorators are
            # already set to be applied.
            node.decorator_list = []
        else:
            # Inner functions are converted already, so we insert a decorator to
            # prevent double conversion. Double conversion would work too, but this
            # saves the overhead.
            node.decorator_list.append(
                parser.parse_expression('ag__.do_not_convert_internal'))

        if node.returns:
            node.returns = self.visit(node.returns)

        self.state[_Function].exit()
        return node
Exemplo n.º 34
0
  def visit_node(self, node):
    prev_live_in = self.in_[node]

    if anno.hasanno(node.ast_node, anno.Static.SCOPE):
      node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)

      gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset())
      # TODO(mdan): verify whether composites' parents need to be added.
      # E.g. if x.y is live whether x needs to be added. Theoretically the
      # activity analysis should have both so that wouldn't be needed.
      kill = node_scope.modified

      live_out = set()
      for n in node.next:
        live_out |= self.in_[n]
      live_in = gen | (live_out - kill)

    else:
      # Nodes that don't have a scope annotation are assumed not to touch any
      # symbols.
      # This Name node below is a literal name, e.g. False
      assert isinstance(node.ast_node,
                        (gast.Name, gast.Continue, gast.Break)), type(
                            node.ast_node)
      live_in = prev_live_in
      live_out = live_in

    self.in_[node] = live_in
    self.out[node] = live_out

    # TODO(mdan): Move this to the superclass?
    return prev_live_in != live_in
Exemplo n.º 35
0
 def _node_sets_self_attribute(self, node):
     if anno.hasanno(node, anno.Basic.QN):
         qn = anno.getanno(node, anno.Basic.QN)
         # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'.
         if qn.has_attr and qn.parent.qn == ('self', ):
             return True
     return False
Exemplo n.º 36
0
  def _should_compile(self, node, fqn):
    """Determines whether an entity should be compiled in the context."""
    # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
    module_name = fqn[0]
    for mod in self.ctx.program.uncompiled_modules:
      if module_name.startswith(mod[0] + '.'):
        return False

    for i in range(1, len(fqn)):
      if fqn[:i] in self.ctx.program.uncompiled_modules:
        return False

    # Check for local decorations
    if anno.hasanno(node, 'graph_ready'):
      return False

    # The decorators themselves are not to be converted.
    # If present, the decorators should appear as static functions.
    target_entity = self._try_resolve_target(node.func)

    if target_entity is not None:

      # This may be reached when "calling" a callable attribute of an object.
      # For example:
      #
      #   self.fc = tf.keras.layers.Dense()
      #   self.fc()
      #
      for mod in self.ctx.program.uncompiled_modules:
        if target_entity.__module__.startswith(mod[0] + '.'):
          return False

      # This attribute is set by the decorator itself.
      # TODO(mdan): This may not play nicely with other wrapping decorators.
      if hasattr(target_entity, '__pyct_is_compile_decorator'):
        return False

      if target_entity in self.ctx.program.options.strip_decorators:
        return False

      # Inspect the target function decorators. If any include a @convert
      # or @graph_ready annotation, then they must be called as they are.
      # TODO(mdan): This may be quite heavy.
      # To parse and re-analyze each function for every call site could be quite
      # wasteful. Maybe we could cache the parsed AST?
      try:
        target_node, _ = parser.parse_entity(target_entity)
        target_node = target_node.body[0]
      except TypeError:
        # Functions whose source we cannot access are compilable (e.g. wrapped
        # to py_func).
        return True

      for dec in target_node.decorator_list:
        decorator_fn = self._resolve_decorator_name(dec)
        if (decorator_fn is not None and
            decorator_fn in self.ctx.program.options.strip_decorators):
          return False

    return True
Exemplo n.º 37
0
    def visit_node(self, node):
        prev_live_in = self.in_[node]

        if anno.hasanno(node.ast_node, anno.Static.SCOPE):
            node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)

            gen = node_scope.read | self.extra_gen.get(node.ast_node,
                                                       frozenset())
            # TODO(mdan): verify whether composites' parents need to be added.
            # E.g. whether x needs to be added if x.y is live. Theoretically the
            # activity analysis should have both so that wouldn't be needed.
            kill = node_scope.modified | node_scope.deleted

            live_out = set()
            for n in node.next:
                live_out |= self.in_[n]
            live_in = gen | (live_out - kill)

        else:
            assert self.can_ignore(node), (node.ast_node, node)

            live_out = set()
            for n in node.next:
                live_out |= self.in_[n]
            live_in = live_out

        self.in_[node] = live_in
        self.out[node] = live_out

        # TODO(mdan): Move this to the superclass?
        return prev_live_in != live_in
Exemplo n.º 38
0
 def visit_Call(self, node):
   node = self.generic_visit(node)
   if anno.hasanno(node.func, 'live_val'):
     live_val = anno.getanno(node.func, 'live_val')
     if live_val in py_builtins.SUPPORTED_BUILTINS:
       node = self._convert_builtin(live_val, node.args, as_expression=True)
   return node
Exemplo n.º 39
0
  def _track_symbol(self,
                    node,
                    composite_writes_alter_parent=False,
                    writes_create_symbol=False):
    # A QN may be missing when we have an attribute (or subscript) on a function
    # call. Example: a().b
    if not anno.hasanno(node, anno.Basic.QN):
      return
    qn = anno.getanno(node, anno.Basic.QN)

    if isinstance(node.ctx, gast.Store):
      self.scope.mark_write(qn)
      if qn.is_composite and composite_writes_alter_parent:
        self.scope.mark_write(qn.parent)
      if writes_create_symbol:
        self.scope.mark_creation(qn, writes_create_symbol=True)
      if self._in_aug_assign:
        self.scope.mark_read(qn)
    elif isinstance(node.ctx, gast.Load):
      self.scope.mark_read(qn)
    elif isinstance(node.ctx, gast.Param):
      # Param contexts appear in function defs, so they have the meaning of
      # defining a variable.
      self.scope.mark_write(qn)
      self.scope.mark_param(qn, self.enclosing_entities[-1])
    else:
      raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))

    anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))

    if self._in_return_statement:
      self.scope.mark_returned(qn)
  def visit_For(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node)
    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols)

    has_extra_test = anno.hasanno(node, 'extra_test')
    if loop_state:
      if has_extra_test:
        # Loop with early stopping (e.g. break or return)
        extra_test = anno.getanno(node, 'extra_test')
        extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        extra_test_name = self.ctx.namer.new_symbol('extra_test',
                                                    reserved_symbols)
        node = self._create_for_loop_early_stopping(
            loop_state, state_ssf, state_ast_tuple, node, extra_test_name,
            extra_test, body_name, node_body)
      else:
        # Loop with loop-carried state and no early stopping
        node = self._create_for_loop_with_state(
            loop_state, state_ssf, state_ast_tuple, node, body_name, node_body)
    else:
      # Loop with no loop-carried state and no early stopping
      assert not has_extra_test, ('Early stoppiong (e.g. break and/or return) '
                                  'should create state variables.')
      node = self._create_for_loop_without_state(node, body_name, node_body)

    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
    return undefined_assigns + node
Exemplo n.º 41
0
  def test_local_scope_info_stack(self):

    class TestTransformer(transformer.Base):

      # Extract all string constants from the block.
      def visit_Str(self, node):
        self.set_local('string', self.get_local('string', default='') + node.s)
        return self.generic_visit(node)

      def _annotate_result(self, node):
        self.enter_local_scope()
        node = self.generic_visit(node)
        anno.setanno(node, 'test', self.get_local('string'))
        self.exit_local_scope()
        return node

      def visit_While(self, node):
        return self._annotate_result(node)

      def visit_For(self, node):
        return self._annotate_result(node)

    tr = TestTransformer(self._simple_context())

    def test_function(a):
      """Docstring."""
      assert a == 'This should not be counted'
      for i in range(3):
        _ = 'a'
        if i > 2:
          return 'b'
        else:
          _ = 'c'
          while True:
            raise '1'
      return 'nor this'

    node, _ = parser.parse_entity(test_function, future_features=())
    node = tr.visit(node)

    for_node = node.body[2]
    while_node = for_node.body[1].orelse[1]

    self.assertFalse(anno.hasanno(for_node, 'string'))
    self.assertEqual('abc', anno.getanno(for_node, 'test'))
    self.assertFalse(anno.hasanno(while_node, 'string'))
    self.assertEqual('1', anno.getanno(while_node, 'test'))
Exemplo n.º 42
0
  def visit_For(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols = self._get_loop_state(node)
    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    if loop_state:
      template = """
        def extra_test_name(state_ssf):
          return extra_test_expr
        def body_name(loop_vars, state_ssf):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return state_ssf,
        state_ast_tuple = ag__.for_stmt(
            iter_, extra_test_name, body_name, (state,))
      """
      node = templates.replace(
          template,
          state=loop_state,
          state_ssf=state_ssf,
          state_ast_tuple=state_ast_tuple,
          iter_=node.iter,
          iterate=node.target,
          extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                    reserved_symbols),
          extra_test_expr=extra_test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)
    else:
      template = """
        def extra_test_name():
          return extra_test_expr
        def body_name(loop_vars):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return ()
        ag__.for_stmt(iter_, extra_test_name, body_name, ())
      """
      node = templates.replace(
          template,
          iter_=node.iter,
          iterate=node.target,
          extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                    reserved_symbols),
          extra_test_expr=extra_test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)

    return node
Exemplo n.º 43
0
    def _track_symbol(self, node, composite_writes_alter_parent=False):
        # A QN may be missing when we have an attribute (or subscript) on a function
        # call. Example: a().b
        if not anno.hasanno(node, anno.Basic.QN):
            return
        qn = anno.getanno(node, anno.Basic.QN)

        # When inside a lambda, ignore any of the lambda's arguments.
        # This includes attributes or slices of those arguments.
        for l in self.state[_Lambda]:
            if qn in l.args:
                return
            if qn.owner_set & set(l.args):
                return

        # When inside a comprehension, ignore any of the comprehensions's targets.
        # This includes attributes or slices of those arguments.
        # This is not true in Python2, which leaks symbols.
        if six.PY3:
            for l in self.state[_Comprehension]:
                if qn in l.targets:
                    return
                if qn.owner_set & set(l.targets):
                    return

        if isinstance(node.ctx, gast.Store):
            # In comprehensions, modified symbols are the comprehension targets.
            if six.PY3 and self.state[_Comprehension].level > 0:
                # Like a lambda's args, they are tracked separately in Python3.
                self.state[_Comprehension].targets.add(qn)
            else:
                self.scope.mark_modified(qn)
                if qn.is_composite and composite_writes_alter_parent:
                    self.scope.mark_modified(qn.parent)
                if self._in_aug_assign:
                    self.scope.mark_read(qn)
        elif isinstance(node.ctx, gast.Load):
            self.scope.mark_read(qn)
        elif isinstance(node.ctx, gast.Param):
            if self._in_function_def_args:
                # In function defs have the meaning of defining a variable.
                self.scope.mark_modified(qn)
                self.scope.mark_param(qn, self.enclosing_entities[-1])
            elif self.state[_Lambda].level:
                # In lambdas, they are tracked separately.
                self.state[_Lambda].args.add(qn)
            else:
                # TODO(mdan): Is this case possible at all?
                raise NotImplementedError(
                    'Param "{}" outside a function arguments or lambda.'.
                    format(qn))
        elif isinstance(node.ctx, gast.Del):
            # The read matches the Python semantics - attempting to delete an
            # undefined symbol is illegal.
            self.scope.mark_read(qn)
            self.scope.mark_deleted(qn)
        else:
            raise ValueError('Unknown context {} for node "{}".'.format(
                type(node.ctx), qn))
Exemplo n.º 44
0
    def visit_For(self, node):
        self.generic_visit(node)

        loop_state, reserved_symbols = self._get_loop_state(node)
        loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
            loop_state, reserved_symbols)
        node_body = ast_util.rename_symbols(node.body, ssf_map)
        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        else:
            extra_test = parser.parse_expression('True')

        if loop_state:
            template = """
        def extra_test_name(state_ssf):
          return extra_test_expr
        def body_name(loop_vars, state_ssf):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return state_ssf,
        state_ast_tuple = ag__.for_stmt(
            iter_, extra_test_name, body_name, (state,))
      """
            node = templates.replace(template,
                                     state=loop_state,
                                     state_ssf=state_ssf,
                                     state_ast_tuple=state_ast_tuple,
                                     iter_=node.iter,
                                     iterate=node.target,
                                     extra_test_name=self.ctx.namer.new_symbol(
                                         'extra_test', reserved_symbols),
                                     extra_test_expr=extra_test,
                                     body_name=self.ctx.namer.new_symbol(
                                         'loop_body', reserved_symbols),
                                     body=node_body)
        else:
            template = """
        def extra_test_name():
          return extra_test_expr
        def body_name(loop_vars):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return ()
        ag__.for_stmt(iter_, extra_test_name, body_name, ())
      """
            node = templates.replace(template,
                                     iter_=node.iter,
                                     iterate=node.target,
                                     extra_test_name=self.ctx.namer.new_symbol(
                                         'extra_test', reserved_symbols),
                                     extra_test_expr=extra_test,
                                     body_name=self.ctx.namer.new_symbol(
                                         'loop_body', reserved_symbols),
                                     body=node_body)

        return node
Exemplo n.º 45
0
  def test_parameter_class_members(self):

    def test_fn(opt):
      opt.minimize(0)

    node = self._parse_and_analyze(test_fn, {})
    method_call = node.body[0].body[0].value.func
    self.assertFalse(anno.hasanno(method_call, 'live_val'))
Exemplo n.º 46
0
 def visit_Expr(self, node):
   if isinstance(node.value, gast.Call):
     if anno.hasanno(node.value.func, 'live_val'):
       target_entity = anno.getanno(node.value.func, 'live_val')
       if not self._function_is_compilable(target_entity):
         if anno.hasanno(node.value.func, 'fqn'):
           target_fqn = anno.getanno(node.value.func, 'fqn')
           if not self._should_compile(node.value, target_fqn):
             return node
           node = self._wrap_to_py_func_no_return(node.value)
           return node
     # Only the case of py_func with no return value is special.
     # Everything else is processed by visit_Call.
     self.visit(node.value)
   else:
     self.generic_visit(node)
   return node
Exemplo n.º 47
0
  def _track_symbol(self, node, composite_writes_alter_parent=False):
    # A QN may be missing when we have an attribute (or subscript) on a function
    # call. Example: a().b
    if not anno.hasanno(node, anno.Basic.QN):
      return
    qn = anno.getanno(node, anno.Basic.QN)

    # When inside a lambda, ignore any of the lambda's arguments.
    # This includes attributes or slices of those arguments.
    for l in self.state[_Lambda]:
      if qn in l.args:
        return
      if qn.owner_set & set(l.args):
        return

    # When inside a comprehension, ignore any of the comprehensions's targets.
    # This includes attributes or slices of those arguments.
    # This is not true in Python2, which leaks symbols.
    if six.PY3:
      for l in self.state[_Comprehension]:
        if qn in l.targets:
          return
        if qn.owner_set & set(l.targets):
          return

    if isinstance(node.ctx, gast.Store):
      # In comprehensions, modified symbols are the comprehension targets.
      if six.PY3 and self.state[_Comprehension].level > 0:
        # Like a lambda's args, they are tracked separately in Python3.
        self.state[_Comprehension].targets.add(qn)
      else:
        self.scope.mark_modified(qn)
        if qn.is_composite and composite_writes_alter_parent:
          self.scope.mark_modified(qn.parent)
        if self._in_aug_assign:
          self.scope.mark_read(qn)
    elif isinstance(node.ctx, gast.Load):
      self.scope.mark_read(qn)
    elif isinstance(node.ctx, gast.Param):
      if self._in_function_def_args:
        # In function defs have the meaning of defining a variable.
        self.scope.mark_modified(qn)
        self.scope.mark_param(qn, self.enclosing_entities[-1])
      elif self.state[_Lambda].level:
        # In lambdas, they are tracked separately.
        self.state[_Lambda].args.add(qn)
      else:
        # TODO(mdan): Is this case possible at all?
        raise NotImplementedError(
            'Param "{}" outside a function arguments or lambda.'.format(qn))
    elif isinstance(node.ctx, gast.Del):
      # The read matches the Python semantics - attempting to delete an
      # undefined symbol is illegal.
      self.scope.mark_read(qn)
      self.scope.mark_deleted(qn)
    else:
      raise ValueError('Unknown context {} for node "{}".'.format(
          type(node.ctx), qn))
Exemplo n.º 48
0
  def visit_For(self, node):
    self.generic_visit(node)

    self._validate_no_live_vars_created(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iter_=node.iter,
        iterate=node.target,
        extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
        extra_test_expr=extra_test,
        body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
Exemplo n.º 49
0
  def test_constructor_detection_builtin_class(self):

    def test_fn(x):
      res = zip(x)
      return res

    node = self._parse_and_analyze(test_fn, {})
    call_node = node.body[0].body[0].value
    self.assertFalse(anno.hasanno(call_node, 'is_constructor'))
Exemplo n.º 50
0
  def test_nested_members(self):

    def test_fn():
      foo = training.GradientDescentOptimizer(0.1)
      foo.bar.baz()

    node = self._parse_and_analyze(test_fn, {'training': training})
    method_call = node.body[0].body[1].value.func
    self.assertFalse(anno.hasanno(method_call, 'live_val'))
 def _expect_simple_symbol(self, operand):
   if isinstance(operand, gast.Name):
     return
   if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND):
     return
   raise NotImplementedError(
       'only simple local variables are supported in logical and compound '
       'comparison expressions; for example, we support "a or b" but not '
       '"a.x or b"; for a workaround, assign the expression to a local '
       'variable and use that instead, for example "tmp = a.x", "tmp or b"')
Exemplo n.º 52
0
 def test_copy_clean_preserves_annotations(self):
   node = parser.parse_str(
       textwrap.dedent("""
     def f(a):
       return a + 1
   """))
   anno.setanno(node.body[0], 'foo', 'bar')
   anno.setanno(node.body[0], 'baz', 1)
   new_node = ast_util.copy_clean(node, preserve_annos={'foo'})
   self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar')
   self.assertFalse(anno.hasanno(new_node.body[0], 'baz'))
Exemplo n.º 53
0
  def test_nested_unpacking(self):

    class Foo(object):
      pass

    class Bar(object):
      pass

    def test_fn():
      a, (b, c) = (Foo(), (Bar(), Foo()))
      return a, b, c

    node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar})
    a, b, c = node.body[0].body[1].value.elts
    self.assertEquals(anno.getanno(a, 'type'), Foo)
    self.assertEquals(anno.getanno(b, 'type'), Bar)
    self.assertEquals(anno.getanno(c, 'type'), Foo)
    self.assertFalse(anno.hasanno(a, 'live_val'))
    self.assertFalse(anno.hasanno(b, 'live_val'))
    self.assertFalse(anno.hasanno(c, 'live_val'))
Exemplo n.º 54
0
 def _block_statement_live_in(self, node, entry_node):
   if entry_node in self.current_analyzer.graph.index:
     cfg_node = self.current_analyzer.graph.index[entry_node]
     stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node])
   else:
     assert anno.hasanno(entry_node, anno.Static.LIVE_VARS_IN), (
         'If not matching a CFG node, must be a block statement:'
         ' {}'.format(entry_node))
     stmt_live_in = anno.getanno(entry_node, anno.Static.LIVE_VARS_IN)
   anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in)
   return node
Exemplo n.º 55
0
 def visit_Call(self, node):
   node = self.generic_visit(node)
   if anno.hasanno(node.func, 'live_val'):
     live_val = anno.getanno(node.func, 'live_val')
     try:
       if live_val in py_builtins.SUPPORTED_BUILTINS:
         node = self._convert_builtin(live_val, node.args, as_expression=True)
     except TypeError:
       # Not everything in Python is hashable. If it isn't then it's definitely
       # not a supported built-in.
       return node
   return node
Exemplo n.º 56
0
  def test_constructor_data_dependent(self):

    def test_fn(x):
      if x > 0:
        opt = training.GradientDescentOptimizer(0.1)
      else:
        opt = training.GradientDescentOptimizer(0.01)
      opt.minimize(0)

    node = self._parse_and_analyze(test_fn, {'training': training})
    method_call = node.body[0].body[1].value.func
    self.assertFalse(anno.hasanno(method_call, 'live_val'))
Exemplo n.º 57
0
  def test_function_variables(self):

    def bar():
      pass

    def test_fn():
      foo = bar
      foo()

    node = self._parse_and_analyze(test_fn, {'bar': bar})
    method_call = node.body[0].body[1].value.func
    self.assertFalse(anno.hasanno(method_call, 'live_val'))
Exemplo n.º 58
0
  def test_basic(self):
    node = ast.Name()

    self.assertEqual(anno.keys(node), set())
    self.assertFalse(anno.hasanno(node, 'foo'))
    with self.assertRaises(AttributeError):
      anno.getanno(node, 'foo')

    anno.setanno(node, 'foo', 3)

    self.assertEqual(anno.keys(node), {'foo'})
    self.assertTrue(anno.hasanno(node, 'foo'))
    self.assertEqual(anno.getanno(node, 'foo'), 3)
    self.assertEqual(anno.getanno(node, 'bar', default=7), 7)

    anno.delanno(node, 'foo')

    self.assertEqual(anno.keys(node), set())
    self.assertFalse(anno.hasanno(node, 'foo'))
    with self.assertRaises(AttributeError):
      anno.getanno(node, 'foo')
    self.assertIsNone(anno.getanno(node, 'foo', default=None))