def visit_While(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)
    if not state:
      # TODO(mdan): Implement this properly.
      # To complete this statement, we need to check whether any variable
      # created inside the body scope is used before being modified outside the
      # scope. This should be done during activity analysis, and in general
      # should cover the case where variables may not be initialized.
      raise ValueError('cannot convert while loop: no outputs')

    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    test = ast_util.rename_symbols(node.test, ssf_map)

    template = """
      def test_name(state_ssf):
        return test
      def body_name(state_ssf):
        body
        return state_ssf,
      state_ast_tuple = py2tf_utils.run_while(test_name, body_name, [state])
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        test_name=self.context.namer.new_symbol('loop_test',
                                                body_scope.referenced),
        test=test,
        body_name=self.context.namer.new_symbol('loop_body',
                                                body_scope.referenced),
        body=node_body)

    return node
    def visit_While(self, node):
        self.generic_visit(node)

        body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        body_closure = body_scope.modified - body_scope.created
        all_referenced = body_scope.referenced

        state = list(body_closure)
        if not state:
            # TODO(mdan): Implement this properly.
            # To complete this statement, we need to check whether any variable
            # created inside the body scope is used before being modified outside the
            # scope. This should be done during activity analysis, and in general
            # should cover the case where variables may not be initialized.
            raise ValueError('cannot convert while loop: no outputs')

        state_ssf = [
            self.context.namer.new_symbol(s.ssf(), all_referenced)
            for s in state
        ]
        ssf_map = {
            name: ssf
            for name, ssf in zip(state, state_ssf) if str(name) != ssf
        }

        if len(state) == 1:
            state = state[0]
            state_ssf = state_ssf[0]
            state_ast_tuple = state
        else:
            state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

        node_body = ast_util.rename_symbols(node.body, ssf_map)
        test = ast_util.rename_symbols(node.test, ssf_map)

        template = """
      def test_name(state_ssf):
        return test
      def body_name(state_ssf):
        body
        return state_ssf,
      state_ast_tuple = py2tf_utils.run_while(test_name, body_name, [state])
    """
        node = templates.replace(
            template,
            state=state,
            state_ssf=state_ssf,
            state_ast_tuple=state_ast_tuple,
            test_name=self.context.namer.new_symbol('loop_test',
                                                    body_scope.referenced),
            test=test,
            body_name=self.context.namer.new_symbol('loop_body',
                                                    body_scope.referenced),
            body=node_body)

        return node
  def visit_While(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)
    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    test = ast_util.rename_symbols(node.test, ssf_map)

    template = """
      def test_name(state_ssf):
        return test
      def body_name(state_ssf):
        body
        return state_ssf,
      state_ast_tuple = tf.while_loop(test_name, body_name, [state])
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        test_name=self.context.namer.new_symbol('loop_test',
                                                body_scope.referenced),
        test=test,
        body_name=self.context.namer.new_symbol('loop_body',
                                                body_scope.referenced),
        body=node_body)

    return node
    def visit_While(self, node):
        self.generic_visit(node)

        body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        body_closure = body_scope.modified - body_scope.created
        all_referenced = body_scope.referenced

        state = list(body_closure)
        state_ssf = [
            self.context.namer.new_symbol(s.ssf(), all_referenced)
            for s in state
        ]
        ssf_map = {
            name: ssf
            for name, ssf in zip(state, state_ssf) if str(name) != ssf
        }

        if len(state) == 1:
            state = state[0]
            state_ssf = state_ssf[0]
            state_ast_tuple = state
        else:
            state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

        node_body = ast_util.rename_symbols(node.body, ssf_map)
        test = ast_util.rename_symbols(node.test, ssf_map)

        template = """
      def test_name(state_ssf):
        return test
      def body_name(state_ssf):
        body
        return state_ssf,
      state_ast_tuple = py2tf_utils.run_while(test_name, body_name, [state])
    """
        node = templates.replace(
            template,
            state=state,
            state_ssf=state_ssf,
            state_ast_tuple=state_ast_tuple,
            test_name=self.context.namer.new_symbol('loop_test',
                                                    body_scope.referenced),
            test=test,
            body_name=self.context.namer.new_symbol('loop_body',
                                                    body_scope.referenced),
            body=node_body)

        return node
 def _visit_and_reindent(self, nodes):
   new_nodes = []
   current_dest = new_nodes
   alias_map = {}
   reindent_requested = False
   for n in nodes:
     n = self.visit(n)
     # NOTE: the order in which these statements execute is important; in
     # particular, watch out for ending up with cycles in the AST.
     if alias_map:
       n = ast_util.rename_symbols(n, alias_map)
     if isinstance(n, (list, tuple)):
       current_dest.extend(n)
     else:
       current_dest.append(n)
     if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
       reindent_requested = True
       new_dest, new_alias_map = anno.getanno(
           n, anno.Basic.INDENT_BLOCK_REMAINDER)
       anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
       new_alias_map.update(alias_map)
       alias_map = new_alias_map
       current_dest = new_dest
   if reindent_requested and not current_dest:
     # TODO(mdan): There may still be something that could be done.
     raise ValueError('Unable to insert statement into the computation flow: '
                      'it is not followed by any computation which '
                      'the statement could gate.')
   return new_nodes
Exemple #6
0
 def _visit_and_reindent(self, nodes):
     new_nodes = []
     current_dest = new_nodes
     alias_map = {}
     reindent_requested = False
     for n in nodes:
         n = self.visit(n)
         # NOTE: the order in which these statements execute is important; in
         # particular, watch out for ending up with cycles in the AST.
         if alias_map:
             n = ast_util.rename_symbols(n, alias_map)
         if isinstance(n, (list, tuple)):
             current_dest.extend(n)
         else:
             current_dest.append(n)
         if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
             reindent_requested = True
             new_dest, new_alias_map = anno.getanno(
                 n, anno.Basic.INDENT_BLOCK_REMAINDER)
             anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
             new_alias_map.update(alias_map)
             alias_map = new_alias_map
             current_dest = new_dest
     if reindent_requested and not current_dest:
         # TODO(mdan): There may still be something that could be done.
         raise ValueError(
             'Unable to insert statement into the computation flow: '
             'it is not followed by any computation which '
             'the statement could gate.')
     return new_nodes
    def test_rename_symbols(self):
        node = ast.Tuple([
            ast.Name('a', ast.Load()),
            ast.Name('b', ast.Load()),
            ast.Attribute(ast.Name('b', None), 'c', ast.Store()),
            ast.Attribute(ast.Attribute(ast.Name('b', None), 'c', ast.Load()),
                          'd', None)
        ], None)
        node = qual_names.resolve(node)
        node = ast_util.rename_symbols(
            node, {
                qual_names.QN('a'): qual_names.QN('renamed_a'),
                qual_names.QN('b.c'): qual_names.QN('renamed_b_c'),
            })

        self.assertEqual(node.elts[0].id, 'renamed_a')
        self.assertTrue(isinstance(node.elts[0].ctx, ast.Load))
        self.assertEqual(node.elts[1].id, 'b')
        self.assertEqual(node.elts[2].id, 'renamed_b_c')
        self.assertTrue(isinstance(node.elts[2].ctx, ast.Store))
        self.assertEqual(node.elts[3].value.id, 'renamed_b_c')
        self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
  def test_rename_symbols(self):
    node = ast.Tuple([
        ast.Name('a', ast.Load()),
        ast.Name('b', ast.Load()),
        ast.Attribute(ast.Name('b', None), 'c', ast.Store()),
        ast.Attribute(
            ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd',
            None)
    ], None)
    node = qual_names.resolve(node)
    node = ast_util.rename_symbols(
        node,
        {
            qual_names.QN('a'): qual_names.QN('renamed_a'),
            qual_names.QN('b.c'): qual_names.QN('renamed_b_c'),
        })

    self.assertEqual(node.elts[0].id, 'renamed_a')
    self.assertTrue(isinstance(node.elts[0].ctx, ast.Load))
    self.assertEqual(node.elts[1].id, 'b')
    self.assertEqual(node.elts[2].id, 'renamed_b_c')
    self.assertTrue(isinstance(node.elts[2].ctx, ast.Store))
    self.assertEqual(node.elts[3].value.id, 'renamed_b_c')
    self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
    def visit_If(self, node):
        self.generic_visit(node)

        body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)

        if body_scope.created - orelse_scope.created:
            raise ValueError(
                'The if branch creates new symbols that the else branch does not.'
            )
        if orelse_scope.created - body_scope.created:
            raise ValueError(
                'The else branch creates new symbols that the if branch does not.'
            )

        modified = tuple(body_scope.modified | orelse_scope.modified)
        all_referenced = body_scope.referenced | orelse_scope.referenced

        # Alias the closure variables inside the conditional functions
        # to avoid errors caused by the local variables created in the branch
        # functions.
        need_alias = ((body_scope.modified | orelse_scope.modified) -
                      (body_scope.created | orelse_scope.created))
        aliased_orig_names = tuple(need_alias)
        aliased_new_names = tuple(
            self.context.namer.new_symbol(s.ssf(), all_referenced)
            for s in aliased_orig_names)
        alias_map = dict(zip(aliased_orig_names, aliased_new_names))
        node_body = ast_util.rename_symbols(node.body, alias_map)
        node_orelse = ast_util.rename_symbols(node.orelse, alias_map)

        if not modified:
            # When the cond would return no value, we leave the cond called without
            # results. That in turn should trigger the side effect guards. The
            # branch functions will return a dummy value that ensures cond
            # actually has some return value as well.
            results = None
        elif len(modified) == 1:
            results = modified[0]
        else:
            results = gast.Tuple([s.ast() for s in modified], None)

        body_name = self.context.namer.new_symbol('if_true', all_referenced)
        orelse_name = self.context.namer.new_symbol('if_false', all_referenced)
        if modified:
            body_returns = tuple(alias_map[s] if s in aliased_orig_names else s
                                 for s in modified)
        else:
            body_returns = templates.replace('tf.ones(())')[0].value

        body_def = self._create_cond_branch(
            body_name,
            aliased_orig_names=tuple(aliased_orig_names),
            aliased_new_names=tuple(aliased_new_names),
            body=node_body,
            returns=body_returns)
        orelse_def = self._create_cond_branch(
            orelse_name,
            aliased_orig_names=tuple(aliased_orig_names),
            aliased_new_names=tuple(aliased_new_names),
            body=node_orelse,
            returns=body_returns)
        cond_expr = self._create_cond_expr(results, node.test, body_name,
                                           orelse_name)

        return body_def + orelse_def + cond_expr
  def visit_If(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)

    if body_scope.created - orelse_scope.created:
      raise ValueError(
          'The if branch creates new symbols that the else branch does not.')
    if orelse_scope.created - body_scope.created:
      raise ValueError(
          'The else branch creates new symbols that the if branch does not.')

    modified = tuple(body_scope.modified | orelse_scope.modified)
    all_referenced = body_scope.referenced | orelse_scope.referenced

    # Alias the closure variables inside the conditional functions
    # to avoid errors caused by the local variables created in the branch
    # functions.
    need_alias = (
        (body_scope.modified | orelse_scope.modified) -
        (body_scope.created | orelse_scope.created))
    aliased_orig_names = tuple(need_alias)
    aliased_new_names = tuple(
        self.context.namer.new_symbol(s.ssf(), all_referenced)
        for s in aliased_orig_names)
    alias_map = dict(zip(aliased_orig_names, aliased_new_names))
    node_body = ast_util.rename_symbols(node.body, alias_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_map)

    if not modified:
      # When the cond would return no value, we leave the cond called without
      # results. That in turn should trigger the side effect guards. The
      # branch functions will return a dummy value that ensures cond
      # actually has some return value as well.
      results = None
    elif len(modified) == 1:
      results = modified[0]
    else:
      results = gast.Tuple([s.ast() for s in modified], None)

    body_name = self.context.namer.new_symbol('if_true', all_referenced)
    orelse_name = self.context.namer.new_symbol('if_false', all_referenced)
    if modified:
      body_returns = tuple(
          alias_map[s] if s in aliased_orig_names else s for s in modified)
    else:
      body_returns = templates.replace('tf.ones(())')[0].value

    body_def = self._create_cond_branch(
        body_name,
        aliased_orig_names=tuple(aliased_orig_names),
        aliased_new_names=tuple(aliased_new_names),
        body=node_body,
        returns=body_returns)
    orelse_def = self._create_cond_branch(
        orelse_name,
        aliased_orig_names=tuple(aliased_orig_names),
        aliased_new_names=tuple(aliased_new_names),
        body=node_orelse,
        returns=body_returns)
    cond_expr = self._create_cond_expr(results, node.test, body_name,
                                       orelse_name)

    return body_def + orelse_def + cond_expr
  def visit_If(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)

    if body_scope.created - orelse_scope.created:
      raise ValueError(
          'The if branch creates new symbols that the else branch does not.')
    if orelse_scope.created - body_scope.created:
      raise ValueError(
          'The else branch creates new symbols that the if branch does not.')

    all_modified = tuple(body_scope.modified | orelse_scope.modified)
    all_referenced = body_scope.referenced | orelse_scope.referenced

    # Alias the closure variables inside the conditional functions
    # to avoid errors caused by the local variables created in the branch
    # functions.
    need_alias = (
        (body_scope.modified | orelse_scope.modified) -
        (body_scope.created | orelse_scope.created))
    aliased_orig_names = tuple(need_alias)
    aliased_new_names = tuple(
        self.context.namer.new_symbol(s.ssf(), all_referenced)
        for s in aliased_orig_names)
    alias_map = dict(zip(aliased_orig_names, aliased_new_names))
    node_body = ast_util.rename_symbols(node.body, alias_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_map)

    if len(all_modified) == 1:
      results = all_modified[0]
    else:
      results = gast.Tuple([s.ast() for s in all_modified], None)

    if aliased_orig_names:
      template = """
        def body_name():
          aliased_new_names, = aliased_orig_names,
          body
          return (all_results,)
        def orelse_name():
          aliased_new_names, = aliased_orig_names,
          orelse
          return (all_results,)
        results = tf.cond(test, body_name, orelse_name)
      """
      body_name = self.context.namer.new_symbol('if_true', all_referenced)
      return templates.replace(
          template,
          test=node.test,
          body_name=body_name,
          body=node_body,
          orelse_name=self.context.namer.new_symbol('if_false', all_referenced),
          orelse=node_orelse,
          aliased_orig_names=tuple(aliased_orig_names),
          aliased_new_names=tuple(aliased_new_names),
          all_results=tuple(alias_map[s] if s in aliased_orig_names else s
                            for s in all_modified),
          results=results)
    else:
      template = """
        def body_name():
          body
          return (all_results,)
        def orelse_name():
          orelse
          return (all_results,)
        results = tf.cond(test, body_name, orelse_name)
      """
      body_name = self.context.namer.new_symbol('if_true', all_referenced)
      return templates.replace(
          template,
          test=node.test,
          body_name=body_name,
          body=node_body,
          orelse_name=self.context.namer.new_symbol('if_false', all_referenced),
          orelse=node_orelse,
          all_results=tuple(s for s in all_modified),
          results=results)
Exemple #12
0
    def visit_If(self, node):
        self.generic_visit(node)

        body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)

        if body_scope.created - orelse_scope.created:
            raise ValueError(
                'The if branch creates new symbols that the else branch does not.'
            )
        if orelse_scope.created - body_scope.created:
            raise ValueError(
                'The else branch creates new symbols that the if branch does not.'
            )

        all_modified = tuple(body_scope.modified | orelse_scope.modified)
        all_referenced = body_scope.referenced | orelse_scope.referenced

        # Alias the closure variables inside the conditional functions
        # to avoid errors caused by the local variables created in the branch
        # functions.
        need_alias = ((body_scope.modified | orelse_scope.modified) -
                      (body_scope.created | orelse_scope.created))
        aliased_orig_names = tuple(need_alias)
        aliased_new_names = tuple(
            self.context.namer.new_symbol(s.ssf(), all_referenced)
            for s in aliased_orig_names)
        alias_map = dict(zip(aliased_orig_names, aliased_new_names))
        node_body = ast_util.rename_symbols(node.body, alias_map)
        node_orelse = ast_util.rename_symbols(node.orelse, alias_map)

        if len(all_modified) == 1:
            results = all_modified[0]
        else:
            results = gast.Tuple([s.ast() for s in all_modified], None)

        if aliased_orig_names:
            template = """
        def body_name():
          aliased_new_names, = aliased_orig_names,
          body
          return (all_results,)
        def orelse_name():
          aliased_new_names, = aliased_orig_names,
          orelse
          return (all_results,)
        results = tf.cond(test, body_name, orelse_name)
      """
            body_name = self.context.namer.new_symbol('if_true',
                                                      all_referenced)
            return templates.replace(
                template,
                test=node.test,
                body_name=body_name,
                body=node_body,
                orelse_name=self.context.namer.new_symbol(
                    'if_false', all_referenced),
                orelse=node_orelse,
                aliased_orig_names=tuple(aliased_orig_names),
                aliased_new_names=tuple(aliased_new_names),
                all_results=tuple(
                    alias_map[s] if s in aliased_orig_names else s
                    for s in all_modified),
                results=results)
        else:
            template = """
        def body_name():
          body
          return (all_results,)
        def orelse_name():
          orelse
          return (all_results,)
        results = tf.cond(test, body_name, orelse_name)
      """
            body_name = self.context.namer.new_symbol('if_true',
                                                      all_referenced)
            return templates.replace(template,
                                     test=node.test,
                                     body_name=body_name,
                                     body=node_body,
                                     orelse_name=self.context.namer.new_symbol(
                                         'if_false', all_referenced),
                                     orelse=node_orelse,
                                     all_results=tuple(s
                                                       for s in all_modified),
                                     results=results)