def visit_For(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node)
    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols)

    has_extra_test = anno.hasanno(node, 'extra_test')
    if loop_state:
      if has_extra_test:
        # Loop with early stopping (e.g. break or return)
        extra_test = anno.getanno(node, 'extra_test')
        extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        extra_test_name = self.ctx.namer.new_symbol('extra_test',
                                                    reserved_symbols)
        node = self._create_for_loop_early_stopping(
            loop_state, state_ssf, state_ast_tuple, node, extra_test_name,
            extra_test, body_name, node_body)
      else:
        # Loop with loop-carried state and no early stopping
        node = self._create_for_loop_with_state(
            loop_state, state_ssf, state_ast_tuple, node, body_name, node_body)
    else:
      # Loop with no loop-carried state and no early stopping
      assert not has_extra_test, ('Early stoppiong (e.g. break and/or return) '
                                  'should create state variables.')
      node = self._create_for_loop_without_state(node, body_name, node_body)

    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
    return undefined_assigns + node
Beispiel #2
0
  def visit_For(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node)
    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols)

    has_extra_test = anno.hasanno(node, 'extra_test')
    if loop_state:
      if has_extra_test:
        # Loop with early stopping (e.g. break or return)
        extra_test = anno.getanno(node, 'extra_test')
        extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        extra_test_name = self.ctx.namer.new_symbol('extra_test',
                                                    reserved_symbols)
        loop_nodes = self._for_loop_with_extra_test(
            loop_state, state_ssf, state_ast_tuple, node, extra_test_name,
            extra_test, body_name, node_body)

      else:
        # Loop with loop-carried state and no early stopping
        loop_nodes = self._for_loop_with_state(
            loop_state, state_ssf, state_ast_tuple, node, body_name, node_body)

    else:
      # Loop with no loop-carried state and no early stopping
      assert not has_extra_test, ('Early stoppiong (e.g. break and/or return) '
                                  'should create state variables.')
      loop_nodes = self._for_loop_without_state(node, body_name, node_body)

    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
    return undefined_assigns + loop_nodes
Beispiel #3
0
    def visit_For(self, node):
        self.generic_visit(node)

        loop_state, reserved_symbols = self._get_loop_state(node)
        loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
            loop_state, reserved_symbols)
        node_body = ast_util.rename_symbols(node.body, ssf_map)
        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        else:
            extra_test = parser.parse_expression('True')

        if loop_state:
            template = """
        def extra_test_name(state_ssf):
          return extra_test_expr
        def body_name(loop_vars, state_ssf):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return state_ssf,
        state_ast_tuple = ag__.for_stmt(
            iter_, extra_test_name, body_name, (state,))
      """
            node = templates.replace(template,
                                     state=loop_state,
                                     state_ssf=state_ssf,
                                     state_ast_tuple=state_ast_tuple,
                                     iter_=node.iter,
                                     iterate=node.target,
                                     extra_test_name=self.ctx.namer.new_symbol(
                                         'extra_test', reserved_symbols),
                                     extra_test_expr=extra_test,
                                     body_name=self.ctx.namer.new_symbol(
                                         'loop_body', reserved_symbols),
                                     body=node_body)
        else:
            template = """
        def extra_test_name():
          return extra_test_expr
        def body_name(loop_vars):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return ()
        ag__.for_stmt(iter_, extra_test_name, body_name, ())
      """
            node = templates.replace(template,
                                     iter_=node.iter,
                                     iterate=node.target,
                                     extra_test_name=self.ctx.namer.new_symbol(
                                         'extra_test', reserved_symbols),
                                     extra_test_expr=extra_test,
                                     body_name=self.ctx.namer.new_symbol(
                                         'loop_body', reserved_symbols),
                                     body=node_body)

        return node
  def visit_For(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols = self._get_loop_state(node)
    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    if loop_state:
      template = """
        def extra_test_name(state_ssf):
          return extra_test_expr
        def body_name(loop_vars, state_ssf):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return state_ssf,
        state_ast_tuple = ag__.for_stmt(
            iter_, extra_test_name, body_name, (state,))
      """
      node = templates.replace(
          template,
          state=loop_state,
          state_ssf=state_ssf,
          state_ast_tuple=state_ast_tuple,
          iter_=node.iter,
          iterate=node.target,
          extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                    reserved_symbols),
          extra_test_expr=extra_test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)
    else:
      template = """
        def extra_test_name():
          return extra_test_expr
        def body_name(loop_vars):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return ()
        ag__.for_stmt(iter_, extra_test_name, body_name, ())
      """
      node = templates.replace(
          template,
          iter_=node.iter,
          iterate=node.target,
          extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                    reserved_symbols),
          extra_test_expr=extra_test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)

    return node
Beispiel #5
0
  def visit_For(self, node):
    self.generic_visit(node)

    self._validate_no_live_vars_created(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iter_=node.iter,
        iterate=node.target,
        extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
        extra_test_expr=extra_test,
        body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
Beispiel #6
0
    def visit_For(self, node):
        self.generic_visit(node)

        self._validate_no_live_vars_created(node)

        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        body_closure = body_scope.modified - body_scope.created
        all_referenced = body_scope.referenced

        state = list(body_closure)

        state_ssf = [
            self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
        ]
        ssf_map = {
            name: ssf
            for name, ssf in zip(state, state_ssf) if str(name) != ssf
        }

        if len(state) == 1:
            state = state[0]
            state_ssf = state_ssf[0]
            state_ast_tuple = state
        else:
            state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

        node_body = ast_util.rename_symbols(node.body, ssf_map)
        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        else:
            extra_test = parser.parse_expression('True')

        template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
        node = templates.replace(
            template,
            state=state,
            state_ssf=state_ssf,
            state_ast_tuple=state_ast_tuple,
            iter_=node.iter,
            iterate=node.target,
            extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                      all_referenced),
            extra_test_expr=extra_test,
            body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
            body=node_body)

        return node
Beispiel #7
0
  def visit_While(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols = self._get_loop_state(node)

    # Note: one might expect we can dispatch based on the loop condition.
    # But because that is dependent on the state, it cannot be evaluated ahead
    # of time - doing that would risk duplicating any effects the condition has.
    # Furthermore, we cannot evaluate slices and attributes, because they might
    # trigger __getitem__ or __getattribute__.
    #
    # A case where this fails includes ops with side effects on a stateful
    # resource captured in an object:
    #
    #   while self.v.read() > 0:
    #     self.v.assign(1)
    #
    # TODO(mdan): Handle the case above.
    cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
    cond_closure = set()
    for s in cond_scope.used:
      cond_closure.update(s.support_set)
    cond_closure -= loop_state

    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    test = ast_util.rename_symbols(node.test, ssf_map)

    template = """
      def test_name(state_ssf):
        return test
      def body_name(state_ssf):
        body
        return state_ssf,
      state_ast_tuple = ag__.while_stmt(
          test_name, body_name, (state,), (extra_deps,))
    """
    node = templates.replace(
        template,
        state=loop_state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
        test=test,
        body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
        body=node_body,
        extra_deps=tuple(s.ast() for s in cond_closure),
    )

    return node
    def visit_While(self, node):
        self.generic_visit(node)

        loop_state, reserved_symbols = self._get_loop_state(node)

        # Note: one might expect we can dispatch based on the loop condition.
        # But because that is dependent on the state, it cannot be evaluated ahead
        # of time - doing that would risk duplicating any effects the condition has.
        # Furthermore, we cannot evaluate slices and attributes, because they might
        # trigger __getitem__ or __getattribute__.
        #
        # A case where this fails includes ops with side effects on a stateful
        # resource captured in an object:
        #
        #   while self.v.read() > 0:
        #     self.v.assign(1)
        #
        # TODO(mdan): Handle the case above.
        cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
        cond_closure = set()
        for s in cond_scope.read:
            cond_closure.update(s.support_set)
        cond_closure -= loop_state

        loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
            loop_state, reserved_symbols)
        node_body = ast_util.rename_symbols(node.body, ssf_map)
        test = ast_util.rename_symbols(node.test, ssf_map)

        template = """
      def test_name(state_ssf):
        return test
      def body_name(state_ssf):
        body
        return state_ssf,
      state_ast_tuple = ag__.while_stmt(
          test_name, body_name, (state,), (extra_deps,))
    """
        node = templates.replace(
            template,
            state=loop_state,
            state_ssf=state_ssf,
            state_ast_tuple=state_ast_tuple,
            test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
            test=test,
            body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
            body=node_body,
            extra_deps=tuple(s.ast() for s in cond_closure),
        )

        return node
    def visit_While(self, node):
        self.generic_visit(node)

        loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(
            node,
            anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified)
        loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
            loop_state, reserved_symbols)
        node_body = ast_util.rename_symbols(node.body, ssf_map)
        test = ast_util.rename_symbols(node.test, ssf_map)

        if loop_state:
            template = """
        def test_name(state_ssf):
          return test
        def body_name(state_ssf):
          body
          return state_ssf,
        state_ast_tuple = ag__.while_stmt(test_name, body_name, (state,))
      """
            node = templates.replace(
                template,
                state=loop_state,
                state_ssf=state_ssf,
                state_ast_tuple=state_ast_tuple,
                test_name=self.ctx.namer.new_symbol('loop_test',
                                                    reserved_symbols),
                test=test,
                body_name=self.ctx.namer.new_symbol('loop_body',
                                                    reserved_symbols),
                body=node_body)
        else:
            template = """
        def test_name():
          return test
        def body_name():
          body
          return ()
        ag__.while_stmt(test_name, body_name, ())
      """
            node = templates.replace(
                template,
                test_name=self.ctx.namer.new_symbol('loop_test',
                                                    reserved_symbols),
                test=test,
                body_name=self.ctx.namer.new_symbol('loop_body',
                                                    reserved_symbols),
                body=node_body)

        undefined_assigns = self._create_undefined_assigns(possibly_undefs)
        return undefined_assigns + node
 def _for_loop_with_extra_test(self, loop_state, state_ssf, state_ast_tuple,
                               original_node, extra_test_name, extra_test,
                               body_name, loop_body, ssf_map):
     target_nodes = ast_util.rename_symbols(original_node.target, ssf_map)
     template = """
   def extra_test_name(state_ssf):
     return extra_test_expr
   def body_name(loop_vars, state_ssf):
     # Workaround for PEP-3113
     target = loop_vars
     body
     return state_ssf,
   state_ast_tuple = ag__.for_stmt(
       iter_, extra_test_name, body_name, (state,))
 """
     return templates.replace(template,
                              state=loop_state,
                              state_ssf=state_ssf,
                              state_ast_tuple=state_ast_tuple,
                              iter_=original_node.iter,
                              target=target_nodes,
                              extra_test_name=extra_test_name,
                              extra_test_expr=extra_test,
                              body_name=body_name,
                              body=loop_body)
Beispiel #11
0
    def test_rename_symbols_function(self):
        node = parser.parse('def f():\n  pass')
        node = ast_util.rename_symbols(
            node, {qual_names.QN('f'): qual_names.QN('f1')})

        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'def f1():\n    pass')
 def _visit_and_reindent(self, nodes):
     new_nodes = []
     current_dest = new_nodes
     alias_map = {}
     reindent_requested = False
     for n in nodes:
         n = self.visit(n)
         # NOTE: the order in which these statements execute is important; in
         # particular, watch out for ending up with cycles in the AST.
         if alias_map:
             n = ast_util.rename_symbols(n, alias_map)
         if isinstance(n, (list, tuple)):
             current_dest.extend(n)
         else:
             current_dest.append(n)
         if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
             reindent_requested = True
             new_dest, new_alias_map = anno.getanno(
                 n, anno.Basic.INDENT_BLOCK_REMAINDER)
             anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
             new_alias_map.update(alias_map)
             alias_map = new_alias_map
             current_dest = new_dest
     if reindent_requested and not current_dest:
         # TODO(mdan): There may still be something that could be done.
         raise ValueError(
             'Unable to insert statement into the computation flow: '
             'it is not followed by any computation which '
             'the statement could gate.')
     return new_nodes
Beispiel #13
0
 def _for_loop_with_extra_test(self, loop_state, state_ssf, state_ast_tuple,
                               original_node, extra_test_name, extra_test,
                               body_name, loop_body, ssf_map):
   target_nodes = ast_util.rename_symbols(original_node.target, ssf_map)
   template = """
     def extra_test_name(state_ssf):
       return extra_test_expr
     def body_name(loop_vars, state_ssf):
       # Workaround for PEP-3113
       target = loop_vars
       body
       return state_ssf,
     state_ast_tuple = ag__.for_stmt(
         iter_, extra_test_name, body_name, (state,))
   """
   return templates.replace(
       template,
       state=loop_state,
       state_ssf=state_ssf,
       state_ast_tuple=state_ast_tuple,
       iter_=original_node.iter,
       target=target_nodes,
       extra_test_name=extra_test_name,
       extra_test_expr=extra_test,
       body_name=body_name,
       body=loop_body)
 def _visit_and_reindent(self, nodes):
   new_nodes = []
   current_dest = new_nodes
   alias_map = {}
   reindent_requested = False
   for n in nodes:
     n = self.visit(n)
     # NOTE: the order in which these statements execute is important; in
     # particular, watch out for ending up with cycles in the AST.
     if alias_map:
       n = ast_util.rename_symbols(n, alias_map)
     if isinstance(n, (list, tuple)):
       current_dest.extend(n)
     else:
       current_dest.append(n)
     if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
       reindent_requested = True
       new_dest, new_alias_map = anno.getanno(
           n, anno.Basic.INDENT_BLOCK_REMAINDER)
       anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
       new_alias_map.update(alias_map)
       alias_map = new_alias_map
       current_dest = new_dest
   if reindent_requested and not current_dest:
     # TODO(mdan): There may still be something that could be done.
     raise ValueError('Unable to insert statement into the computation flow: '
                      'it is not followed by any computation which '
                      'the statement could gate.')
   return new_nodes
Beispiel #15
0
  def visit_While(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(
        node, anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified)
    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    test = ast_util.rename_symbols(node.test, ssf_map)

    if loop_state:
      template = """
        def test_name(state_ssf):
          return test
        def body_name(state_ssf):
          body
          return state_ssf,
        state_ast_tuple = ag__.while_stmt(test_name, body_name, (state,))
      """
      node = templates.replace(
          template,
          state=loop_state,
          state_ssf=state_ssf,
          state_ast_tuple=state_ast_tuple,
          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
          test=test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)
    else:
      template = """
        def test_name():
          return test
        def body_name():
          body
          return ()
        ag__.while_stmt(test_name, body_name, ())
      """
      node = templates.replace(
          template,
          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
          test=test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)

    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
    return undefined_assigns + node
    def test_rename_symbols_attributes(self):
        node = parser.parse_str('b.c = b.c.d')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

        source = compiler.ast_to_source(node)
        self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
Beispiel #17
0
  def test_rename_symbols_attributes(self):
    node = parser.parse_str('b.c = b.c.d')
    node = qual_names.resolve(node)

    node = ast_util.rename_symbols(
        node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

    source = compiler.ast_to_source(node)
    self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
Beispiel #18
0
    def test_rename_symbols_attributes(self):
        node = parser.parse('b.c = b.c.d')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
Beispiel #19
0
    def test_rename_symbols_global(self):
        node = parser.parse('global a, b, c')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b'): qual_names.QN('renamed_b')})

        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'global a, renamed_b, c')
    def test_rename_symbols_annotations(self):
        node = parser.parse_str('a[i]')
        node = qual_names.resolve(node)
        anno.setanno(node, 'foo', 'bar')
        orig_anno = anno.getanno(node, 'foo')

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('b')})

        self.assertIs(anno.getanno(node, 'foo'), orig_anno)
    def test_rename_symbols_basic(self):
        node = parser.parse_str('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

        self.assertIsInstance(node.body[0].value.left.id, str)
        source = compiler.ast_to_source(node)
        self.assertEqual(source.strip(), 'renamed_a + b')
    def test_rename_symbols_basic(self):
        node = parser.parse('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

        self.assertIsInstance(node.value.left.id, str)
        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'renamed_a + b')
Beispiel #23
0
  def test_rename_symbols_annotations(self):
    node = parser.parse_str('a[i]')
    node = qual_names.resolve(node)
    anno.setanno(node, 'foo', 'bar')
    orig_anno = anno.getanno(node, 'foo')

    node = ast_util.rename_symbols(node,
                                   {qual_names.QN('a'): qual_names.QN('b')})

    self.assertIs(anno.getanno(node, 'foo'), orig_anno)
Beispiel #24
0
  def test_rename_symbols_basic(self):
    node = parser.parse_str('a + b')
    node = qual_names.resolve(node)

    node = ast_util.rename_symbols(
        node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

    self.assertIsInstance(node.body[0].value.left.id, str)
    source = compiler.ast_to_source(node)
    self.assertEqual(source.strip(), 'renamed_a + b')
Beispiel #25
0
    def test_rename_symbols_basic(self):
        node = parser.parse('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
        source = parser.unparse(node, include_encoding_marker=False)
        expected_node_src = 'renamed_a + b'

        self.assertIsInstance(node.value.left.id, str)
        self.assertAstMatches(node, source)
        self.assertAstMatches(node, expected_node_src)
    def visit_If(self, node):
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
        defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
        live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

        # Note: this information needs to be extracted before the body conversion
        # that happens in the call to generic_visit below, because the conversion
        # generates nodes that lack static analysis annotations.
        need_alias_in_body = self._determine_aliased_symbols(
            body_scope, defined_in, node.body)
        need_alias_in_orelse = self._determine_aliased_symbols(
            orelse_scope, defined_in, node.orelse)

        node = self.generic_visit(node)

        modified_in_cond = body_scope.modified | orelse_scope.modified
        returned_from_cond = set()
        composites = set()
        for s in modified_in_cond:
            if s in live_out and not s.is_composite():
                returned_from_cond.add(s)
            if s.is_composite():
                # Special treatment for compound objects, always return them.
                # This allows special handling within the if_stmt itself.
                # For example, in TensorFlow we need to restore the state of composite
                # symbols to ensure that only effects from the executed branch are seen.
                composites.add(s)

        created_in_body = body_scope.modified & returned_from_cond - defined_in
        created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

        basic_created_in_body = tuple(s for s in created_in_body
                                      if not s.is_composite())
        basic_created_in_orelse = tuple(s for s in created_in_orelse
                                        if not s.is_composite())

        # These variables are defined only in a single branch. This is fine in
        # Python so we pass them through. Another backend, e.g. Tensorflow, may need
        # to handle these cases specially or throw an Error.
        possibly_undefined = (set(basic_created_in_body)
                              ^ set(basic_created_in_orelse))

        # Alias the closure variables inside the conditional functions, to allow
        # the functions access to the respective variables.
        # We will alias variables independently for body and orelse scope,
        # because different branches might write different variables.
        aliased_body_orig_names = tuple(need_alias_in_body)
        aliased_orelse_orig_names = tuple(need_alias_in_orelse)
        aliased_body_new_names = tuple(
            self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
            for s in aliased_body_orig_names)
        aliased_orelse_new_names = tuple(
            self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
            for s in aliased_orelse_orig_names)

        alias_body_map = dict(
            zip(aliased_body_orig_names, aliased_body_new_names))
        alias_orelse_map = dict(
            zip(aliased_orelse_orig_names, aliased_orelse_new_names))

        node_body = ast_util.rename_symbols(node.body, alias_body_map)
        node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

        cond_var_name = self.ctx.namer.new_symbol('cond',
                                                  body_scope.referenced)
        body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
        orelse_name = self.ctx.namer.new_symbol('if_false',
                                                orelse_scope.referenced)
        all_referenced = body_scope.referenced | orelse_scope.referenced
        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      all_referenced)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      all_referenced)

        returned_from_cond = tuple(returned_from_cond)
        if returned_from_cond:
            if len(returned_from_cond) == 1:
                cond_results = returned_from_cond[0]
            else:
                cond_results = gast.Tuple(
                    [s.ast() for s in returned_from_cond], None)

            returned_from_body = tuple(
                alias_body_map[s] if s in need_alias_in_body else s
                for s in returned_from_cond)
            returned_from_orelse = tuple(
                alias_orelse_map[s] if s in need_alias_in_orelse else s
                for s in returned_from_cond)

        else:
            # When the cond would return no value, we leave the cond called without
            # results. That in turn should trigger the side effect guards. The
            # branch functions will return a dummy value that ensures cond
            # actually has some return value as well.
            cond_results = None
            # TODO(mdan): Replace with None once side_effect_guards is retired.
            returned_from_body = (templates.replace_as_expression(
                'ag__.match_staging_level(1, cond_var_name)',
                cond_var_name=cond_var_name), )
            returned_from_orelse = (templates.replace_as_expression(
                'ag__.match_staging_level(1, cond_var_name)',
                cond_var_name=cond_var_name), )

        cond_assign = self.create_assignment(cond_var_name, node.test)
        body_def = self._create_cond_branch(
            body_name,
            aliased_orig_names=aliased_body_orig_names,
            aliased_new_names=aliased_body_new_names,
            body=node_body,
            returns=returned_from_body)
        orelse_def = self._create_cond_branch(
            orelse_name,
            aliased_orig_names=aliased_orelse_orig_names,
            aliased_new_names=aliased_orelse_new_names,
            body=node_orelse,
            returns=returned_from_orelse)
        undefined_assigns = self._create_undefined_assigns(possibly_undefined)
        composite_defs = self._create_state_functions(composites,
                                                      state_getter_name,
                                                      state_setter_name)

        cond_expr = self._create_cond_expr(cond_results, cond_var_name,
                                           body_name, orelse_name,
                                           state_getter_name,
                                           state_setter_name)

        if_ast = (undefined_assigns + composite_defs + body_def + orelse_def +
                  cond_assign + cond_expr)
        return if_ast
Beispiel #27
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        nodes, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            do_rename=False)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = nodes[0]
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace
  def visit_If(self, node):
    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

    # Note: this information needs to be extracted before the body conversion
    # that happens in the call to generic_visit below, because the conversion
    # generates nodes that lack static analysis annotations.
    need_alias_in_body = self._determine_aliased_symbols(
        body_scope, defined_in, node.body)
    need_alias_in_orelse = self._determine_aliased_symbols(
        orelse_scope, defined_in, node.orelse)

    node = self.generic_visit(node)

    modified_in_cond = body_scope.modified | orelse_scope.modified
    returned_from_cond = set()
    for s in modified_in_cond:
      if s in live_out:
        returned_from_cond.add(s)
      elif s.is_composite():
        # Special treatment for compound objects: if any of their owner entities
        # are live, then they are outputs as well.
        if live_out & s.owner_set:
          returned_from_cond.add(s)

    created_in_body = body_scope.modified & returned_from_cond - defined_in
    created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

    basic_created_in_body = tuple(
        s for s in created_in_body if not s.is_composite())
    basic_created_in_orelse = tuple(
        s for s in created_in_orelse if not s.is_composite())

    # These variables are defined only in a single branch. This is fine in
    # Python so we pass them through. Another backend, e.g. Tensorflow, may need
    # to handle these cases specially or throw an Error.
    possibly_undefined = (set(basic_created_in_body) ^
                          set(basic_created_in_orelse))

    # Alias the closure variables inside the conditional functions, to allow
    # the functions access to the respective variables.
    # We will alias variables independently for body and orelse scope,
    # because different branches might write different variables.
    aliased_body_orig_names = tuple(need_alias_in_body)
    aliased_orelse_orig_names = tuple(need_alias_in_orelse)
    aliased_body_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
        for s in aliased_body_orig_names)
    aliased_orelse_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
        for s in aliased_orelse_orig_names)

    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
    alias_orelse_map = dict(
        zip(aliased_orelse_orig_names, aliased_orelse_new_names))

    node_body = ast_util.rename_symbols(node.body, alias_body_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

    cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced)
    body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
    orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)

    returned_from_cond = tuple(returned_from_cond)
    if returned_from_cond:
      if len(returned_from_cond) == 1:
        cond_results = returned_from_cond[0]
      else:
        cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)

      returned_from_body = tuple(
          alias_body_map[s] if s in need_alias_in_body else s
          for s in returned_from_cond)
      returned_from_orelse = tuple(
          alias_orelse_map[s] if s in need_alias_in_orelse else s
          for s in returned_from_cond)

    else:
      # When the cond would return no value, we leave the cond called without
      # results. That in turn should trigger the side effect guards. The
      # branch functions will return a dummy value that ensures cond
      # actually has some return value as well.
      cond_results = None
      # TODO(mdan): Replace with None once side_effect_guards is retired.
      returned_from_body = (templates.replace_as_expression(
          'ag__.match_staging_level(1, cond_var_name)',
          cond_var_name=cond_var_name),)
      returned_from_orelse = (templates.replace_as_expression(
          'ag__.match_staging_level(1, cond_var_name)',
          cond_var_name=cond_var_name),)

    cond_assign = self.create_assignment(cond_var_name, node.test)
    body_def = self._create_cond_branch(
        body_name,
        aliased_orig_names=aliased_body_orig_names,
        aliased_new_names=aliased_body_new_names,
        body=node_body,
        returns=returned_from_body)
    orelse_def = self._create_cond_branch(
        orelse_name,
        aliased_orig_names=aliased_orelse_orig_names,
        aliased_new_names=aliased_orelse_new_names,
        body=node_orelse,
        returns=returned_from_orelse)
    undefined_assigns = self._create_undefined_assigns(possibly_undefined)

    cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name,
                                       orelse_name)

    return (undefined_assigns
            + cond_assign
            + body_def
            + orelse_def
            + cond_expr)
def class_to_graph(c, program_ctx):
  """Specialization of `entity_to_graph` for classes."""
  converted_members = {}
  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
  members = tf_inspect.getmembers(c, predicate=method_filter)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.' % c)

  class_namespace = {}
  for _, m in members:
    # Only convert the members that are directly defined by the class.
    if inspect_utils.getdefiningclass(m, c) is not c:
      continue
    node, _, namespace = function_to_graph(
        m,
        program_ctx=program_ctx,
        arg_values={},
        arg_types={'self': (c.__name__, c)},
        owner_type=c)
    if class_namespace is None:
      class_namespace = namespace
    else:
      class_namespace.update(namespace)
    converted_members[m] = node[0]
  namer = program_ctx.new_namer(class_namespace)
  class_name = namer.compiled_class_name(c.__name__, c)

  # TODO(mdan): This needs to be explained more thoroughly.
  # Process any base classes: if the superclass if of a whitelisted type, an
  # absolute import line is generated. Otherwise, it is marked for conversion
  # (as a side effect of the call to namer.compiled_class_name() followed by
  # program_ctx.update_name_map(namer)).
  output_nodes = []
  renames = {}
  base_names = []
  for base in c.__bases__:
    if isinstance(object, base):
      base_names.append('object')
      continue
    if is_whitelisted_for_graph(base):
      alias = namer.new_symbol(base.__name__, ())
      output_nodes.append(
          gast.ImportFrom(
              module=base.__module__,
              names=[gast.alias(name=base.__name__, asname=alias)],
              level=0))
    else:
      # This will trigger a conversion into a class with this name.
      alias = namer.compiled_class_name(base.__name__, base)
    base_names.append(alias)
    renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
  program_ctx.update_name_map(namer)

  # Generate the definition of the converted class.
  bases = [gast.Name(n, gast.Load(), None) for n in base_names]
  class_def = gast.ClassDef(
      class_name,
      bases=bases,
      keywords=[],
      body=list(converted_members.values()),
      decorator_list=[])
  # Make a final pass to replace references to the class or its base classes.
  # Most commonly, this occurs when making super().__init__() calls.
  # TODO(mdan): Making direct references to superclass' superclass will fail.
  class_def = qual_names.resolve(class_def)
  renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
  class_def = ast_util.rename_symbols(class_def, renames)

  output_nodes.append(class_def)

  return output_nodes, class_name, class_namespace
Beispiel #30
0
  def visit_If(self, node):
    node = self.generic_visit(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

    modified_in_cond = body_scope.modified | orelse_scope.modified
    returned_from_cond = set()
    for s in modified_in_cond:
      if s in live_out:
        returned_from_cond.add(s)
      elif s.is_composite():
        # Special treatment for compound objects: if any of their owner entities
        # are live, then they are outputs as well.
        if live_out & s.owner_set:
          returned_from_cond.add(s)

    need_alias_in_body = body_scope.modified & defined_in
    need_alias_in_orelse = orelse_scope.modified & defined_in

    created_in_body = body_scope.modified & returned_from_cond - defined_in
    created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

    if created_in_body != created_in_orelse:
      raise ValueError(
          'if statement may not initialize all variables: the true branch'
          ' creates %s, while the false branch creates %s. Make sure all'
          ' these variables are initialized either in both'
          ' branches or before the if statement.' %
          (self._fmt_symbols(created_in_body),
           self._fmt_symbols(created_in_orelse)))

    # Alias the closure variables inside the conditional functions, to allow
    # the functions access to the respective variables.
    # We will alias variables independently for body and orelse scope,
    # because different branches might write different variables.
    aliased_body_orig_names = tuple(need_alias_in_body)
    aliased_orelse_orig_names = tuple(need_alias_in_orelse)
    aliased_body_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
        for s in aliased_body_orig_names)
    aliased_orelse_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
        for s in aliased_orelse_orig_names)

    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
    alias_orelse_map = dict(
        zip(aliased_orelse_orig_names, aliased_orelse_new_names))

    node_body = ast_util.rename_symbols(node.body, alias_body_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

    returned_from_cond = tuple(returned_from_cond)
    if returned_from_cond:
      if len(returned_from_cond) == 1:
        cond_results = returned_from_cond[0]
      else:
        cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)

      returned_from_body = tuple(
          alias_body_map[s] if s in need_alias_in_body else s
          for s in returned_from_cond)
      returned_from_orelse = tuple(
          alias_orelse_map[s] if s in need_alias_in_orelse else s
          for s in returned_from_cond)

    else:
      # When the cond would return no value, we leave the cond called without
      # results. That in turn should trigger the side effect guards. The
      # branch functions will return a dummy value that ensures cond
      # actually has some return value as well.
      cond_results = None
      # TODO(mdan): This doesn't belong here; it's specific to the operator.
      returned_from_body = (templates.replace_as_expression('tf.constant(1)'),)
      returned_from_orelse = (
          templates.replace_as_expression('tf.constant(1)'),)

    body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
    orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)

    body_def = self._create_cond_branch(
        body_name,
        aliased_orig_names=aliased_body_orig_names,
        aliased_new_names=aliased_body_new_names,
        body=node_body,
        returns=returned_from_body)
    orelse_def = self._create_cond_branch(
        orelse_name,
        aliased_orig_names=aliased_orelse_orig_names,
        aliased_new_names=aliased_orelse_new_names,
        body=node_orelse,
        returns=returned_from_orelse)
    cond_expr = self._create_cond_expr(cond_results, node.test, body_name,
                                       orelse_name)

    return body_def + orelse_def + cond_expr
    def visit_If(self, node):
        node = self.generic_visit(node)

        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
        defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
        live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

        modified_in_cond = body_scope.modified | orelse_scope.modified
        returned_from_cond = set()
        for s in modified_in_cond:
            if s in live_out:
                returned_from_cond.add(s)
            elif s.is_composite():
                # Special treatment for compound objects: if any of their owner entities
                # are live, then they are outputs as well.
                if live_out & s.owner_set:
                    returned_from_cond.add(s)

        need_alias_in_body = body_scope.modified & defined_in
        need_alias_in_orelse = orelse_scope.modified & defined_in

        created_in_body = body_scope.modified & returned_from_cond - defined_in
        created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

        if created_in_body != created_in_orelse:
            raise ValueError(
                'if statement may not initialize all variables: the true branch'
                ' creates %s, while the false branch creates %s. Make sure all'
                ' these variables are initialized either in both'
                ' branches or before the if statement.' %
                (self._fmt_symbols(created_in_body),
                 self._fmt_symbols(created_in_orelse)))

        # Alias the closure variables inside the conditional functions, to allow
        # the functions access to the respective variables.
        # We will alias variables independently for body and orelse scope,
        # because different branches might write different variables.
        aliased_body_orig_names = tuple(need_alias_in_body)
        aliased_orelse_orig_names = tuple(need_alias_in_orelse)
        aliased_body_new_names = tuple(
            self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
            for s in aliased_body_orig_names)
        aliased_orelse_new_names = tuple(
            self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
            for s in aliased_orelse_orig_names)

        alias_body_map = dict(
            zip(aliased_body_orig_names, aliased_body_new_names))
        alias_orelse_map = dict(
            zip(aliased_orelse_orig_names, aliased_orelse_new_names))

        node_body = ast_util.rename_symbols(node.body, alias_body_map)
        node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

        cond_var_name = self.ctx.namer.new_symbol('cond',
                                                  body_scope.referenced)
        body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
        orelse_name = self.ctx.namer.new_symbol('if_false',
                                                orelse_scope.referenced)

        returned_from_cond = tuple(returned_from_cond)
        if returned_from_cond:
            if len(returned_from_cond) == 1:
                cond_results = returned_from_cond[0]
            else:
                cond_results = gast.Tuple(
                    [s.ast() for s in returned_from_cond], None)

            returned_from_body = tuple(
                alias_body_map[s] if s in need_alias_in_body else s
                for s in returned_from_cond)
            returned_from_orelse = tuple(
                alias_orelse_map[s] if s in need_alias_in_orelse else s
                for s in returned_from_cond)

        else:
            # When the cond would return no value, we leave the cond called without
            # results. That in turn should trigger the side effect guards. The
            # branch functions will return a dummy value that ensures cond
            # actually has some return value as well.
            cond_results = None
            # TODO(mdan): This doesn't belong here; it's specific to the operator.
            returned_from_body = (templates.replace_as_expression(
                'ag__.match_staging_level(1, cond_var_name)',
                cond_var_name=cond_var_name), )
            returned_from_orelse = (templates.replace_as_expression(
                'ag__.match_staging_level(1, cond_var_name)',
                cond_var_name=cond_var_name), )

        cond_assign = self.create_assignment(cond_var_name, node.test)
        body_def = self._create_cond_branch(
            body_name,
            aliased_orig_names=aliased_body_orig_names,
            aliased_new_names=aliased_body_new_names,
            body=node_body,
            returns=returned_from_body)
        orelse_def = self._create_cond_branch(
            orelse_name,
            aliased_orig_names=aliased_orelse_orig_names,
            aliased_new_names=aliased_orelse_new_names,
            body=node_orelse,
            returns=returned_from_orelse)
        cond_expr = self._create_cond_expr(cond_results, cond_var_name,
                                           body_name, orelse_name)

        return cond_assign + body_def + orelse_def + cond_expr
Beispiel #32
0
  def visit_While(self, node):
    self.generic_visit(node)

    self._validate_no_live_vars_created(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
    cond_closure = set()
    for s in cond_scope.used:
      for root in s.support_set:
        if root not in body_scope.created:
          cond_closure.add(root)

    state = list(body_closure)
    if not state:
      # TODO(mdan): Implement this properly.
      # To complete this statement, we need to check whether any variable
      # created inside the body scope is used before being modified outside the
      # scope. This should be done during activity analysis, and in general
      # should cover the case where variables may not be initialized.
      raise ValueError('cannot convert while loop: no outputs')

    state_ssf = [
        self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    test = ast_util.rename_symbols(node.test, ssf_map)

    # TODO(b/113118541) investigate the need-for and correctness-of extra_deps
    template = """
      def test_name(state_ssf):
        return test
      def body_name(state_ssf):
        body
        return state_ssf,
      state_ast_tuple = ag__.while_stmt(
          test_name, body_name, (state,), (extra_deps,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced),
        test=test,
        body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced),
        body=node_body,
        extra_deps=tuple(s.ast() for s in cond_closure),
    )

    return node
Beispiel #33
0
def convert_class_to_ast(c, program_ctx):
    """Specialization of `convert_entity_to_ast` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('cannot convert %s: no member methods' % c)

    # TODO(mdan): Don't clobber namespaces for each method in one class namespace.
    # The assumption that one namespace suffices for all methods only holds if
    # all methods were defined in the same module.
    # If, instead, functions are imported from multiple modules and then spliced
    # into the class, then each function has its own globals and __future__
    # imports that need to stay separate.

    # For example, C's methods could both have `global x` statements referring to
    # mod1.x and mod2.x, but using one namespace for C would cause a conflict.
    # from mod1 import f1
    # from mod2 import f2
    # class C(object):
    #   method1 = f1
    #   method2 = f2

    class_namespace = {}
    future_features = None
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        (node, ), _, entity_info = convert_func_to_ast(m,
                                                       program_ctx=program_ctx,
                                                       do_rename=False)
        class_namespace.update(entity_info.namespace)
        converted_members[m] = node

        # TODO(mdan): Similarly check the globals.
        if future_features is None:
            future_features = entity_info.future_features
        elif frozenset(future_features) ^ frozenset(
                entity_info.future_features):
            # Note: we can support this case if ever needed.
            raise ValueError(
                'cannot convert {}: if has methods built with mismatched future'
                ' features: {} and {}'.format(c, future_features,
                                              entity_info.future_features))
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    # TODO(mdan): Find a way better than forging this object.
    entity_info = transformer.EntityInfo(source_code=None,
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=class_namespace)

    return output_nodes, class_name, entity_info
Beispiel #34
0
def convert_class_to_ast(c, program_ctx):
  """Specialization of `convert_entity_to_ast` for classes."""
  # TODO(mdan): Revisit this altogether. Not sure we still need it.
  converted_members = {}
  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
  members = tf_inspect.getmembers(c, predicate=method_filter)
  if not members:
    raise ValueError('cannot convert %s: no member methods' % c)

  # TODO(mdan): Don't clobber namespaces for each method in one class namespace.
  # The assumption that one namespace suffices for all methods only holds if
  # all methods were defined in the same module.
  # If, instead, functions are imported from multiple modules and then spliced
  # into the class, then each function has its own globals and __future__
  # imports that need to stay separate.

  # For example, C's methods could both have `global x` statements referring to
  # mod1.x and mod2.x, but using one namespace for C would cause a conflict.
  # from mod1 import f1
  # from mod2 import f2
  # class C(object):
  #   method1 = f1
  #   method2 = f2

  class_namespace = {}
  future_features = None
  for _, m in members:
    # Only convert the members that are directly defined by the class.
    if inspect_utils.getdefiningclass(m, c) is not c:
      continue
    (node,), _, entity_info = convert_func_to_ast(
        m,
        program_ctx=program_ctx,
        do_rename=False)
    class_namespace.update(entity_info.namespace)
    converted_members[m] = node

    # TODO(mdan): Similarly check the globals.
    if future_features is None:
      future_features = entity_info.future_features
    elif frozenset(future_features) ^ frozenset(entity_info.future_features):
      # Note: we can support this case if ever needed.
      raise ValueError(
          'cannot convert {}: if has methods built with mismatched future'
          ' features: {} and {}'.format(c, future_features,
                                        entity_info.future_features))
  namer = naming.Namer(class_namespace)
  class_name = namer.class_name(c.__name__)

  # Process any base classes: if the superclass if of a whitelisted type, an
  # absolute import line is generated.
  output_nodes = []
  renames = {}
  base_names = []
  for base in c.__bases__:
    if isinstance(object, base):
      base_names.append('object')
      continue
    if is_whitelisted_for_graph(base):
      alias = namer.new_symbol(base.__name__, ())
      output_nodes.append(
          gast.ImportFrom(
              module=base.__module__,
              names=[gast.alias(name=base.__name__, asname=alias)],
              level=0))
    else:
      raise NotImplementedError(
          'Conversion of classes that do not directly extend classes from'
          ' whitelisted modules is temporarily suspended. If this breaks'
          ' existing code please notify the AutoGraph team immediately.')
    base_names.append(alias)
    renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

  # Generate the definition of the converted class.
  bases = [gast.Name(n, gast.Load(), None) for n in base_names]
  class_def = gast.ClassDef(
      class_name,
      bases=bases,
      keywords=[],
      body=list(converted_members.values()),
      decorator_list=[])
  # Make a final pass to replace references to the class or its base classes.
  # Most commonly, this occurs when making super().__init__() calls.
  # TODO(mdan): Making direct references to superclass' superclass will fail.
  class_def = qual_names.resolve(class_def)
  renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
  class_def = ast_util.rename_symbols(class_def, renames)

  output_nodes.append(class_def)

  # TODO(mdan): Find a way better than forging this object.
  entity_info = transformer.EntityInfo(
      source_code=None,
      source_file=None,
      future_features=future_features,
      namespace=class_namespace)

  return output_nodes, class_name, entity_info
Beispiel #35
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        node, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            owner_type=c)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = node[0]
    namer = program_ctx.new_namer(class_namespace)
    class_name = namer.compiled_class_name(c.__name__, c)

    # TODO(mdan): This needs to be explained more thoroughly.
    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated. Otherwise, it is marked for conversion
    # (as a side effect of the call to namer.compiled_class_name() followed by
    # program_ctx.update_name_map(namer)).
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            # This will trigger a conversion into a class with this name.
            alias = namer.compiled_class_name(base.__name__, base)
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
    program_ctx.update_name_map(namer)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace
Beispiel #36
0
    def visit_While(self, node):
        self.generic_visit(node)

        self._validate_no_live_vars_created(node)

        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        body_closure = body_scope.modified - body_scope.created
        all_referenced = body_scope.referenced

        cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
        cond_closure = set()
        for s in cond_scope.used:
            for root in s.support_set:
                if root not in body_scope.created:
                    cond_closure.add(root)

        state = list(body_closure)
        if not state:
            # TODO(mdan): Implement this properly.
            # To complete this statement, we need to check whether any variable
            # created inside the body scope is used before being modified outside the
            # scope. This should be done during activity analysis, and in general
            # should cover the case where variables may not be initialized.
            raise ValueError('cannot convert while loop: no outputs')

        state_ssf = [
            self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
        ]
        ssf_map = {
            name: ssf
            for name, ssf in zip(state, state_ssf) if str(name) != ssf
        }

        if len(state) == 1:
            state = state[0]
            state_ssf = state_ssf[0]
            state_ast_tuple = state
        else:
            state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

        node_body = ast_util.rename_symbols(node.body, ssf_map)
        test = ast_util.rename_symbols(node.test, ssf_map)

        # TODO(b/113118541) investigate the need-for and correctness-of extra_deps
        template = """
      def test_name(state_ssf):
        return test
      def body_name(state_ssf):
        body
        return state_ssf,
      state_ast_tuple = ag__.while_stmt(
          test_name, body_name, (state,), (extra_deps,))
    """
        node = templates.replace(
            template,
            state=state,
            state_ssf=state_ssf,
            state_ast_tuple=state_ast_tuple,
            test_name=self.ctx.namer.new_symbol('loop_test',
                                                body_scope.referenced),
            test=test,
            body_name=self.ctx.namer.new_symbol('loop_body',
                                                body_scope.referenced),
            body=node_body,
            extra_deps=tuple(s.ast() for s in cond_closure),
        )

        return node