Exemplo n.º 1
0
    def visit_Subscript(self, node):
        node = self.generic_visit(node)
        if not isinstance(node.slice, gast.Index):
            return node

        if not isinstance(node.ctx, gast.Load):
            # Index writes are handled at a higher level, one at which the rvalue is
            # also available.
            return node

        dtype = self.get_definition_directive(
            node.value,
            directives.set_element_type,
            'dtype',
            default=templates.replace_as_expression('None'))

        template = """
      ag__.get_item(
          target,
          key,
          opts=ag__.GetItemOpts(element_dtype=dtype))
    """
        return templates.replace_as_expression(template,
                                               target=node.value,
                                               key=node.slice,
                                               dtype=dtype)
Exemplo n.º 2
0
  def _generate_pop_operation(self, original_call_node, pop_var_name):
    assert isinstance(original_call_node.func, gast.Attribute)

    if original_call_node.args:
      pop_element = original_call_node.args[0]
    else:
      pop_element = parser.parse_expression('None')
    # The call will be something like "target.pop()", and the dtype is hooked to
    # target, hence the func.value.
    dtype = anno.getanno(
        original_call_node.func.value,
        'element_type',
        default=templates.replace_as_expression('None'))
    shape = anno.getanno(
        original_call_node.func.value,
        'element_shape',
        default=templates.replace_as_expression('None'))

    template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
    return templates.replace(
        template,
        target=original_call_node.func.value,
        pop_var_name=pop_var_name,
        element=pop_element,
        dtype=dtype,
        shape=shape)
Exemplo n.º 3
0
 def test_replace_as_expression_restrictions(self):
   template = """
     foo(a)
     bar(b)
   """
   with self.assertRaises(ValueError):
     templates.replace_as_expression(template)
Exemplo n.º 4
0
  def _generate_pop_operation(self, original_call_node, pop_var_name):
    assert isinstance(original_call_node.func, gast.Attribute)

    if original_call_node.args:
      pop_element = original_call_node.args[0]
    else:
      pop_element = parser.parse_expression('None')
    # The call will be something like "target.pop()", and the dtype is hooked to
    # target, hence the func.value.
    dtype = anno.getanno(
        original_call_node.func.value,
        'element_type',
        default=templates.replace_as_expression('None'))
    shape = anno.getanno(
        original_call_node.func.value,
        'element_shape',
        default=templates.replace_as_expression('None'))

    template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
    return templates.replace(
        template,
        target=original_call_node.func.value,
        pop_var_name=pop_var_name,
        element=pop_element,
        dtype=dtype,
        shape=shape)
Exemplo n.º 5
0
 def test_replace_as_expression_restrictions(self):
     template = """
   foo(a)
   bar(b)
 """
     with self.assertRaises(ValueError):
         templates.replace_as_expression(template)
Exemplo n.º 6
0
    def visit_Subscript(self, node):
        node = self.generic_visit(node)
        if not isinstance(node.slice, gast.Index):
            # TODO(mdan): It might make more sense to wave them through.
            raise NotImplementedError('non-index slice')

        if not isinstance(node.ctx, gast.Load):
            # Index writes are handled at a higher level, one at which the rvalue is
            # also available.
            return node

        dtype = anno.getanno(node.value,
                             'element_type',
                             default=templates.replace_as_expression('None'))

        template = """
      ag__.get_item(
          target,
          key,
          opts=ag__.GetItemOpts(element_dtype=dtype))
    """
        return templates.replace_as_expression(template,
                                               target=node.value,
                                               key=node.slice,
                                               dtype=dtype)
Exemplo n.º 7
0
  def visit_For(self, node):
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    break_var = self.context.namer.new_symbol('break__', scope.referenced)

    node.target = self.visit(node.target)
    node.iter = self.visit(node.iter)
    node.body, break_used = self._track_body(node.body, break_var)
    # A break in the else clause applies to the containing scope.
    node.orelse = self.visit_block(node.orelse)

    if break_used:
      node.orelse = self._guard_if_present(node.orelse, break_var)
      template = """
        var_name = False
        for_stmt
      """
      # Python's else clause only triggers if the loop exited cleanly (e.g.
      # break did not trigger).
      node = templates.replace(
          template,
          var_name=break_var,
          for_stmt=node)
      extra_test = templates.replace_as_expression(
          'not var_name', var_name=break_var)
      anno.setanno(node[1], 'extra_test', extra_test)

    return node
Exemplo n.º 8
0
 def _replace_stack_call(self, node):
     assert len(node.args) == 1
     dtype = anno.getanno(node.args[0],
                          'element_type',
                          default=templates.replace_as_expression('None'))
     template = """
   ag__.list_stack(
       target,
       opts=ag__.ListStackOpts(
           element_dtype=dtype,
           original_call=orig_call))
 """
     return templates.replace_as_expression(template,
                                            dtype=dtype,
                                            target=node.args[0],
                                            orig_call=node.func)
Exemplo n.º 9
0
 def visit_IfExp(self, node):
   template = """
       ag__.utils.run_cond(test, lambda: (body,), lambda: (orelse,))
   """
   desugared_ifexp = templates.replace_as_expression(
       template, test=node.test, body=node.body, orelse=node.orelse)
   return desugared_ifexp
 def _as_function(self, func_name, args):
   template = """
     func_name(args)
   """
   replacement = templates.replace_as_expression(
       template, func_name=parser.parse_expression(func_name), args=args)
   anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
   return replacement
Exemplo n.º 11
0
 def _as_function(self, func_name, args):
     template = """
   func_name(args)
 """
     replacement = templates.replace_as_expression(
         template, func_name=parser.parse_expression(func_name), args=args)
     anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
     return replacement
Exemplo n.º 12
0
 def _replace_stack_call(self, node):
   assert len(node.args) == 1
   dtype = anno.getanno(
       node.args[0],
       'element_type',
       default=templates.replace_as_expression('None'))
   template = """
     ag__.list_stack(
         target,
         opts=ag__.ListStackOpts(
             element_dtype=dtype,
             original_call=orig_call))
   """
   return templates.replace_as_expression(
       template,
       dtype=dtype,
       target=node.args[0],
       orig_call=node.func)
Exemplo n.º 13
0
 def _convert_builtin(self, f, args, as_expression):
   template = """
     ag__.func(args)
   """
   if as_expression:
     return templates.replace_as_expression(
         template, func=py_builtins.overload_of(f).__name__, args=args)
   else:
     return templates.replace(
         template, func=py_builtins.overload_of(f).__name__, args=args)
 def _convert_builtin(self, f, args, as_expression):
     template = """
   ag__.func(args)
 """
     if as_expression:
         return templates.replace_as_expression(
             template, func=py_builtins.overload_of(f).__name__, args=args)
     else:
         return templates.replace(template,
                                  func=py_builtins.overload_of(f).__name__,
                                  args=args)
Exemplo n.º 15
0
 def _wrap_to_py_func_single_return(self, node, dtype):
   # TODO(mdan): Properly handle varargs, etc.
   template = """
     ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False)
   """
   return templates.replace_as_expression(
       template,
       func=node.func,
       dtype=parser.parse_expression(dtype),
       args=node.args,
       kwargs=ast_util.keywords_to_dict(node.keywords))
 def _wrap_to_py_func_single_return(self, node, dtype):
     # TODO(mdan): Properly handle varargs, etc.
     template = """
   autograph_utils.wrap_py_func(func, dtype, (args,), kwargs, False)
 """
     return templates.replace_as_expression(
         template,
         func=node.func,
         dtype=parser.parse_expression(dtype),
         args=node.args,
         kwargs=ast_util.keywords_to_dict(node.keywords))
Exemplo n.º 17
0
 def _wrap_to_py_func_single_return(self, node, dtype):
     # TODO (mdan): Properly handle varargs, etc. id:492
     # https://github.com/imdone/tensorflow/issues/493
     template = """
   ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False)
 """
     return templates.replace_as_expression(
         template,
         func=node.func,
         dtype=parser.parse_expression(dtype),
         args=node.args,
         kwargs=ast_util.keywords_to_dict(node.keywords))
  def visit_IfExp(self, node):
    if anno.hasanno(node.test, anno.Basic.QN):
      name_root = anno.getanno(node.test, anno.Basic.QN).ssf()
    else:
      name_root = 'ifexp'

    true_fn_name = self._create_branch(node.body, '%s_true' % name_root)
    false_fn_name = self._create_branch(node.orelse, '%s_false' % name_root)

    return templates.replace_as_expression(
        'ag__.utils.run_cond(test, true_fn_name, false_fn_name)',
        test=node.test,
        true_fn_name=true_fn_name,
        false_fn_name=false_fn_name)
Exemplo n.º 19
0
  def _generate_pop_operation(self, original_call_node, pop_var_name):
    assert isinstance(original_call_node.func, gast.Attribute)

    if original_call_node.args:
      pop_element = original_call_node.args[0]
    else:
      pop_element = parser.parse_expression('None')

    # The call will be something like "target.pop()", and the dtype is hooked to
    # target, hence the func.value.
    # TODO(mdan): For lists of lists, this won't work.
    # The reason why it won't work is because it's unclear how to annotate
    # the list as a "list of lists with a certain element type" when using
    # operations like `l.pop().pop()`.
    dtype = self.get_definition_directive(
        original_call_node.func.value,
        directives.set_element_type,
        'dtype',
        default=templates.replace_as_expression('None'))
    shape = self.get_definition_directive(
        original_call_node.func.value,
        directives.set_element_type,
        'shape',
        default=templates.replace_as_expression('None'))

    template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
    return templates.replace(
        template,
        target=original_call_node.func.value,
        pop_var_name=pop_var_name,
        element=pop_element,
        dtype=dtype,
        shape=shape)
Exemplo n.º 20
0
  def visit_Subscript(self, node):
    node = self.generic_visit(node)
    if not isinstance(node.slice, gast.Index):
      return node

    if not isinstance(node.ctx, gast.Load):
      # Index writes are handled at a higher level, one at which the rvalue is
      # also available.
      return node

    dtype = self.get_definition_directive(
        node.value,
        directives.set_element_type,
        'dtype',
        default=templates.replace_as_expression('None'))

    template = """
      ag__.get_item(
          target,
          key,
          opts=ag__.GetItemOpts(element_dtype=dtype))
    """
    return templates.replace_as_expression(
        template, target=node.value, key=node.slice.value, dtype=dtype)
Exemplo n.º 21
0
    def _generate_pop_operation(self, original_call_node, pop_var_name):
        assert isinstance(original_call_node.func, gast.Attribute)

        if original_call_node.args:
            pop_element = original_call_node.args[0]
        else:
            pop_element = parser.parse_expression('None')

        # The call will be something like "target.pop()", and the dtype is hooked to
        # target, hence the func.value.
        # TODO(mdan): For lists of lists, this won't work.
        # The reason why it won't work is because it's unclear how to annotate
        # the list as a "list of lists with a certain element type" when using
        # operations like `l.pop().pop()`.
        dtype = self.get_definition_directive(
            original_call_node.func.value,
            directives.set_element_type,
            'dtype',
            default=templates.replace_as_expression('None'))
        shape = self.get_definition_directive(
            original_call_node.func.value,
            directives.set_element_type,
            'shape',
            default=templates.replace_as_expression('None'))

        template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
        return templates.replace(template,
                                 target=original_call_node.func.value,
                                 pop_var_name=pop_var_name,
                                 element=pop_element,
                                 dtype=dtype,
                                 shape=shape)
Exemplo n.º 22
0
  def visit_Subscript(self, node):
    node = self.generic_visit(node)
    if not isinstance(node.slice, gast.Index):
      # TODO(mdan): It might make more sense to wave them through.
      raise NotImplementedError('non-index slice')

    if not isinstance(node.ctx, gast.Load):
      # Index writes are handled at a higher level, one at which the rvalue is
      # also available.
      return node

    dtype = anno.getanno(
        node.value,
        'element_type',
        default=templates.replace_as_expression('None'))

    template = """
      ag__.get_item(
          target,
          key,
          opts=ag__.GetItemOpts(element_dtype=dtype))
    """
    return templates.replace_as_expression(
        template, target=node.value, key=node.slice, dtype=dtype)
    def visit_IfExp(self, node):
        if anno.hasanno(node.test, anno.Basic.QN):
            name_root = anno.getanno(node.test, anno.Basic.QN).ssf()
        else:
            name_root = 'ifexp'

        true_fn_name = self._create_branch(node.body, '%s_true' % name_root)
        false_fn_name = self._create_branch(node.orelse,
                                            '%s_false' % name_root)

        return templates.replace_as_expression(
            'ag__.utils.run_cond(test, true_fn_name, false_fn_name)',
            test=node.test,
            true_fn_name=true_fn_name,
            false_fn_name=false_fn_name)
Exemplo n.º 24
0
  def _empty_list(self, node):
    if not anno.hasanno(node, 'element_type'):
      raise NotImplementedError(
          'type inference for empty lists is not yet supported; '
          'use set_element_type(<list>, <dtype>) to continue')
    dtype = anno.getanno(node, 'element_type')
    if not isinstance(dtype, dtypes.DType):
      # TODO(mdan): Allow non-TF dtypes?
      # That would be consistent with the dynamic dispatch pattern, but
      # we must make sure that doesn't become confusing.
      raise NotImplementedError('element type "%s" not yet supported' % dtype)

    dtype_name = dtype.name
    # TODO(mdan): Does it ever make sense not to use tensor lists?
    template = """
      tf.TensorArray(tf.dtype_name, size=0, dynamic_size=True)
    """
    return templates.replace_as_expression(template, dtype_name=dtype_name)
Exemplo n.º 25
0
    def _empty_list(self, node):
        if not anno.hasanno(node, 'element_type'):
            raise NotImplementedError(
                'type inference for empty lists is not yet supported; '
                'use utils.set_element_type(<list>, <dtype>) to continue')
        dtype = anno.getanno(node, 'element_type')
        if not isinstance(dtype, dtypes.DType):
            # TODO(mdan): Allow non-TF dtypes?
            # That would be consistent with the dynamic dispatch pattern, but
            # we must make sure that doesn't become confusing.
            raise NotImplementedError('element type "%s" not yet supported' %
                                      dtype)

        dtype_name = dtype.name
        # TODO(mdan): Does it ever make sense not to use tensor lists?
        template = """
      tf.TensorArray(tf.dtype_name, size=0, dynamic_size=True)
    """
        return templates.replace_as_expression(template, dtype_name=dtype_name)
Exemplo n.º 26
0
    def visit_For(self, node):
        self.generic_visit(node.target)
        self.generic_visit(node.iter)
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)

        break_var = self.context.namer.new_symbol('break_requested',
                                                  scope.referenced)
        self.break_uses.append([False, break_var])
        node.body = self._manual_visit_list(node.body)
        if self.break_uses[-1][0]:
            extra_cond = templates.replace_as_expression('not var_name',
                                                         var_name=break_var)
            anno.setanno(node, 'extra_cond', extra_cond)
            final_nodes = [self._create_break_init(), node]
        else:
            final_nodes = node
        self.break_uses.pop()

        for n in node.orelse:
            self.generic_visit(n)
        return final_nodes
Exemplo n.º 27
0
  def visit_For(self, node):
    self.generic_visit(node.target)
    self.generic_visit(node.iter)
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)

    break_var = self.context.namer.new_symbol('break_requested',
                                              scope.referenced)
    self.break_uses.append([False, break_var])
    node.body = self._manual_visit_list(node.body)
    if self.break_uses[-1][0]:
      extra_cond = templates.replace_as_expression(
          'not var_name', var_name=break_var)
      anno.setanno(node, 'extra_cond', extra_cond)
      final_nodes = [self._create_break_init(), node]
    else:
      final_nodes = node
    self.break_uses.pop()

    for n in node.orelse:
      self.generic_visit(n)
    return final_nodes
Exemplo n.º 28
0
  def visit_For(self, node):
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

    node.target = self.visit(node.target)
    node.iter = self.visit(node.iter)
    node.body, break_used = self._track_body(node.body, break_var)
    # A break in the else clause applies to the containing scope.
    node.orelse = self.visit_block(node.orelse)

    if break_used:
      # Python's else clause only triggers if the loop exited cleanly (e.g.
      # break did not trigger).
      guarded_orelse = self._guard_if_present(node.orelse, break_var)
      extra_test = templates.replace_as_expression(
          'not var_name', var_name=break_var)

      # The extra test is hidden in the AST, which will confuse the static
      # analysis. To mitigate that, we insert a no-op statement that ensures
      # the control variable is marked as used.
      # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
      template = """
        var_name = False
        for target in iter_:
          (var_name,)
          body
        else:
          orelse
      """
      node = templates.replace(
          template,
          var_name=break_var,
          iter_=node.iter,
          target=node.target,
          body=node.body,
          orelse=guarded_orelse)

      anno.setanno(node[1], 'extra_test', extra_test)

    return node
Exemplo n.º 29
0
    def visit_For(self, node):
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.context.namer.new_symbol('break_requested',
                                                  scope.referenced)

        self.break_uses.append([False, break_var])
        node = self.generic_visit(node)
        if self.break_uses[-1][0]:
            template = """
        var_name = False
        original_for
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     original_for=node)
            extra_cond = templates.replace_as_expression('not var_name',
                                                         var_name=break_var)
            new_for_node = node[1]
            anno.setanno(new_for_node, 'extra_cond', extra_cond)
        self.break_uses.pop()

        return node
Exemplo n.º 30
0
    def visit_For(self, node):
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

        node.target = self.visit(node.target)
        node.iter = self.visit(node.iter)
        node.body, break_used = self._process_body(node.body, break_var)
        # A break in the else clause applies to the containing scope.
        node.orelse = self.visit_block(node.orelse)

        if break_used:
            # Python's else clause only triggers if the loop exited cleanly (e.g.
            # break did not trigger).
            guarded_orelse = self._guard_if_present(node.orelse, break_var)
            extra_test = templates.replace_as_expression('not var_name',
                                                         var_name=break_var)

            # The extra code is hidden in the AST, which will confuse the static
            # analysis. To mitigate that, we insert a no-op statement that ensures
            # the control variable is marked as used.
            # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
            template = """
        var_name = tf.constant(False)
        for target in iter_:
          (var_name,)
          body
        else:
          orelse
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     iter_=node.iter,
                                     target=node.target,
                                     body=node.body,
                                     orelse=guarded_orelse)

            anno.setanno(node[1], 'extra_test', extra_test)

        return node
Exemplo n.º 31
0
    def _replace_pop_call(self, node):
        # Expressions that use pop() are converted to a statement + expression.
        #
        # For example:
        #
        #   print(target.pop())
        #
        # ... is converted to:
        #
        #   target, target_pop = ag__.list_pop(target)
        #   print(target_pop)
        #
        # Here, we just generate the variable name and swap it in,
        # and _generate_pop_operation will handle the rest.
        #
        # Multiple uses of pop() are allowed:
        #
        #   print(tartget.pop(), target.pop())
        #   print(tartget.pop().pop())
        #
        assert isinstance(node.func, gast.Attribute)
        scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
        target_node = node.func.value

        # Attempt to use a related name if one exists. Otherwise use something
        # generic.
        if anno.hasanno(target_node, anno.Basic.QN):
            target_name = anno.getanno(target_node, anno.Basic.QN).ssf()
        else:
            target_name = 'list_'
        pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced)

        pop_uses = self.get_local(POP_USES, [])
        pop_uses.append((node, pop_var_name))
        self.set_local(POP_USES, pop_uses)

        return templates.replace_as_expression('var_name',
                                               var_name=pop_var_name)
Exemplo n.º 32
0
  def visit_For(self, node):
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    break_var = self.context.namer.new_symbol('break_requested',
                                              scope.referenced)

    self.break_uses.append([False, break_var])
    node = self.generic_visit(node)
    if self.break_uses[-1][0]:
      template = """
        var_name = False
        original_for
      """
      node = templates.replace(
          template,
          var_name=break_var,
          original_for=node)
      extra_cond = templates.replace_as_expression(
          'not var_name', var_name=break_var)
      new_for_node = node[1]
      anno.setanno(new_for_node, 'extra_cond', extra_cond)
    self.break_uses.pop()

    return node
Exemplo n.º 33
0
  def _replace_pop_call(self, node):
    # Expressions that use pop() are converted to a statement + expression.
    #
    # For example:
    #
    #   print(target.pop())
    #
    # ... is converted to:
    #
    #   target, target_pop = ag__.list_pop(target)
    #   print(target_pop)
    #
    # Here, we just generate the variable name and swap it in,
    # and _generate_pop_operation will handle the rest.
    #
    # Multiple uses of pop() are allowed:
    #
    #   print(tartget.pop(), target.pop())
    #   print(tartget.pop().pop())
    #
    assert isinstance(node.func, gast.Attribute)
    scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
    target_node = node.func.value

    # Attempt to use a related name if one exists. Otherwise use something
    # generic.
    if anno.hasanno(target_node, anno.Basic.QN):
      target_name = anno.getanno(target_node, anno.Basic.QN).ssf()
    else:
      target_name = 'list_'
    pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced)

    pop_uses = self.get_local(POP_USES, [])
    pop_uses.append((node, pop_var_name))
    self.set_local(POP_USES, pop_uses)

    return templates.replace_as_expression('var_name', var_name=pop_var_name)
Exemplo n.º 34
0
    def visit_If(self, node):
        node = self.generic_visit(node)

        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
        defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
        live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

        modified_in_cond = body_scope.modified | orelse_scope.modified
        returned_from_cond = set()
        for s in modified_in_cond:
            if s in live_out:
                returned_from_cond.add(s)
            elif s.is_composite():
                # Special treatment for compound objects: if any of their owner entities
                # are live, then they are outputs as well.
                if any(owner in live_out for owner in s.owner_set):
                    returned_from_cond.add(s)

        need_alias_in_body = body_scope.modified & defined_in
        need_alias_in_orelse = orelse_scope.modified & defined_in

        created_in_body = body_scope.modified & returned_from_cond - defined_in
        created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

        if created_in_body != created_in_orelse:
            raise ValueError(
                'if statement may not initialize all variables: the true branch'
                ' creates %s, while the false branch creates %s. Make sure all'
                ' these variables are initialized either in both'
                ' branches or before the if statement.' %
                (self._fmt_symbol_list(created_in_body),
                 self._fmt_symbol_list(created_in_orelse)))

        # Alias the closure variables inside the conditional functions, to allow
        # the functions access to the respective variables.
        # We will alias variables independently for body and orelse scope,
        # because different branches might write different variables.
        aliased_body_orig_names = tuple(need_alias_in_body)
        aliased_orelse_orig_names = tuple(need_alias_in_orelse)
        aliased_body_new_names = tuple(
            self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
            for s in aliased_body_orig_names)
        aliased_orelse_new_names = tuple(
            self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
            for s in aliased_orelse_orig_names)

        alias_body_map = dict(
            zip(aliased_body_orig_names, aliased_body_new_names))
        alias_orelse_map = dict(
            zip(aliased_orelse_orig_names, aliased_orelse_new_names))

        node_body = ast_util.rename_symbols(node.body, alias_body_map)
        node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

        returned_from_cond = tuple(returned_from_cond)
        if returned_from_cond:
            if len(returned_from_cond) == 1:
                # TODO(mdan): Move this quirk into the operator implementation.
                cond_results = returned_from_cond[0]
            else:
                cond_results = gast.Tuple(
                    [s.ast() for s in returned_from_cond], None)

            returned_from_body = tuple(
                alias_body_map[s] if s in need_alias_in_body else s
                for s in returned_from_cond)
            returned_from_orelse = tuple(
                alias_orelse_map[s] if s in need_alias_in_orelse else s
                for s in returned_from_cond)

        else:
            # When the cond would return no value, we leave the cond called without
            # results. That in turn should trigger the side effect guards. The
            # branch functions will return a dummy value that ensures cond
            # actually has some return value as well.
            cond_results = None
            # TODO(mdan): This doesn't belong here; it's specific to the operator.
            returned_from_body = templates.replace_as_expression(
                'tf.constant(1)')
            returned_from_orelse = templates.replace_as_expression(
                'tf.constant(1)')

        body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
        orelse_name = self.ctx.namer.new_symbol('if_false',
                                                orelse_scope.referenced)

        body_def = self._create_cond_branch(
            body_name,
            aliased_orig_names=aliased_body_orig_names,
            aliased_new_names=aliased_body_new_names,
            body=node_body,
            returns=returned_from_body)
        orelse_def = self._create_cond_branch(
            orelse_name,
            aliased_orig_names=aliased_orelse_orig_names,
            aliased_new_names=aliased_orelse_new_names,
            body=node_orelse,
            returns=returned_from_orelse)
        cond_expr = self._create_cond_expr(cond_results, node.test, body_name,
                                           orelse_name)

        return body_def + orelse_def + cond_expr
Exemplo n.º 35
0
 def visit_List(self, node):
   node = self.generic_visit(node)
   template = """
     ag__.new_list(elements)
   """
   return templates.replace_as_expression(template, elements=node)
Exemplo n.º 36
0
  def visit_If(self, node):
    node = self.generic_visit(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

    modified_in_cond = body_scope.modified | orelse_scope.modified
    returned_from_cond = set()
    for s in modified_in_cond:
      if s in live_out:
        returned_from_cond.add(s)
      elif s.is_composite():
        # Special treatment for compound objects: if any of their owner entities
        # are live, then they are outputs as well.
        if any(owner in live_out for owner in s.owner_set):
          returned_from_cond.add(s)

    need_alias_in_body = body_scope.modified & defined_in
    need_alias_in_orelse = orelse_scope.modified & defined_in

    created_in_body = body_scope.modified & returned_from_cond - defined_in
    created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

    if created_in_body != created_in_orelse:
      raise ValueError(
          'if statement may not initialize all variables: the true branch'
          ' creates %s, while the false branch creates %s. Make sure all'
          ' these variables are initialized either in both'
          ' branches or before the if statement.' %
          (self._fmt_symbol_list(created_in_body),
           self._fmt_symbol_list(created_in_orelse)))

    # Alias the closure variables inside the conditional functions, to allow
    # the functions access to the respective variables.
    # We will alias variables independently for body and orelse scope,
    # because different branches might write different variables.
    aliased_body_orig_names = tuple(need_alias_in_body)
    aliased_orelse_orig_names = tuple(need_alias_in_orelse)
    aliased_body_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
        for s in aliased_body_orig_names)
    aliased_orelse_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
        for s in aliased_orelse_orig_names)

    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
    alias_orelse_map = dict(
        zip(aliased_orelse_orig_names, aliased_orelse_new_names))

    node_body = ast_util.rename_symbols(node.body, alias_body_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

    returned_from_cond = tuple(returned_from_cond)
    if returned_from_cond:
      if len(returned_from_cond) == 1:
        # TODO(mdan): Move this quirk into the operator implementation.
        cond_results = returned_from_cond[0]
      else:
        cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)

      returned_from_body = tuple(
          alias_body_map[s] if s in need_alias_in_body else s
          for s in returned_from_cond)
      returned_from_orelse = tuple(
          alias_orelse_map[s] if s in need_alias_in_orelse else s
          for s in returned_from_cond)

    else:
      # When the cond would return no value, we leave the cond called without
      # results. That in turn should trigger the side effect guards. The
      # branch functions will return a dummy value that ensures cond
      # actually has some return value as well.
      cond_results = None
      # TODO(mdan): This doesn't belong here; it's specific to the operator.
      returned_from_body = templates.replace_as_expression('1')
      returned_from_orelse = templates.replace_as_expression('1')

    body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
    orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)

    body_def = self._create_cond_branch(
        body_name,
        aliased_orig_names=aliased_body_orig_names,
        aliased_new_names=aliased_body_new_names,
        body=node_body,
        returns=returned_from_body)
    orelse_def = self._create_cond_branch(
        orelse_name,
        aliased_orig_names=aliased_orelse_orig_names,
        aliased_new_names=aliased_orelse_new_names,
        body=node_orelse,
        returns=returned_from_orelse)
    cond_expr = self._create_cond_expr(cond_results, node.test, body_name,
                                       orelse_name)

    return body_def + orelse_def + cond_expr
Exemplo n.º 37
0
 def visit_List(self, node):
   node = self.generic_visit(node)
   template = """
     ag__.new_list(elements)
   """
   return templates.replace_as_expression(template, elements=node)