예제 #1
0
  def visit_For(self, node):
    self.generic_visit(node)

    self._validate_no_live_vars_created(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iter_=node.iter,
        iterate=node.target,
        extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
        extra_test_expr=extra_test,
        body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
예제 #2
0
    def visit_For(self, node):
        self.generic_visit(node)

        self._validate_no_live_vars_created(node)

        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        body_closure = body_scope.modified - body_scope.created
        all_referenced = body_scope.referenced

        state = list(body_closure)

        state_ssf = [
            self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
        ]
        ssf_map = {
            name: ssf
            for name, ssf in zip(state, state_ssf) if str(name) != ssf
        }

        if len(state) == 1:
            state = state[0]
            state_ssf = state_ssf[0]
            state_ast_tuple = state
        else:
            state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

        node_body = ast_util.rename_symbols(node.body, ssf_map)
        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        else:
            extra_test = parser.parse_expression('True')

        template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
        node = templates.replace(
            template,
            state=state,
            state_ssf=state_ssf,
            state_ast_tuple=state_ast_tuple,
            iter_=node.iter,
            iterate=node.target,
            extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                      all_referenced),
            extra_test_expr=extra_test,
            body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
            body=node_body)

        return node
예제 #3
0
  def visit_For(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_cond'):
      extra_cond = anno.getanno(node, 'extra_cond')
      extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
    else:
      extra_cond = parser.parse_expression('True')

    template = """
      def extra_cond_name(state_ssf):
        return extra_cond_expr
      def body_name(iterate, state_ssf):
        body
        return state_ssf,
      state_ast_tuple = ag__.for_loop(
          iterated, extra_cond_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iterated=node.iter,
        iterate=node.target,
        extra_cond_name=self.context.namer.new_symbol('extra_cond',
                                                      all_referenced),
        extra_cond_expr=extra_cond,
        body_name=self.context.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
예제 #4
0
  def visit_For(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_cond'):
      extra_cond = anno.getanno(node, 'extra_cond')
      extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
    else:
      extra_cond = parser.parse_expression('True')

    template = """
      def extra_cond_name(state_ssf):
        return extra_cond_expr
      def body_name(iterate, state_ssf):
        body
        return state_ssf,
      state_ast_tuple = __ops.for_loop(
          iterated, extra_cond_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iterated=node.iter,
        iterate=node.target,
        extra_cond_name=self.context.namer.new_symbol('extra_cond',
                                                      all_referenced),
        extra_cond_expr=extra_cond,
        body_name=self.context.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
예제 #5
0
 def _visit_and_reindent(self, nodes):
   new_nodes = []
   current_dest = new_nodes
   alias_map = {}
   reindent_requested = False
   for n in nodes:
     n = self.visit(n)
     # NOTE: the order in which these statements execute is important; in
     # particular, watch out for ending up with cycles in the AST.
     if alias_map:
       n = ast_util.rename_symbols(n, alias_map)
     if isinstance(n, (list, tuple)):
       current_dest.extend(n)
     else:
       current_dest.append(n)
     if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
       reindent_requested = True
       new_dest, new_alias_map = anno.getanno(
           n, anno.Basic.INDENT_BLOCK_REMAINDER)
       anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
       new_alias_map.update(alias_map)
       alias_map = new_alias_map
       current_dest = new_dest
   if reindent_requested and not current_dest:
     # TODO(mdan): There may still be something that could be done.
     raise ValueError('Unable to insert statement into the computation flow: '
                      'it is not followed by any computation which '
                      'the statement could gate.')
   return new_nodes
예제 #6
0
    def test_rename_symbols_basic(self):
        node = parser.parse_str('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

        self.assertIsInstance(node.body[0].value.left.id, str)
        self.assertEqual(compiler.ast_to_source(node).strip(), 'renamed_a + b')
예제 #7
0
  def test_rename_symbols_attributes(self):
    node = parser.parse_str('b.c = b.c.d')
    node = qual_names.resolve(node)

    node = ast_util.rename_symbols(
        node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

    source = compiler.ast_to_source(node)
    self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
예제 #8
0
    def test_rename_symbols_attributes(self):
        node = parser.parse_str('b.c = b.c.d')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

        source, _ = compiler.ast_to_source(node)
        self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
예제 #9
0
    def test_rename_symbols_annotations(self):
        node = parser.parse_str('a[i]')
        node = qual_names.resolve(node)
        anno.setanno(node, 'foo', 'bar')
        orig_anno = anno.getanno(node, 'foo')

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('b')})

        self.assertIs(anno.getanno(node, 'foo'), orig_anno)
예제 #10
0
  def test_rename_symbols_basic(self):
    node = parser.parse_str('a + b')
    node = qual_names.resolve(node)

    node = ast_util.rename_symbols(
        node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

    self.assertIsInstance(node.body[0].value.left.id, str)
    source = compiler.ast_to_source(node)
    self.assertEqual(source.strip(), 'renamed_a + b')
예제 #11
0
  def test_rename_symbols_annotations(self):
    node = parser.parse_str('a[i]')
    node = qual_names.resolve(node)
    anno.setanno(node, 'foo', 'bar')
    orig_anno = anno.getanno(node, 'foo')

    node = ast_util.rename_symbols(node,
                                   {qual_names.QN('a'): qual_names.QN('b')})

    self.assertIs(anno.getanno(node, 'foo'), orig_anno)
예제 #12
0
  def test_rename_symbols(self):
    node = ast.Tuple([
        ast.Name('a', ast.Load()),
        ast.Name('b', ast.Load()),
        ast.Attribute(ast.Name('b', None), 'c', ast.Store()),
        ast.Attribute(
            ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None)
    ], None)
    node = qual_names.resolve(node)
    node = ast_util.rename_symbols(
        node, {
            qual_names.QN('a'):
                qual_names.QN('renamed_a'),
            qual_names.QN(qual_names.QN('b'), attr='c'):
                qual_names.QN('renamed_b_c'),
        })

    self.assertEqual(node.elts[0].id, 'renamed_a')
    self.assertTrue(isinstance(node.elts[0].ctx, ast.Load))
    self.assertEqual(node.elts[1].id, 'b')
    self.assertEqual(node.elts[2].id, 'renamed_b_c')
    self.assertTrue(isinstance(node.elts[2].ctx, ast.Store))
    self.assertEqual(node.elts[3].value.id, 'renamed_b_c')
    self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
예제 #13
0
    def test_rename_symbols(self):
        node = ast.Tuple([
            ast.Name('a', ast.Load()),
            ast.Name('b', ast.Load()),
            ast.Attribute(ast.Name('b', None), 'c', ast.Store()),
            ast.Attribute(ast.Attribute(ast.Name('b', None), 'c', ast.Load()),
                          'd', None)
        ], None)
        node = qual_names.resolve(node)
        node = ast_util.rename_symbols(
            node, {
                qual_names.QN('a'):
                qual_names.QN('renamed_a'),
                qual_names.QN(qual_names.QN('b'), attr='c'):
                qual_names.QN('renamed_b_c'),
            })

        self.assertEqual(node.elts[0].id, 'renamed_a')
        self.assertTrue(isinstance(node.elts[0].ctx, ast.Load))
        self.assertEqual(node.elts[1].id, 'b')
        self.assertEqual(node.elts[2].id, 'renamed_b_c')
        self.assertTrue(isinstance(node.elts[2].ctx, ast.Store))
        self.assertEqual(node.elts[3].value.id, 'renamed_b_c')
        self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
예제 #14
0
  def visit_While(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE)
    cond_closure = set()
    for s in cond_scope.referenced:
      for root in s.support_set:
        if root not in body_scope.created:
          cond_closure.add(root)

    state = list(body_closure)
    if not state:
      # TODO (mdan): Implement this properly. id:486
      # https://github.com/imdone/tensorflow/issues/487
      # To complete this statement, we need to check whether any variable
      # created inside the body scope is used before being modified outside the
      # scope. This should be done during activity analysis, and in general
      # should cover the case where variables may not be initialized.
      raise ValueError('cannot convert while loop: no outputs')

    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    test = ast_util.rename_symbols(node.test, ssf_map)

    template = """
      def test_name(state_ssf):
        return test
      def body_name(state_ssf):
        body
        return state_ssf,
      state_ast_tuple = ag__.while_loop(
          test_name, body_name, (state,), (extra_deps,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        test_name=self.context.namer.new_symbol('loop_test',
                                                body_scope.referenced),
        test=test,
        body_name=self.context.namer.new_symbol('loop_body',
                                                body_scope.referenced),
        body=node_body,
        extra_deps=tuple(s.ast() for s in cond_closure),
    )

    return node
예제 #15
0
  def visit_If(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)

    if body_scope.created - orelse_scope.created:
      raise ValueError(
          'The if branch creates new symbols that the else branch does not.')
    if orelse_scope.created - body_scope.created:
      raise ValueError(
          'The else branch creates new symbols that the if branch does not.')

    modified = tuple(body_scope.modified | orelse_scope.modified)
    all_referenced = body_scope.referenced | orelse_scope.referenced

    # Alias the closure variables inside the conditional functions
    # to avoid errors caused by the local variables created in the branch
    # functions.
    need_alias = (
        (body_scope.modified | orelse_scope.modified) -
        (body_scope.created | orelse_scope.created))
    aliased_orig_names = tuple(need_alias)
    aliased_new_names = tuple(
        self.context.namer.new_symbol(s.ssf(), all_referenced)
        for s in aliased_orig_names)
    alias_map = dict(zip(aliased_orig_names, aliased_new_names))
    node_body = ast_util.rename_symbols(node.body, alias_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_map)

    if not modified:
      # When the cond would return no value, we leave the cond called without
      # results. That in turn should trigger the side effect guards. The
      # branch functions will return a dummy value that ensures cond
      # actually has some return value as well.
      results = None
    elif len(modified) == 1:
      results = modified[0]
    else:
      results = gast.Tuple([s.ast() for s in modified], None)

    body_name = self.context.namer.new_symbol('if_true', all_referenced)
    orelse_name = self.context.namer.new_symbol('if_false', all_referenced)
    if modified:
      body_returns = tuple(
          alias_map[s] if s in aliased_orig_names else s for s in modified)
    else:
      body_returns = templates.replace('tf.ones(())')[0].value

    body_def = self._create_cond_branch(
        body_name,
        aliased_orig_names=tuple(aliased_orig_names),
        aliased_new_names=tuple(aliased_new_names),
        body=node_body,
        returns=body_returns)
    orelse_def = self._create_cond_branch(
        orelse_name,
        aliased_orig_names=tuple(aliased_orig_names),
        aliased_new_names=tuple(aliased_new_names),
        body=node_orelse,
        returns=body_returns)
    cond_expr = self._create_cond_expr(results, node.test, body_name,
                                       orelse_name)

    return body_def + orelse_def + cond_expr
예제 #16
0
    def visit_If(self, node):
        self.generic_visit(node)

        body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)
        body_defs = body_scope.created | body_scope.modified
        orelse_defs = orelse_scope.created | orelse_scope.modified
        live = anno.getanno(node, 'live_out')

        # We'll need to check if we're closing over variables that are defined
        # elsewhere in the function
        # NOTE: we can only detect syntactic closure in the scope
        # of the code passed in. If the AutoGraph'd function itself closes
        # over other variables, this analysis won't take that into account.
        defined = anno.getanno(node, 'defined_in')

        # We only need to return variables that are
        # - modified by one or both branches
        # - live (or has a live parent) at the end of the conditional
        modified = []
        for def_ in body_defs | orelse_defs:
            def_with_parents = set((def_, )) | def_.support_set
            if live & def_with_parents:
                modified.append(def_)

        # We need to check if live created variables are balanced
        # in both branches
        created = live & (body_scope.created | orelse_scope.created)

        # The if statement is illegal if there are variables that are created,
        # that are also live, but both branches don't create them.
        if created:
            if created != (body_scope.created & live):
                raise ValueError(
                    'The main branch does not create all live symbols that the else '
                    'branch does.')
            if created != (orelse_scope.created & live):
                raise ValueError(
                    'The else branch does not create all live symbols that the main '
                    'branch does.')

        # Alias the closure variables inside the conditional functions
        # to avoid errors caused by the local variables created in the branch
        # functions.
        # We will alias variables independently for body and orelse scope,
        # because different branches might write different variables.
        aliased_body_orig_names = tuple(body_scope.modified -
                                        body_scope.created)
        aliased_orelse_orig_names = tuple(orelse_scope.modified -
                                          orelse_scope.created)
        aliased_body_new_names = tuple(
            self.context.namer.new_symbol(s.ssf(), body_scope.referenced)
            for s in aliased_body_orig_names)
        aliased_orelse_new_names = tuple(
            self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced)
            for s in aliased_orelse_orig_names)

        alias_body_map = dict(
            zip(aliased_body_orig_names, aliased_body_new_names))
        alias_orelse_map = dict(
            zip(aliased_orelse_orig_names, aliased_orelse_new_names))

        node_body = ast_util.rename_symbols(node.body, alias_body_map)
        node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

        if not modified:
            # When the cond would return no value, we leave the cond called without
            # results. That in turn should trigger the side effect guards. The
            # branch functions will return a dummy value that ensures cond
            # actually has some return value as well.
            results = None
        elif len(modified) == 1:
            results = modified[0]
        else:
            results = gast.Tuple([s.ast() for s in modified], None)

        body_name = self.context.namer.new_symbol('if_true',
                                                  body_scope.referenced)
        orelse_name = self.context.namer.new_symbol('if_false',
                                                    orelse_scope.referenced)
        if modified:

            def build_returns(aliased_names, alias_map, scope):
                """Builds list of return variables for a branch of a conditional."""
                returns = []
                for s in modified:
                    if s in aliased_names:
                        returns.append(alias_map[s])
                    else:
                        if s not in scope.created | defined:
                            raise ValueError(
                                'Attempting to return variable "%s" from the true branch of '
                                'a conditional, but it was not closed over, or created in '
                                'this branch.' % str(s))
                        else:
                            returns.append(s)
                return tuple(returns)

            body_returns = build_returns(aliased_body_orig_names,
                                         alias_body_map, body_scope)
            orelse_returns = build_returns(aliased_orelse_orig_names,
                                           alias_orelse_map, orelse_scope)

        else:
            body_returns = orelse_returns = templates.replace(
                'tf.ones(())')[0].value

        body_def = self._create_cond_branch(
            body_name,
            aliased_orig_names=tuple(aliased_body_orig_names),
            aliased_new_names=tuple(aliased_body_new_names),
            body=node_body,
            returns=body_returns)
        orelse_def = self._create_cond_branch(
            orelse_name,
            aliased_orig_names=tuple(aliased_orelse_orig_names),
            aliased_new_names=tuple(aliased_orelse_new_names),
            body=node_orelse,
            returns=orelse_returns)
        cond_expr = self._create_cond_expr(results, node.test, body_name,
                                           orelse_name)

        return body_def + orelse_def + cond_expr
예제 #17
0
  def visit_If(self, node):
    node = self.generic_visit(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

    modified_in_cond = body_scope.modified | orelse_scope.modified
    returned_from_cond = set()
    for s in modified_in_cond:
      if s in live_out:
        returned_from_cond.add(s)
      elif s.is_composite():
        # Special treatment for compound objects: if any of their owner entities
        # are live, then they are outputs as well.
        if any(owner in live_out for owner in s.owner_set):
          returned_from_cond.add(s)

    need_alias_in_body = body_scope.modified & defined_in
    need_alias_in_orelse = orelse_scope.modified & defined_in

    created_in_body = body_scope.modified & returned_from_cond - defined_in
    created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

    if created_in_body != created_in_orelse:
      raise ValueError(
          'if statement may not initialize all variables: the true branch'
          ' creates %s, while the false branch creates %s. Make sure all'
          ' these variables are initialized either in both'
          ' branches or before the if statement.' %
          (self._fmt_symbol_list(created_in_body),
           self._fmt_symbol_list(created_in_orelse)))

    # Alias the closure variables inside the conditional functions, to allow
    # the functions access to the respective variables.
    # We will alias variables independently for body and orelse scope,
    # because different branches might write different variables.
    aliased_body_orig_names = tuple(need_alias_in_body)
    aliased_orelse_orig_names = tuple(need_alias_in_orelse)
    aliased_body_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
        for s in aliased_body_orig_names)
    aliased_orelse_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
        for s in aliased_orelse_orig_names)

    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
    alias_orelse_map = dict(
        zip(aliased_orelse_orig_names, aliased_orelse_new_names))

    node_body = ast_util.rename_symbols(node.body, alias_body_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

    returned_from_cond = tuple(returned_from_cond)
    if returned_from_cond:
      if len(returned_from_cond) == 1:
        # TODO(mdan): Move this quirk into the operator implementation.
        cond_results = returned_from_cond[0]
      else:
        cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)

      returned_from_body = tuple(
          alias_body_map[s] if s in need_alias_in_body else s
          for s in returned_from_cond)
      returned_from_orelse = tuple(
          alias_orelse_map[s] if s in need_alias_in_orelse else s
          for s in returned_from_cond)

    else:
      # When the cond would return no value, we leave the cond called without
      # results. That in turn should trigger the side effect guards. The
      # branch functions will return a dummy value that ensures cond
      # actually has some return value as well.
      cond_results = None
      # TODO(mdan): This doesn't belong here; it's specific to the operator.
      returned_from_body = templates.replace_as_expression('1')
      returned_from_orelse = templates.replace_as_expression('1')

    body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
    orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)

    body_def = self._create_cond_branch(
        body_name,
        aliased_orig_names=aliased_body_orig_names,
        aliased_new_names=aliased_body_new_names,
        body=node_body,
        returns=returned_from_body)
    orelse_def = self._create_cond_branch(
        orelse_name,
        aliased_orig_names=aliased_orelse_orig_names,
        aliased_new_names=aliased_orelse_new_names,
        body=node_orelse,
        returns=returned_from_orelse)
    cond_expr = self._create_cond_expr(cond_results, node.test, body_name,
                                       orelse_name)

    return body_def + orelse_def + cond_expr
예제 #18
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        node, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            owner_type=c)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = node
    namer = program_ctx.new_namer(class_namespace)
    class_name = namer.compiled_class_name(c.__name__, c)

    # TODO(mdan): This needs to be explained more thoroughly.
    # Process any base classes: if the sueprclass if of a whitelisted type, an
    # absolute import line is generated. Otherwise, it is marked for conversion
    # (as a side effect of the call to namer.compiled_class_name() followed by
    # program_ctx.update_name_map(namer)).
    output_nodes = []
    renames = {}
    bases = []
    for base in c.__bases__:
        if isinstance(object, base):
            bases.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            # This will trigger a conversion into a class with this name.
            alias = namer.compiled_class_name(base.__name__, base)
        bases.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
    program_ctx.update_name_map(namer)

    # Generate the definition of the converted class.
    output_nodes.append(
        gast.ClassDef(class_name,
                      bases=bases,
                      keywords=[],
                      body=list(converted_members.values()),
                      decorator_list=[]))
    node = gast.Module(output_nodes)

    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    node = qual_names.resolve(node)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    node = ast_util.rename_symbols(node, renames)

    return node, class_name, class_namespace
예제 #19
0
def class_to_graph(c, program_ctx):
  """Specialization of `entity_to_graph` for classes."""
  converted_members = {}
  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
  members = tf_inspect.getmembers(c, predicate=method_filter)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.' % c)

  class_namespace = {}
  for _, m in members:
    # Only convert the members that are directly defined by the class.
    if inspect_utils.getdefiningclass(m, c) is not c:
      continue
    node, _, namespace = function_to_graph(
        m,
        program_ctx=program_ctx,
        arg_values={},
        arg_types={'self': (c.__name__, c)},
        owner_type=c,
        rewrite_errors=False)
    if class_namespace is None:
      class_namespace = namespace
    else:
      class_namespace.update(namespace)
    converted_members[m] = node[0]
  namer = program_ctx.new_namer(class_namespace)
  class_name = namer.compiled_class_name(c.__name__, c)

  # TODO(mdan): This needs to be explained more thoroughly.
  # Process any base classes: if the sueprclass if of a whitelisted type, an
  # absolute import line is generated. Otherwise, it is marked for conversion
  # (as a side effect of the call to namer.compiled_class_name() followed by
  # program_ctx.update_name_map(namer)).
  output_nodes = []
  renames = {}
  base_names = []
  for base in c.__bases__:
    if isinstance(object, base):
      base_names.append('object')
      continue
    if is_whitelisted_for_graph(base):
      alias = namer.new_symbol(base.__name__, ())
      output_nodes.append(
          gast.ImportFrom(
              module=base.__module__,
              names=[gast.alias(name=base.__name__, asname=alias)],
              level=0))
    else:
      # This will trigger a conversion into a class with this name.
      alias = namer.compiled_class_name(base.__name__, base)
    base_names.append(alias)
    renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
  program_ctx.update_name_map(namer)

  # Generate the definition of the converted class.
  bases = [gast.Name(n, gast.Load(), None) for n in base_names]
  class_def = gast.ClassDef(
      class_name,
      bases=bases,
      keywords=[],
      body=list(converted_members.values()),
      decorator_list=[])
  # Make a final pass to replace references to the class or its base classes.
  # Most commonly, this occurs when making super().__init__() calls.
  # TODO(mdan): Making direct references to superclass' superclass will fail.
  class_def = qual_names.resolve(class_def)
  renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
  class_def = ast_util.rename_symbols(class_def, renames)

  output_nodes.append(class_def)

  return output_nodes, class_name, class_namespace
예제 #20
0
  def visit_While(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE)
    cond_closure = set()
    for s in cond_scope.referenced:
      for root in s.support_set:
        if root not in body_scope.created:
          cond_closure.add(root)

    state = list(body_closure)
    if not state:
      # TODO(mdan): Implement this properly.
      # To complete this statement, we need to check whether any variable
      # created inside the body scope is used before being modified outside the
      # scope. This should be done during activity analysis, and in general
      # should cover the case where variables may not be initialized.
      raise ValueError('cannot convert while loop: no outputs')

    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    test = ast_util.rename_symbols(node.test, ssf_map)

    template = """
      def test_name(state_ssf):
        return test
      def body_name(state_ssf):
        body
        return state_ssf,
      state_ast_tuple = ag__.while_stmt(
          test_name, body_name, (state,), (extra_deps,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        test_name=self.context.namer.new_symbol('loop_test',
                                                body_scope.referenced),
        test=test,
        body_name=self.context.namer.new_symbol('loop_body',
                                                body_scope.referenced),
        body=node_body,
        extra_deps=tuple(s.ast() for s in cond_closure),
    )

    return node
예제 #21
0
  def visit_If(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)
    body_defs = body_scope.created | body_scope.modified
    orelse_defs = orelse_scope.created | orelse_scope.modified
    live = anno.getanno(node, 'live_out')

    # We'll need to check if we're closing over variables that are defined
    # elsewhere in the function
    # NOTE: we can only detect syntactic closure in the scope
    # of the code passed in. If the AutoGraph'd function itself closes
    # over other variables, this analysis won't take that into account.
    defined = anno.getanno(node, 'defined_in')

    # We only need to return variables that are
    # - modified by one or both branches
    # - live (or has a live parent) at the end of the conditional
    modified = []
    for def_ in body_defs | orelse_defs:
      def_with_parents = set((def_,)) | def_.support_set
      if live & def_with_parents:
        modified.append(def_)

    # We need to check if live created variables are balanced
    # in both branches
    created = live & (body_scope.created | orelse_scope.created)

    # The if statement is illegal if there are variables that are created,
    # that are also live, but both branches don't create them.
    if created:
      if created != (body_scope.created & live):
        raise ValueError(
            'The main branch does not create all live symbols that the else '
            'branch does.')
      if created != (orelse_scope.created & live):
        raise ValueError(
            'The else branch does not create all live symbols that the main '
            'branch does.')

    # Alias the closure variables inside the conditional functions
    # to avoid errors caused by the local variables created in the branch
    # functions.
    # We will alias variables independently for body and orelse scope,
    # because different branches might write different variables.
    aliased_body_orig_names = tuple(body_scope.modified - body_scope.created)
    aliased_orelse_orig_names = tuple(orelse_scope.modified -
                                      orelse_scope.created)
    aliased_body_new_names = tuple(
        self.context.namer.new_symbol(s.ssf(), body_scope.referenced)
        for s in aliased_body_orig_names)
    aliased_orelse_new_names = tuple(
        self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced)
        for s in aliased_orelse_orig_names)

    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
    alias_orelse_map = dict(
        zip(aliased_orelse_orig_names, aliased_orelse_new_names))

    node_body = ast_util.rename_symbols(node.body, alias_body_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

    if not modified:
      # When the cond would return no value, we leave the cond called without
      # results. That in turn should trigger the side effect guards. The
      # branch functions will return a dummy value that ensures cond
      # actually has some return value as well.
      results = None
    elif len(modified) == 1:
      results = modified[0]
    else:
      results = gast.Tuple([s.ast() for s in modified], None)

    body_name = self.context.namer.new_symbol('if_true', body_scope.referenced)
    orelse_name = self.context.namer.new_symbol('if_false',
                                                orelse_scope.referenced)
    if modified:

      def build_returns(aliased_names, alias_map, scope):
        """Builds list of return variables for a branch of a conditional."""
        returns = []
        for s in modified:
          if s in aliased_names:
            returns.append(alias_map[s])
          else:
            if s not in scope.created | defined:
              raise ValueError(
                  'Attempting to return variable "%s" from the true branch of '
                  'a conditional, but it was not closed over, or created in '
                  'this branch.' % str(s))
            else:
              returns.append(s)
        return tuple(returns)

      body_returns = build_returns(aliased_body_orig_names, alias_body_map,
                                   body_scope)
      orelse_returns = build_returns(aliased_orelse_orig_names,
                                     alias_orelse_map, orelse_scope)

    else:
      body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value

    body_def = self._create_cond_branch(
        body_name,
        aliased_orig_names=tuple(aliased_body_orig_names),
        aliased_new_names=tuple(aliased_body_new_names),
        body=node_body,
        returns=body_returns)
    orelse_def = self._create_cond_branch(
        orelse_name,
        aliased_orig_names=tuple(aliased_orelse_orig_names),
        aliased_new_names=tuple(aliased_orelse_new_names),
        body=node_orelse,
        returns=orelse_returns)
    cond_expr = self._create_cond_expr(results, node.test, body_name,
                                       orelse_name)

    return body_def + orelse_def + cond_expr
예제 #22
0
  def visit_If(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)

    if body_scope.created - orelse_scope.created:
      raise ValueError(
          'The if branch creates new symbols that the else branch does not.')
    if orelse_scope.created - body_scope.created:
      raise ValueError(
          'The else branch creates new symbols that the if branch does not.')

    modified = tuple(body_scope.modified | orelse_scope.modified)
    all_referenced = body_scope.referenced | orelse_scope.referenced

    # Alias the closure variables inside the conditional functions
    # to avoid errors caused by the local variables created in the branch
    # functions.
    need_alias = (
        (body_scope.modified | orelse_scope.modified) -
        (body_scope.created | orelse_scope.created))
    aliased_orig_names = tuple(need_alias)
    aliased_new_names = tuple(
        self.context.namer.new_symbol(s.ssf(), all_referenced)
        for s in aliased_orig_names)
    alias_map = dict(zip(aliased_orig_names, aliased_new_names))
    node_body = ast_util.rename_symbols(node.body, alias_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_map)

    if not modified:
      # When the cond would return no value, we leave the cond called without
      # results. That in turn should trigger the side effect guards. The
      # branch functions will return a dummy value that ensures cond
      # actually has some return value as well.
      results = None
    elif len(modified) == 1:
      results = modified[0]
    else:
      results = gast.Tuple([s.ast() for s in modified], None)

    body_name = self.context.namer.new_symbol('if_true', all_referenced)
    orelse_name = self.context.namer.new_symbol('if_false', all_referenced)
    if modified:
      body_returns = tuple(
          alias_map[s] if s in aliased_orig_names else s for s in modified)
    else:
      body_returns = templates.replace('tf.ones(())')[0].value

    body_def = self._create_cond_branch(
        body_name,
        aliased_orig_names=tuple(aliased_orig_names),
        aliased_new_names=tuple(aliased_new_names),
        body=node_body,
        returns=body_returns)
    orelse_def = self._create_cond_branch(
        orelse_name,
        aliased_orig_names=tuple(aliased_orig_names),
        aliased_new_names=tuple(aliased_new_names),
        body=node_orelse,
        returns=body_returns)
    cond_expr = self._create_cond_expr(results, node.test, body_name,
                                       orelse_name)

    return body_def + orelse_def + cond_expr
예제 #23
0
    def visit_If(self, node):
        node = self.generic_visit(node)

        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
        defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
        live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

        modified_in_cond = body_scope.modified | orelse_scope.modified
        returned_from_cond = set()
        for s in modified_in_cond:
            if s in live_out:
                returned_from_cond.add(s)
            elif s.is_composite():
                # Special treatment for compound objects: if any of their owner entities
                # are live, then they are outputs as well.
                if any(owner in live_out for owner in s.owner_set):
                    returned_from_cond.add(s)

        need_alias_in_body = body_scope.modified & defined_in
        need_alias_in_orelse = orelse_scope.modified & defined_in

        created_in_body = body_scope.modified & returned_from_cond - defined_in
        created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

        if created_in_body != created_in_orelse:
            raise ValueError(
                'if statement may not initialize all variables: the true branch'
                ' creates %s, while the false branch creates %s. Make sure all'
                ' these variables are initialized either in both'
                ' branches or before the if statement.' %
                (self._fmt_symbol_list(created_in_body),
                 self._fmt_symbol_list(created_in_orelse)))

        # Alias the closure variables inside the conditional functions, to allow
        # the functions access to the respective variables.
        # We will alias variables independently for body and orelse scope,
        # because different branches might write different variables.
        aliased_body_orig_names = tuple(need_alias_in_body)
        aliased_orelse_orig_names = tuple(need_alias_in_orelse)
        aliased_body_new_names = tuple(
            self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
            for s in aliased_body_orig_names)
        aliased_orelse_new_names = tuple(
            self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
            for s in aliased_orelse_orig_names)

        alias_body_map = dict(
            zip(aliased_body_orig_names, aliased_body_new_names))
        alias_orelse_map = dict(
            zip(aliased_orelse_orig_names, aliased_orelse_new_names))

        node_body = ast_util.rename_symbols(node.body, alias_body_map)
        node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

        returned_from_cond = tuple(returned_from_cond)
        if returned_from_cond:
            if len(returned_from_cond) == 1:
                # TODO(mdan): Move this quirk into the operator implementation.
                cond_results = returned_from_cond[0]
            else:
                cond_results = gast.Tuple(
                    [s.ast() for s in returned_from_cond], None)

            returned_from_body = tuple(
                alias_body_map[s] if s in need_alias_in_body else s
                for s in returned_from_cond)
            returned_from_orelse = tuple(
                alias_orelse_map[s] if s in need_alias_in_orelse else s
                for s in returned_from_cond)

        else:
            # When the cond would return no value, we leave the cond called without
            # results. That in turn should trigger the side effect guards. The
            # branch functions will return a dummy value that ensures cond
            # actually has some return value as well.
            cond_results = None
            # TODO(mdan): This doesn't belong here; it's specific to the operator.
            returned_from_body = templates.replace_as_expression(
                'tf.constant(1)')
            returned_from_orelse = templates.replace_as_expression(
                'tf.constant(1)')

        body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
        orelse_name = self.ctx.namer.new_symbol('if_false',
                                                orelse_scope.referenced)

        body_def = self._create_cond_branch(
            body_name,
            aliased_orig_names=aliased_body_orig_names,
            aliased_new_names=aliased_body_new_names,
            body=node_body,
            returns=returned_from_body)
        orelse_def = self._create_cond_branch(
            orelse_name,
            aliased_orig_names=aliased_orelse_orig_names,
            aliased_new_names=aliased_orelse_new_names,
            body=node_orelse,
            returns=returned_from_orelse)
        cond_expr = self._create_cond_expr(cond_results, node.test, body_name,
                                           orelse_name)

        return body_def + orelse_def + cond_expr