Beispiel #1
0
  def test_get_definition_directive_multiple_consistent(self):

    directive_key = object

    def test_fn():
      a = 1
      if a:
        a = 2
      return a

    ns = {}
    node, ctx = self.prepare(test_fn, ns)
    symbol_a = node.body[2].value
    defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
    defs[0].directives[directive_key] = {
        'test_arg': parser.parse_expression('foo'),
        'other_arg': parser.parse_expression('bar'),
    }
    defs[1].directives[directive_key] = {
        'test_arg': parser.parse_expression('foo'),
        'other_arg': parser.parse_expression('baz'),
    }
    c = TestConverter(ctx)
    value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
                                       None)
    self.assertEqual(value.id, 'foo')
Beispiel #2
0
  def to_ast(self, ctx, internal_convert_user_code=None):
    """Returns a representation of this object as an AST node.

    The AST node encodes a constructor that would create an object with the
    same contents.

    Args:
      ctx: EntityContext, the entity with which this AST needs to be consistent.
      internal_convert_user_code: Optional[bool], allows ovrriding the
        corresponding value.

    Returns:
      ast.Node
    """
    template = """
      ag__.ConversionOptions(
          recursive=recursive_val,
          verbose=verbose_val,
          strip_decorators=strip_decorators_val,
          force_conversion=force_conversion_val,
          optional_features=optional_features_val,
          internal_convert_user_code=internal_convert_user_code_val)
    """

    def as_qualified_name(o):
      name = inspect_utils.getqualifiedname(ctx.info.namespace, o, max_depth=1)
      if not name:
        if isinstance(o, weakref.ref):
          # `o` might already be a weak reference, if this object was
          # constructed from code generated by `to_ast` itself.
          # If so, unpack it.
          o = o()
        # TODO(mdan): This needs to account for the symbols defined locally.
        name = ctx.namer.new_symbol(o.__name__, ())
        ctx.program.add_symbol(name, weakref.ref(o))
      return name

    def list_of_names(values):
      return parser.parse_expression('({})'.format(', '.join(
          tuple(as_qualified_name(v) for v in values))))

    def list_of_features(values):
      return parser.parse_expression('({})'.format(', '.join(
          'ag__.{}'.format(str(v)) for v in values)))

    if internal_convert_user_code is None:
      internal_convert_user_code = self.internal_convert_user_code

    expr_ast = templates.replace(
        template,
        recursive_val=parser.parse_expression(str(self.recursive)),
        verbose_val=parser.parse_expression(str(int(self.verbose))),
        strip_decorators_val=list_of_names(self._strip_decorators),
        force_conversion_val=parser.parse_expression(
            str(self.force_conversion)),
        internal_convert_user_code_val=parser.parse_expression(
            str(internal_convert_user_code)),
        optional_features_val=list_of_features(self.optional_features))
    return expr_ast[0].value
Beispiel #3
0
  def to_ast(self, namespace, internal_convert_user_code=None):
    """Returns a representation of this object as an AST node.

    The AST node encodes a constructor that would create an object with the
    same contents.

    Args:
      namespace: Dict[str, Any], the namespace to use when serializing values to
        names.
      internal_convert_user_code: Optional[bool], allows ovrriding the
        corresponding value.

    Returns:
      ast.Node
    """
    template = """
      constructor_name(
          recursive=recursive_val,
          verbose=verbose_val,
          strip_decorators=strip_decorators_val,
          force_conversion=force_conversion_val,
          optional_features=optional_features_val,
          internal_convert_user_code=internal_convert_user_code_val)
    """

    def as_qualified_name(o):
      name = inspect_utils.getqualifiedname(namespace, o)
      if not name:
        raise ValueError('Could not locate entity {} in {}'.format(
            o, namespace))
      return name

    def list_of_names(values):
      return parser.parse_expression('({})'.format(', '.join(
          tuple(as_qualified_name(v) for v in values))))

    def list_of_features(values):
      return parser.parse_expression('({})'.format(', '.join(
          'ag__.Feature.{}'.format(v)
          for v in Feature.__members__
          if v in values)))

    if internal_convert_user_code is not None:
      internal_convert_user_code = self.internal_convert_user_code

    expr_ast = templates.replace(
        template,
        constructor_name=parser.parse_expression(
            as_qualified_name(ConversionOptions)),
        recursive_val=parser.parse_expression(str(self.recursive)),
        verbose_val=parser.parse_expression(str(int(self.verbose))),
        strip_decorators_val=list_of_names(self._strip_decorators),
        force_conversion_val=parser.parse_expression(
            str(self.force_conversion)),
        internal_convert_user_code_val=parser.parse_expression(
            str(internal_convert_user_code)),
        optional_features_val=list_of_features(self.optional_features))
    return expr_ast[0].value
  def visit_arguments(self, node):
    for i in range(len(node.defaults)):
      node.defaults[i] = parser.parse_expression('None')

    for i, d in enumerate(node.kw_defaults):
      if d is not None:
        node.kw_defaults[i] = parser.parse_expression('None')

    # Only the top level function is modified - no need to visit the children.
    return node
Beispiel #5
0
  def visit_Call(self, node):
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')

      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None

      if self._function_is_compilable(target_entity):
        if self._should_compile(node, target_fqn):
          node = self._rename_compilable_function(node)
        else:
          node = self.generic_visit(node)
          return node

      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)

      elif inspect_utils.isbuiltin(target_entity):
        # Note: Any builtin that passed the builtins converter is assumed to be
        # safe for graph mode.
        return node

      elif inspect_utils.isnamedtuple(target_entity):
        # Although not compilable, we assume they are safe for graph mode.
        node = self.generic_visit(node)
        return node

      else:
        # TODO(mdan): Instert dynamic conversion here instead.
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      # Special cases
      # TODO(mdan): These need a systematic review - there may be more.

      # 1. super() calls - these are preserved. The class conversion mechanism
      # will ensure that they return the correct value.
      if ast_util.matches(node, parser.parse_expression('super(_)')):
        return node

      # 2. super().method calls - these are preserved as well, when the
      # conversion processes the entire class.
      if (ast_util.matches(node, parser.parse_expression('super(_)._(_)')) and
          self.ctx.info.owner_type is not None):
        return node

      node = self._insert_dynamic_conversion(node)
    return node
Beispiel #6
0
 def _insert_dynamic_conversion(self, node):
     """Inlines a dynamic conversion for a dynamic function."""
     # TODO(mdan): Pass information on the statically compiled functions.
     # Having access to the statically compiled functions can help avoid
     # unnecessary compilation.
     # For example, this would lead to function `a` being compiled twice:
     #
     #   def a():
     #     v = b
     #     b()
     #   def b():
     #     a()
     #
     # This is really a problem with recursive calls, which currently can
     # only be gated by a static condition, and should be rare.
     # TODO(mdan): It probably makes sense to use dynamic conversion every time.
     # Before we could convert all the time though, we'd need a reasonable
     # caching mechanism.
     template = """
   ag__.converted_call(
       func,
       ag__.ConversionOptions.new(recursive=recursive_val),
       args)
 """
     call_expr = templates.replace(template,
                                   func=node.func,
                                   recursive_val=parser.parse_expression(
                                       str(self.ctx.program.recursive)),
                                   args=node.args)
     new_call = call_expr[0].value
     # TODO(mdan): Improve the template mechanism to better support this.
     new_call.keywords = node.keywords
     return new_call
Beispiel #7
0
 def _insert_dynamic_conversion(self, node):
   """Inlines a dynamic conversion for a dynamic function."""
   # TODO(mdan): Pass information on the statically compiled functions.
   # Having access to the statically compiled functions can help avoid
   # unnecessary compilation.
   # For example, this would lead to function `a` being compiled twice:
   #
   #   def a():
   #     v = b
   #     b()
   #   def b():
   #     a()
   #
   # This is really a problem with recursive calls, which currently can
   # only be gated by a static condition, and should be rare.
   # TODO(mdan): It probably makes sense to use dynamic conversion every time.
   # Before we could convert all the time though, we'd need a reasonable
   # caching mechanism.
   template = """
     ag__.converted_call(
         func,
         ag__.ConversionOptions.new(recursive=recursive_val),
         args)
   """
   call_expr = templates.replace(
       template,
       func=node.func,
       recursive_val=parser.parse_expression(str(self.ctx.program.recursive)),
       args=node.args)
   new_call = call_expr[0].value
   # TODO(mdan): Improve the template mechanism to better support this.
   new_call.keywords = node.keywords
   return new_call
Beispiel #8
0
    def visit_FunctionDef(self, node):
        self.state[_Function].enter()
        # Note: if the conversion process ever creates helper functions, this
        # assumption will no longer hold.
        assert anno.hasanno(node, 'function_context_name'), (
            'The function_scopes converter always creates a scope for functions.'
        )
        self.state[_Function].context_name = anno.getanno(
            node, 'function_context_name')
        node.args = self.visit(node.args)
        node.body = self.visit_block(node.body)

        if self.state[_Function].level < 2:
            # Top-level functions lose their decorator because the conversion is
            # always just-in-time and by the time it happens the decorators are
            # already set to be applied.
            node.decorator_list = []
        else:
            # Inner functions are converted already, so we insert a decorator to
            # prevent double conversion. Double conversion would work too, but this
            # saves the overhead.
            node.decorator_list.append(
                parser.parse_expression('ag__.do_not_convert_internal'))

        if node.returns:
            node.returns = self.visit(node.returns)

        self.state[_Function].exit()
        return node
Beispiel #9
0
  def visit_Call(self, node):
    # TODO(mdan): Refactor converted_call as a 'Call' operator.

    # Calls to the internal 'ag__' module are never converted (though their
    # arguments might be).
    full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
    if full_name.startswith('ag__.'):
      return self.generic_visit(node)
    if (full_name == 'print' and
        not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
      return self.generic_visit(node)

    template = """
      ag__.converted_call(func, owner, options, args)
    """
    if isinstance(node.func, gast.Attribute):
      func = gast.Str(node.func.attr)
      owner = node.func.value
    else:
      func = node.func
      owner = parser.parse_expression('None')

    new_call = templates.replace_as_expression(
        template,
        func=func,
        owner=owner,
        options=self.ctx.program.options.to_ast(
            self.ctx,
            internal_convert_user_code=self.ctx.program.options.recursive),
        args=node.args)
    # TODO(mdan): Improve the template mechanism to better support this.
    new_call.keywords = node.keywords

    return new_call
Beispiel #10
0
 def _kwargs_to_dict(self, node):
   """Ties together all keyword and **kwarg arguments in a single dict."""
   if node.keywords:
     return gast.Call(
         gast.Name('dict', gast.Load(), None), args=(), keywords=node.keywords)
   else:
     return parser.parse_expression('None')
    def visit_Return(self, node):
        for block in reversed(self.state[_Block].stack):
            block.return_used = True
            block.create_guard_next = True
            if block.is_function:
                break

        retval = node.value if node.value else parser.parse_expression('None')

        # Note: If `return <expr> raises, then the return is aborted.
        # The try-catch below ensures the variables remain consistent in that case.
        template = """
      try:
        do_return_var_name = True
        retval_var_name = retval
      except:
        do_return_var_name = False
        raise
    """
        node = templates.replace(
            template,
            do_return_var_name=self.state[_Function].do_return_var_name,
            retval_var_name=self.state[_Function].retval_var_name,
            retval=retval)

        return node
Beispiel #12
0
    def visit_Call(self, node):
        # TODO(mdan): Refactor converted_call as a 'Call' operator.

        # Calls to the internal 'ag__' module are never converted (though their
        # arguments might be).
        full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
        if full_name.startswith('ag__.'):
            return self.generic_visit(node)
        if (full_name == 'print' and not self.ctx.program.options.uses(
                converter.Feature.BUILTIN_FUNCTIONS)):
            return self.generic_visit(node)

        template = """
      ag__.converted_call(func, owner, options, args)
    """
        if isinstance(node.func, gast.Attribute):
            func = gast.Str(node.func.attr)
            owner = node.func.value
        else:
            func = node.func
            owner = parser.parse_expression('None')

        new_call = templates.replace_as_expression(
            template,
            func=func,
            owner=owner,
            options=self.ctx.program.options.to_ast(
                self.ctx,
                internal_convert_user_code=self.ctx.program.options.recursive),
            args=node.args)
        # TODO(mdan): Improve the template mechanism to better support this.
        new_call.keywords = node.keywords

        return new_call
  def visit_For(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols = self._get_loop_state(node)
    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    if loop_state:
      template = """
        def extra_test_name(state_ssf):
          return extra_test_expr
        def body_name(loop_vars, state_ssf):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return state_ssf,
        state_ast_tuple = ag__.for_stmt(
            iter_, extra_test_name, body_name, (state,))
      """
      node = templates.replace(
          template,
          state=loop_state,
          state_ssf=state_ssf,
          state_ast_tuple=state_ast_tuple,
          iter_=node.iter,
          iterate=node.target,
          extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                    reserved_symbols),
          extra_test_expr=extra_test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)
    else:
      template = """
        def extra_test_name():
          return extra_test_expr
        def body_name(loop_vars):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return ()
        ag__.for_stmt(iter_, extra_test_name, body_name, ())
      """
      node = templates.replace(
          template,
          iter_=node.iter,
          iterate=node.target,
          extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                    reserved_symbols),
          extra_test_expr=extra_test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)

    return node
Beispiel #14
0
    def visit_For(self, node):
        self.generic_visit(node)

        loop_state, reserved_symbols = self._get_loop_state(node)
        loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
            loop_state, reserved_symbols)
        node_body = ast_util.rename_symbols(node.body, ssf_map)
        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        else:
            extra_test = parser.parse_expression('True')

        if loop_state:
            template = """
        def extra_test_name(state_ssf):
          return extra_test_expr
        def body_name(loop_vars, state_ssf):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return state_ssf,
        state_ast_tuple = ag__.for_stmt(
            iter_, extra_test_name, body_name, (state,))
      """
            node = templates.replace(template,
                                     state=loop_state,
                                     state_ssf=state_ssf,
                                     state_ast_tuple=state_ast_tuple,
                                     iter_=node.iter,
                                     iterate=node.target,
                                     extra_test_name=self.ctx.namer.new_symbol(
                                         'extra_test', reserved_symbols),
                                     extra_test_expr=extra_test,
                                     body_name=self.ctx.namer.new_symbol(
                                         'loop_body', reserved_symbols),
                                     body=node_body)
        else:
            template = """
        def extra_test_name():
          return extra_test_expr
        def body_name(loop_vars):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return ()
        ag__.for_stmt(iter_, extra_test_name, body_name, ())
      """
            node = templates.replace(template,
                                     iter_=node.iter,
                                     iterate=node.target,
                                     extra_test_name=self.ctx.namer.new_symbol(
                                         'extra_test', reserved_symbols),
                                     extra_test_expr=extra_test,
                                     body_name=self.ctx.namer.new_symbol(
                                         'loop_body', reserved_symbols),
                                     body=node_body)

        return node
 def _as_function(self, func_name, args):
     template = """
   func_name(args)
 """
     replacement = templates.replace_as_expression(
         template, func_name=parser.parse_expression(func_name), args=args)
     anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
     return replacement
 def _as_function(self, func_name, args):
   template = """
     func_name(args)
   """
   replacement = templates.replace_as_expression(
       template, func_name=parser.parse_expression(func_name), args=args)
   anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
   return replacement
Beispiel #17
0
  def visit_For(self, node):
    self.generic_visit(node)

    self._validate_no_live_vars_created(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iter_=node.iter,
        iterate=node.target,
        extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
        extra_test_expr=extra_test,
        body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
Beispiel #18
0
    def visit_For(self, node):
        self.generic_visit(node)

        self._validate_no_live_vars_created(node)

        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        body_closure = body_scope.modified - body_scope.created
        all_referenced = body_scope.referenced

        state = list(body_closure)

        state_ssf = [
            self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
        ]
        ssf_map = {
            name: ssf
            for name, ssf in zip(state, state_ssf) if str(name) != ssf
        }

        if len(state) == 1:
            state = state[0]
            state_ssf = state_ssf[0]
            state_ast_tuple = state
        else:
            state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

        node_body = ast_util.rename_symbols(node.body, ssf_map)
        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        else:
            extra_test = parser.parse_expression('True')

        template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
        node = templates.replace(
            template,
            state=state,
            state_ssf=state_ssf,
            state_ast_tuple=state_ast_tuple,
            iter_=node.iter,
            iterate=node.target,
            extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                      all_referenced),
            extra_test_expr=extra_test,
            body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
            body=node_body)

        return node
 def test_keywords_to_dict(self):
     keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
     d = ast_util.keywords_to_dict(keywords)
     # Make sure we generate a usable dict node by attaching it to a variable and
     # compiling everything.
     node = parser.parse_str('def f(b): pass').body[0]
     node.body.append(ast.Return(d))
     result, _ = compiler.ast_to_object(node)
     self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
 def test_lambda_in_function_call(self):
   template = """
     a = foo(arg)
   """
   source = parser.parse_expression('[lambda i: i]')
   node = templates.replace(template, arg=source)
   lambda_arg = node[0].value.args[0].elts[0]
   self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param)
   self.assertIsInstance(lambda_arg.body.ctx, gast.Load)
Beispiel #21
0
 def test_star_comprehension_in_function_call(self):
   template = """
     a = foo(func, args)
   """
   source = parser.parse_expression('bar(*[i for i in range(j)])')
   node = templates.replace(template, func=source.func, args=source.args)
   arg_node = node[0].value.args[1].value
   self.assertIsInstance(arg_node.generators[0].target.ctx, gast.Store)
   self.assertIsInstance(arg_node.elt.ctx, gast.Load)
Beispiel #22
0
 def test_keywords_to_dict(self):
   keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
   d = ast_util.keywords_to_dict(keywords)
   # Make sure we generate a usable dict node by attaching it to a variable and
   # compiling everything.
   node = parser.parse_str('def f(b): pass').body[0]
   node.body.append(ast.Return(d))
   result, _ = compiler.ast_to_object(node)
   self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
Beispiel #23
0
 def test_lambda_in_function_call(self):
     template = """
   a = foo(arg)
 """
     source = parser.parse_expression('[lambda i: i]')
     node = templates.replace(template, arg=source)
     lambda_arg = node[0].value.args[0].elts[0]
     self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param)
     self.assertIsInstance(lambda_arg.body.ctx, gast.Load)
Beispiel #24
0
 def test_star_comprehension_in_function_call(self):
     template = """
   a = foo(func, args)
 """
     source = parser.parse_expression('bar(*[i for i in range(j)])')
     node = templates.replace(template, func=source.func, args=source.args)
     arg_node = node[0].value.args[1].value
     self.assertIsInstance(arg_node.generators[0].target.ctx, gast.Store)
     self.assertIsInstance(arg_node.elt.ctx, gast.Load)
Beispiel #25
0
    def test_index_access_multiple_definitions(self):
        def test_fn(l):
            if l:
                l = []
            return l[1]

        node, ctx = self.prepare(test_fn, {})
        def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.int32')
        }
        def_, = anno.getanno(node.body[0].body[0].targets[0],
                             anno.Static.DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.float32')
        }
        with self.assertRaises(ValueError):
            slices.transform(node, ctx)
Beispiel #26
0
  def test_replace_tuple_context(self):
    template = """
      def test_fn(foo):
        foo = 0
    """

    node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0]
    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
    self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
    self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
Beispiel #27
0
    def test_replace_name_with_dict(self):
        template = """
      def test_fn():
        return foo['bar']
    """

        source = parser.parse_expression('{\'bar\': 3}')
        node = templates.replace(template, foo=source)[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(3, result.test_fn())
Beispiel #28
0
  def test_replace_tuple_context(self):
    template = """
      def test_fn(foo):
        foo = 0
    """

    node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0]
    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
    self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
    self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
Beispiel #29
0
    def test_replace_expression_context(self):
        template = """
      def test_fn():
        foo
    """

        node = templates.replace(
            template, foo=parser.parse_expression('a + 2 * b / -c'))[0]
        self.assertIsInstance(node.body[0].left.ctx, gast.Load)
        self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)
Beispiel #30
0
  def test_index_access_multiple_definitions(self):

    def test_fn(l):
      if l:
        l = []
      return l[1]

    node, ctx = self.prepare(test_fn, {})
    def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.int32')
    }
    def_, = anno.getanno(node.body[0].body[0].targets[0],
                         anno.Static.DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.float32')
    }
    with self.assertRaises(transformer.AutographParseError):
      slices.transform(node, ctx)
Beispiel #31
0
  def test_replace_name_with_dict(self):
    template = """
      def test_fn():
        return foo['bar']
    """

    source = parser.parse_expression('{\'bar\': 3}')
    node = templates.replace(template, foo=source)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(3, result.test_fn())
Beispiel #32
0
  def test_replace_expression_context(self):
    template = """
      def test_fn():
        foo
    """

    node = templates.replace(
        template, foo=parser.parse_expression('a + 2 * b / -c'))[0]
    self.assertIsInstance(node.body[0].left.ctx, gast.Load)
    self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)
Beispiel #33
0
    def to_ast(self, namespace):
        """Returns a representation of this object as an AST node.

    The AST node encodes a constructor that would create an object with the
    same contents.

    Args:
      namespace: Dict[str, Any], the namespace to use when serializing values to
        names.

    Returns:
      ast.Node
    """
        template = """
      constructor_name(
          recursive=recursive_val,
          verbose=verbose_val,
          strip_decorators=strip_decorator_names,
          force_conversion=force_conversion_val)
    """

        def as_qualified_name(o):
            name = inspect_utils.getqualifiedname(namespace, o)
            if not name:
                raise ValueError('Could not locate entity {} in {}'.format(
                    o, namespace))
            return name

        strip_decorators_code = '({})'.format(', '.join(
            tuple(as_qualified_name(o) for o in self.strip_decorators)))

        expr_ast = templates.replace(
            template,
            constructor_name=parser.parse_expression(
                as_qualified_name(ConversionOptions)),
            recursive_val=parser.parse_expression(str(self.recursive)),
            verbose_val=parser.parse_expression(str(self.verbose)),
            strip_decorator_names=parser.parse_expression(
                strip_decorators_code),
            force_conversion_val=parser.parse_expression(
                str(self.force_conversion)))
        return expr_ast[0].value
Beispiel #34
0
    def test_replace_index(self):
        template = """
      def test_fn():
        foo = 0
    """

        node = templates.replace(
            template, foo=parser.parse_expression('foo(a[b]).bar'))[0]
        function_call_arg = node.body[0].targets[0].value.args[0]
        self.assertIsInstance(function_call_arg.ctx, gast.Load)
        self.assertIsInstance(function_call_arg.slice.ctx, gast.Load)
Beispiel #35
0
  def test_replace_index(self):
    template = """
      def test_fn():
        foo = 0
    """

    node = templates.replace(
        template, foo=parser.parse_expression('foo(a[b]).bar'))[0]
    function_call_arg = node.body[0].targets[0].value.args[0]
    self.assertIsInstance(function_call_arg.ctx, gast.Load)
    self.assertIsInstance(function_call_arg.slice.value.ctx, gast.Load)
Beispiel #36
0
    def test_get_definition_directive_basic(self):

        directive_key = object

        def test_fn():
            a = 1
            return a

        ns = {}
        node, ctx = self.prepare(test_fn, ns)
        symbol_a = node.body[1].value
        defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
        defs.directives[directive_key] = {
            'test_arg': parser.parse_expression('foo'),
            'other_arg': parser.parse_expression('bar'),
        }
        c = TestConverter(ctx)
        value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
                                           None)
        self.assertEqual(value.id, 'foo')
Beispiel #37
0
 def _wrap_to_py_func_single_return(self, node, dtype):
   # TODO(mdan): Properly handle varargs, etc.
   template = """
     ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False)
   """
   return templates.replace_as_expression(
       template,
       func=node.func,
       dtype=parser.parse_expression(dtype),
       args=node.args,
       kwargs=ast_util.keywords_to_dict(node.keywords))
Beispiel #38
0
  def to_ast(self, namespace):
    """Returns a representation of this object as an AST node.

    The AST node encodes a constructor that would create an object with the
    same contents.

    Args:
      namespace: Dict[str, Any], the namespace to use when serializing values to
        names.

    Returns:
      ast.Node
    """
    template = """
      constructor_name(
          recursive=recursive_val,
          verbose=verbose_val,
          strip_decorators=strip_decorator_names,
          force_conversion=force_conversion_val)
    """

    def as_qualified_name(o):
      name = inspect_utils.getqualifiedname(namespace, o)
      if not name:
        raise ValueError('Could not locate entity {} in {}'.format(
            o, namespace))
      return name

    strip_decorators_code = '({})'.format(', '.join(
        tuple(as_qualified_name(o) for o in self.strip_decorators)))

    expr_ast = templates.replace(
        template,
        constructor_name=parser.parse_expression(
            as_qualified_name(ConversionOptions)),
        recursive_val=parser.parse_expression(str(self.recursive)),
        verbose_val=parser.parse_expression(str(self.verbose)),
        strip_decorator_names=parser.parse_expression(strip_decorators_code),
        force_conversion_val=parser.parse_expression(
            str(self.force_conversion)))
    return expr_ast[0].value
Beispiel #39
0
  def test_replace_attribute_context(self):
    template = """
      def test_fn(foo):
        foo = 0
    """

    node = templates.replace(
        template,
        foo=parser.parse_expression('a.b.c'))[0]
    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
    self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
    self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load)
Beispiel #40
0
    def test_replace_attribute_context(self):
        template = """
      def test_fn(foo):
        foo = 0
    """

        node = templates.replace(template,
                                 foo=parser.parse_expression('a.b.c'))[0]
        self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
        self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
        self.assertIsInstance(node.body[0].targets[0].value.value.ctx,
                              gast.Load)
    def to_ast(self, internal_convert_user_code=None):
        """Returns a representation of this object as an AST node.

    The AST node encodes a constructor that would create an object with the
    same contents.

    Args:
      internal_convert_user_code: Optional[bool], allows ovrriding the
        corresponding value.

    Returns:
      ast.Node
    """
        if self == STANDARD_OPTIONS:
            return parser.parse_expression('ag__.STD')

        template = """
      ag__.ConversionOptions(
          recursive=recursive_val,
          force_conversion=force_conversion_val,
          optional_features=optional_features_val,
          internal_convert_user_code=internal_convert_user_code_val)
    """

        def list_of_features(values):
            return parser.parse_expression('({})'.format(', '.join(
                'ag__.{}'.format(str(v)) for v in values)))

        if internal_convert_user_code is None:
            internal_convert_user_code = self.internal_convert_user_code

        expr_ast = templates.replace(
            template,
            recursive_val=parser.parse_expression(str(self.recursive)),
            force_conversion_val=parser.parse_expression(
                str(self.force_conversion)),
            internal_convert_user_code_val=parser.parse_expression(
                str(internal_convert_user_code)),
            optional_features_val=list_of_features(self.optional_features))
        return expr_ast[0].value
Beispiel #42
0
    def _as_function(self, func_name, args, args_as_lambda=False):
        if args_as_lambda:
            args_as_lambda = []
            for arg in args:
                template = """
          lambda: arg
        """
                args_as_lambda.append(
                    templates.replace_as_expression(template, arg=arg))
            args = args_as_lambda

        if not args:
            template = """
        func_name()
      """
            replacement = templates.replace_as_expression(
                template, func_name=parser.parse_expression(func_name))
        elif len(args) == 1:
            template = """
        func_name(arg)
      """
            replacement = templates.replace_as_expression(
                template,
                func_name=parser.parse_expression(func_name),
                arg=args[0])
        elif len(args) == 2:
            template = """
        func_name(arg1, arg2)
      """
            replacement = templates.replace_as_expression(
                template,
                func_name=parser.parse_expression(func_name),
                arg1=args[0],
                arg2=args[1])
        else:
            raise NotImplementedError('{} arguments for {}'.format(
                len(args), func_name))

        anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
        return replacement
Beispiel #43
0
  def test_replace_complex_context(self):
    template = """
      def test_fn():
        foo = 0
    """

    node = templates.replace(
        template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
    function_call_arg = node.body[0].targets[0].value.args[0]
    self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
    self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
    self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
Beispiel #44
0
    def test_list_pop(self):
        def test_fn():
            l = special_functions.tensor_list([1, 2, 3])
            s = l.pop()
            return s, l

        ns = {'special_functions': special_functions}
        node, ctx = self.prepare(test_fn, ns)
        def_, = anno.getanno(node.body[0].targets[0],
                             anno.Static.ORIG_DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.int32'),
            'shape': parser.parse_expression('()'),
        }
        node = lists.transform(node, ctx)

        with self.compiled(node, ns, dtypes.int32) as result:
            with self.cached_session() as sess:
                ts, tl = result.test_fn()
                r = list_ops.tensor_list_stack(tl, dtypes.int32)
                self.assertAllEqual(self.evaluate(r), [1, 2])
                self.assertAllEqual(self.evaluate(ts), 3)
Beispiel #45
0
    def test_replace_complex_context(self):
        template = """
      def test_fn():
        foo = 0
    """

        node = templates.replace(
            template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
        self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
        function_call_arg = node.body[0].targets[0].value.args[0]
        self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
        self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
        self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
Beispiel #46
0
  def visit_FunctionDef(self, node):
    with self.state[_Function] as fn_scope:
      scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)

      function_context_name = self.ctx.namer.new_symbol('fscope',
                                                        scope.referenced)
      fn_scope.context_name = function_context_name
      anno.setanno(node, 'function_context_name', function_context_name)

      node = self.generic_visit(node)

      if fn_scope.level <= 2:
        # Top-level functions lose their decorator because the conversion is
        # always just-in-time and by the time it happens the decorators are
        # already set to be applied.
        node.decorator_list = []
      else:
        # TODO(mdan): Fix the tests so that we can always add this decorator.
        # Inner functions are converted already, so we insert a decorator to
        # prevent double conversion. Double conversion would work too, but this
        # saves the overhead.
        node.decorator_list.append(
            parser.parse_expression('ag__.autograph_artifact'))

      docstring_node = None
      if node.body:
        first_statement = node.body[0]
        if (isinstance(first_statement, gast.Expr) and
            isinstance(first_statement.value, gast.Constant)):
          docstring_node = first_statement
          node.body = node.body[1:]

      template = """
        with ag__.FunctionScope(
            function_name, context_name, options) as function_context:
          body
      """
      wrapped_body = templates.replace(
          template,
          function_name=gast.Constant(node.name, kind=None),
          context_name=gast.Constant(function_context_name, kind=None),
          options=self._function_scope_options(fn_scope).to_ast(),
          function_context=function_context_name,
          body=node.body)

      if docstring_node is not None:
        wrapped_body = [docstring_node] + wrapped_body

      node.body = wrapped_body

      return node
  def _as_function(self, func_name, args, args_as_lambda=False):
    if args_as_lambda:
      args_as_lambda = []
      for arg in args:
        template = """
          lambda: arg
        """
        args_as_lambda.append(
            templates.replace_as_expression(template, arg=arg))
      args = args_as_lambda

    if not args:
      template = """
        func_name()
      """
      replacement = templates.replace_as_expression(
          template, func_name=parser.parse_expression(func_name))
    elif len(args) == 1:
      template = """
        func_name(arg)
      """
      replacement = templates.replace_as_expression(
          template, func_name=parser.parse_expression(func_name), arg=args[0])
    elif len(args) == 2:
      template = """
        func_name(arg1, arg2)
      """
      replacement = templates.replace_as_expression(
          template,
          func_name=parser.parse_expression(func_name),
          arg1=args[0],
          arg2=args[1])
    else:
      raise NotImplementedError('{} arguments for {}'.format(
          len(args), func_name))

    anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
    return replacement
Beispiel #48
0
  def test_list_pop(self):

    def test_fn():
      l = special_functions.tensor_list([1, 2, 3])
      s = l.pop()
      return s, l

    ns = {'special_functions': special_functions}
    node, ctx = self.prepare(test_fn, ns)
    def_, = anno.getanno(node.body[0].targets[0],
                         anno.Static.ORIG_DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.int32'),
        'shape': parser.parse_expression('()'),
    }
    node = lists.transform(node, ctx)

    with self.compiled(node, ns, dtypes.int32) as result:
      with self.cached_session() as sess:
        ts, tl = result.test_fn()
        r = list_ops.tensor_list_stack(tl, dtypes.int32)
        self.assertAllEqual(self.evaluate(r), [1, 2])
        self.assertAllEqual(self.evaluate(ts), 3)
Beispiel #49
0
  def test_get_definition_directive_multiple_inconsistent(self):

    directive_key = object

    def test_fn():
      a = 1
      if a:
        a = 2
      return a

    ns = {}
    node, ctx = self.prepare(test_fn, ns)
    symbol_a = node.body[2].value
    defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
    defs[0].directives[directive_key] = {
        'test_arg': parser.parse_expression('foo'),
    }
    defs[1].directives[directive_key] = {
        'test_arg': parser.parse_expression('bar'),
    }
    c = TestConverter(ctx)
    with self.assertRaises(ValueError):
      c.get_definition_directive(symbol_a, directive_key, 'test_arg', None)
Beispiel #50
0
  def to_ast(self, internal_convert_user_code=None):
    """Returns a representation of this object as an AST node.

    The AST node encodes a constructor that would create an object with the
    same contents.

    Args:
      internal_convert_user_code: Optional[bool], allows ovrriding the
        corresponding value.

    Returns:
      ast.Node
    """
    template = """
      ag__.ConversionOptions(
          recursive=recursive_val,
          force_conversion=force_conversion_val,
          optional_features=optional_features_val,
          internal_convert_user_code=internal_convert_user_code_val)
    """

    def list_of_features(values):
      return parser.parse_expression('({})'.format(', '.join(
          'ag__.{}'.format(str(v)) for v in values)))

    if internal_convert_user_code is None:
      internal_convert_user_code = self.internal_convert_user_code

    expr_ast = templates.replace(
        template,
        recursive_val=parser.parse_expression(str(self.recursive)),
        force_conversion_val=parser.parse_expression(
            str(self.force_conversion)),
        internal_convert_user_code_val=parser.parse_expression(
            str(internal_convert_user_code)),
        optional_features_val=list_of_features(self.optional_features))
    return expr_ast[0].value
Beispiel #51
0
  def test_get_definition_directive_default(self):

    directive_key = object

    def test_fn():
      a = 1
      return a

    ns = {}
    node, ctx = self.prepare(test_fn, ns)
    symbol_a = node.body[1].value
    c = TestConverter(ctx)
    value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
                                       parser.parse_expression('default'))
    self.assertEqual(value.id, 'default')
Beispiel #52
0
    def test_get_definition_directive_default(self):

        directive_key = object

        def test_fn():
            a = 1
            return a

        ns = {}
        node, ctx = self.prepare(test_fn, ns)
        symbol_a = node.body[1].value
        c = TestConverter(ctx)
        value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
                                           parser.parse_expression('default'))
        self.assertEqual(value.id, 'default')
Beispiel #53
0
    def test_get_definition_directive_default(self):

        directive_key = object

        def f():
            a = 1
            return a

        _, node, ctx = self.transform(f, (), include_ast=True)

        symbol_a = node.body[1].value
        c = TestConverter(ctx)
        value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
                                           parser.parse_expression('default'))
        self.assertEqual(value.id, 'default')
Beispiel #54
0
    def test_replace_name_with_call(self):
        template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

        source = parser.parse_expression('f()(b)')
        node = templates.replace(template, foo=source)[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(15, result.test_fn())
Beispiel #55
0
  def test_replace_name_with_call(self):
    template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

    source = parser.parse_expression('f()(b)')
    node = templates.replace(template, foo=source)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(15, result.test_fn())
Beispiel #56
0
  def visit_Return(self, node):
    self.state[_Return].used = True

    retval = node.value if node.value else parser.parse_expression('None')

    template = """
      do_return_var_name = True
      retval_var_name = retval
    """
    node = templates.replace(
        template,
        do_return_var_name=self.state[_Function].do_return_var_name,
        retval_var_name=self.state[_Function].retval_var_name,
        retval=retval)

    return node
Beispiel #57
0
    def test_list_stack(self):
        def test_fn():
            l = [1, 2, 3]
            return tf.stack(l)

        node, ctx = self.prepare(test_fn, {})
        def_, = anno.getanno(node.body[0].targets[0],
                             anno.Static.ORIG_DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.int32')
        }
        node = lists.transform(node, ctx)

        with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
            with self.cached_session() as sess:
                self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
Beispiel #58
0
  def test_replace_call_keyword(self):
    template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

    source = parser.parse_expression('f(d=3, f=5)')
    node = templates.replace(template, kws=source.keywords)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(9, result.test_fn())

    with self.assertRaises(ValueError):
      templates.replace(template, kws=[])
      templates.replace(template, kws=1)
Beispiel #59
0
  def test_list_stack(self):

    def test_fn():
      l = [1, 2, 3]
      return tf.stack(l)

    node, ctx = self.prepare(test_fn, {})
    def_, = anno.getanno(node.body[0].targets[0],
                         anno.Static.ORIG_DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.int32')
    }
    node = lists.transform(node, ctx)

    with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
      with self.cached_session() as sess:
        self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3])