Exemple #1
0
  def _generate_pop_operation(self, original_call_node, pop_var_name):
    assert isinstance(original_call_node.func, gast.Attribute)

    if original_call_node.args:
      pop_element = original_call_node.args[0]
    else:
      pop_element = parser.parse_expression('None')
    # The call will be something like "target.pop()", and the dtype is hooked to
    # target, hence the func.value.
    dtype = anno.getanno(
        original_call_node.func.value,
        'element_type',
        default=templates.replace_as_expression('None'))
    shape = anno.getanno(
        original_call_node.func.value,
        'element_shape',
        default=templates.replace_as_expression('None'))

    template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
    return templates.replace(
        template,
        target=original_call_node.func.value,
        pop_var_name=pop_var_name,
        element=pop_element,
        dtype=dtype,
        shape=shape)
Exemple #2
0
    def visit_Call(self, node):
        if anno.hasanno(node.func, 'live_val'):
            # Symbols targeted by the "set_type" marker function are assigned the data
            # type that it specified.
            if (anno.getanno(node.func, 'live_val') is
                    self.context.type_annotation_func):

                if len(node.args) < 2 or len(node.args) > 3:
                    raise ValueError(
                        '"%s" must have either two or three parameters' %
                        self.context.type_annotation_func)
                if len(node.args) == 2:
                    target_arg, type_arg = node.args
                    shape_arg = parser.parse_expression('None')
                else:
                    target_arg, type_arg, shape_arg = node.args
                if not anno.hasanno(target_arg, anno.Basic.QN):
                    raise ValueError(
                        'the first argument of "%s" must by a symbol' %
                        self.context.type_annotation_func)
                # TODO(mdan): This is vulnerable to symbol renaming.
                element_type = type_arg
                element_shape = shape_arg

                target_symbol = anno.getanno(target_arg, anno.Basic.QN)
                # Find the definition of this symbol and annotate it with the given
                # data type. That in turn will cause future uses of the symbol
                # to receive the same type annotation.
                definition = self.scope.getval(target_symbol)
                anno.setanno(node, 'element_type', element_type)
                anno.setanno(node, 'element_shape', element_shape)
                anno.setanno(definition, 'element_type', element_type)
                anno.setanno(definition, 'element_shape', element_shape)
                # TODO(mdan): Should we update references between definition and here?
        return self.generic_visit(node)
Exemple #3
0
  def _generate_pop_operation(self, original_call_node, pop_var_name):
    assert isinstance(original_call_node.func, gast.Attribute)

    if original_call_node.args:
      pop_element = original_call_node.args[0]
    else:
      pop_element = parser.parse_expression('None')
    # The call will be something like "target.pop()", and the dtype is hooked to
    # target, hence the func.value.
    dtype = anno.getanno(
        original_call_node.func.value,
        'element_type',
        default=templates.replace_as_expression('None'))
    shape = anno.getanno(
        original_call_node.func.value,
        'element_shape',
        default=templates.replace_as_expression('None'))

    template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
    return templates.replace(
        template,
        target=original_call_node.func.value,
        pop_var_name=pop_var_name,
        element=pop_element,
        dtype=dtype,
        shape=shape)
  def visit_Call(self, node):
    if anno.hasanno(node.func, 'live_val'):
      # Symbols targeted by the "set_type" marker function are assigned the data
      # type that it specified.
      if anno.getanno(node.func, 'live_val') is utils.set_element_type:

        if len(node.args) < 2 or len(node.args) > 3:
          raise ValueError('"%s" must have either two or three parameters'
                           % self.context.type_annotation_func)
        if len(node.args) == 2:
          target_arg, type_arg = node.args
          shape_arg = parser.parse_expression('None')
        else:
          target_arg, type_arg, shape_arg = node.args
        if not anno.hasanno(target_arg, anno.Basic.QN):
          raise ValueError('the first argument of "%s" must by a symbol' %
                           utils.set_element_type)
        # TODO(mdan): This is vulnerable to symbol renaming.
        element_type = type_arg
        element_shape = shape_arg

        target_symbol = anno.getanno(target_arg, anno.Basic.QN)
        # Find the definition of this symbol and annotate it with the given
        # data type. That in turn will cause future uses of the symbol
        # to receive the same type annotation.
        definition = self.scope.getval(target_symbol)
        anno.setanno(node, 'element_type', element_type)
        anno.setanno(node, 'element_shape', element_shape)
        anno.setanno(definition, 'element_type', element_type)
        anno.setanno(definition, 'element_shape', element_shape)
        # TODO(mdan): Should we update references between definition and here?
    return self.generic_visit(node)
 def _as_function(self, func_name, args):
     template = """
   func_name(args)
 """
     replacement = templates.replace_as_expression(
         template, func_name=parser.parse_expression(func_name), args=args)
     anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
     return replacement
 def _as_function(self, func_name, args):
   template = """
     func_name(args)
   """
   replacement = templates.replace_as_expression(
       template, func_name=parser.parse_expression(func_name), args=args)
   anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
   return replacement
Exemple #7
0
  def visit_For(self, node):
    self.generic_visit(node)

    self._validate_no_live_vars_created(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iter_=node.iter,
        iterate=node.target,
        extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
        extra_test_expr=extra_test,
        body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
Exemple #8
0
    def visit_For(self, node):
        self.generic_visit(node)

        self._validate_no_live_vars_created(node)

        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        body_closure = body_scope.modified - body_scope.created
        all_referenced = body_scope.referenced

        state = list(body_closure)

        state_ssf = [
            self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
        ]
        ssf_map = {
            name: ssf
            for name, ssf in zip(state, state_ssf) if str(name) != ssf
        }

        if len(state) == 1:
            state = state[0]
            state_ssf = state_ssf[0]
            state_ast_tuple = state
        else:
            state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

        node_body = ast_util.rename_symbols(node.body, ssf_map)
        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        else:
            extra_test = parser.parse_expression('True')

        template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
        node = templates.replace(
            template,
            state=state,
            state_ssf=state_ssf,
            state_ast_tuple=state_ast_tuple,
            iter_=node.iter,
            iterate=node.target,
            extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                      all_referenced),
            extra_test_expr=extra_test,
            body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
            body=node_body)

        return node
Exemple #9
0
    def test_index_access_multiple_definitions(self):
        def test_fn(l):
            if l:
                l = []
            return l[1]

        node, ctx = self.prepare(test_fn, {})
        def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.int32')
        }
        def_, = anno.getanno(node.body[0].body[0].targets[0],
                             anno.Static.DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.float32')
        }
        with self.assertRaises(transformer.AutographParseError):
            slices.transform(node, ctx)
 def test_keywords_to_dict(self):
   keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
   d = ast_util.keywords_to_dict(keywords)
   # Make sure we generate a usable dict node by attaching it to a variable and
   # compiling everything.
   output = parser.parse_str('b = 3')
   output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),)
   result, _ = compiler.ast_to_object(output)
   self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
 def test_keywords_to_dict(self):
     keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
     d = ast_util.keywords_to_dict(keywords)
     # Make sure we generate a usable dict node by attaching it to a variable and
     # compiling everything.
     output = parser.parse_str('b = 3')
     output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d), )
     result, _ = compiler.ast_to_object(output)
     self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
Exemple #12
0
 def test_keywords_to_dict(self):
     keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
     d = ast_util.keywords_to_dict(keywords)
     # Make sure we generate a usable dict node by attaching it to a variable and
     # compiling everything.
     node = parser.parse_str('def f(b): pass').body[0]
     node.body.append(ast.Return(d))
     result, _ = compiler.ast_to_object(node)
     self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
 def test_keywords_to_dict(self):
   keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
   d = ast_util.keywords_to_dict(keywords)
   # Make sure we generate a usable dict node by attaching it to a variable and
   # compiling everything.
   node = parser.parse_str('def f(b): pass').body[0]
   node.body.append(ast.Return(d))
   result, _ = compiler.ast_to_object(node)
   self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
  def test_replace_name_with_dict(self):
    template = """
      def test_fn():
        return foo['bar']
    """

    source = parser.parse_expression('{\'bar\': 3}')
    node = templates.replace(template, foo=source)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(3, result.test_fn())
Exemple #15
0
  def test_index_access_multiple_definitions(self):

    def test_fn(l):
      if l:
        l = []
      return l[1]

    node, ctx = self.prepare(test_fn, {})
    def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.int32')
    }
    def_, = anno.getanno(node.body[0].body[0].body[0].targets[0],
                         anno.Static.DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.float32')
    }
    with self.assertRaises(transformer.AutographParseError):
      slices.transform(node, ctx)
    def test_replace_name_with_dict(self):
        template = """
      def test_fn():
        return foo['bar']
    """

        source = parser.parse_expression('{\'bar\': 3}')
        node = templates.replace(template, foo=source)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(3, result.test_fn())
Exemple #17
0
  def visit_For(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_cond'):
      extra_cond = anno.getanno(node, 'extra_cond')
      extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
    else:
      extra_cond = parser.parse_expression('True')

    template = """
      def extra_cond_name(state_ssf):
        return extra_cond_expr
      def body_name(iterate, state_ssf):
        body
        return state_ssf,
      state_ast_tuple = ag__.for_loop(
          iterated, extra_cond_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iterated=node.iter,
        iterate=node.target,
        extra_cond_name=self.context.namer.new_symbol('extra_cond',
                                                      all_referenced),
        extra_cond_expr=extra_cond,
        body_name=self.context.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
Exemple #18
0
  def visit_For(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_cond'):
      extra_cond = anno.getanno(node, 'extra_cond')
      extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
    else:
      extra_cond = parser.parse_expression('True')

    template = """
      def extra_cond_name(state_ssf):
        return extra_cond_expr
      def body_name(iterate, state_ssf):
        body
        return state_ssf,
      state_ast_tuple = __ops.for_loop(
          iterated, extra_cond_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iterated=node.iter,
        iterate=node.target,
        extra_cond_name=self.context.namer.new_symbol('extra_cond',
                                                      all_referenced),
        extra_cond_expr=extra_cond,
        body_name=self.context.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
 def _wrap_to_py_func_single_return(self, node, dtype):
     # TODO(mdan): Properly handle varargs, etc.
     template = """
   autograph_utils.wrap_py_func(func, dtype, (args,), kwargs, False)
 """
     return templates.replace_as_expression(
         template,
         func=node.func,
         dtype=parser.parse_expression(dtype),
         args=node.args,
         kwargs=ast_util.keywords_to_dict(node.keywords))
Exemple #20
0
 def _wrap_to_py_func_single_return(self, node, dtype):
   # TODO(mdan): Properly handle varargs, etc.
   template = """
     ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False)
   """
   return templates.replace_as_expression(
       template,
       func=node.func,
       dtype=parser.parse_expression(dtype),
       args=node.args,
       kwargs=ast_util.keywords_to_dict(node.keywords))
    def test_replace_tuple_context(self):
        template = """
      def test_fn(foo):
        foo = 0
    """

        node = templates.replace(template,
                                 foo=parser.parse_expression('(a, b)'))[0]
        self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
        self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
        self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
Exemple #22
0
    def test_list_pop(self):
        def test_fn():
            l = [1, 2, 3]
            s = l.pop()
            return s, l

        node, ctx = self.prepare(test_fn, {})
        def_, = anno.getanno(node.body[0].body[0].targets[0],
                             anno.Static.ORIG_DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.int32'),
            'shape': parser.parse_expression('()'),
        }
        node = lists.transform(node, ctx)

        with self.compiled(node, {}, dtypes.int32) as result:
            with self.test_session() as sess:
                ts, tl = result.test_fn()
                r = list_ops.tensor_list_stack(tl, dtypes.int32)
                self.assertAllEqual(sess.run(r), [1, 2])
                self.assertAllEqual(sess.run(ts), 3)
  def test_replace_attribute_context(self):
    template = """
      def test_fn(foo):
        foo = 0
    """

    node = templates.replace(
        template,
        foo=parser.parse_expression('a.b.c'))[0]
    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
    self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
    self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load)
    def test_replace_attribute_context(self):
        template = """
      def test_fn(foo):
        foo = 0
    """

        node = templates.replace(template,
                                 foo=parser.parse_expression('a.b.c'))[0]
        self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
        self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
        self.assertIsInstance(node.body[0].targets[0].value.value.ctx,
                              gast.Load)
Exemple #25
0
 def _wrap_to_py_func_single_return(self, node, dtype):
     # TODO (mdan): Properly handle varargs, etc. id:492
     # https://github.com/imdone/tensorflow/issues/493
     template = """
   ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False)
 """
     return templates.replace_as_expression(
         template,
         func=node.func,
         dtype=parser.parse_expression(dtype),
         args=node.args,
         kwargs=ast_util.keywords_to_dict(node.keywords))
Exemple #26
0
  def test_list_pop(self):

    def test_fn():
      l = [1, 2, 3]
      s = l.pop()
      return s, l

    node, ctx = self.prepare(test_fn, {})
    def_, = anno.getanno(node.body[0].body[0].targets[0],
                         anno.Static.ORIG_DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.int32'),
        'shape': parser.parse_expression('()'),
    }
    node = lists.transform(node, ctx)

    with self.compiled(node, {}, dtypes.int32) as result:
      with self.test_session() as sess:
        ts, tl = result.test_fn()
        r = list_ops.tensor_list_stack(tl, dtypes.int32)
        self.assertAllEqual(sess.run(r), [1, 2])
        self.assertAllEqual(sess.run(ts), 3)
    def test_replace_complex_context(self):
        template = """
      def test_fn(foo):
        foo = 0
    """

        node = templates.replace(
            template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
        self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
        function_call_arg = node.body[0].targets[0].value.args[0]
        self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
        self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
        self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
  def test_replace_name_with_call(self):
    template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

    source = parser.parse_expression('f()(b)')
    node = templates.replace(template, foo=source)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(15, result.test_fn())
    def test_replace_name_with_call(self):
        template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

        source = parser.parse_expression('f()(b)')
        node = templates.replace(template, foo=source)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(15, result.test_fn())
    def test_list_stack(self):
        def test_fn():
            l = [1, 2, 3]
            return tf.stack(l)

        node, ctx = self.prepare(test_fn, {})
        def_, = anno.getanno(node.body[0].body[0].targets[0],
                             anno.Static.ORIG_DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.int32')
        }
        node = lists.transform(node, ctx)

        with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
            with self.test_session() as sess:
                self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
    def test_replace_call_keyword(self):
        template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

        source = parser.parse_expression('f(d=3, f=5)')
        node = templates.replace(template, kws=source.keywords)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(9, result.test_fn())

        with self.assertRaises(ValueError):
            templates.replace(template, kws=[])
            templates.replace(template, kws=1)
  def test_replace_call_keyword(self):
    template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

    source = parser.parse_expression('f(d=3, f=5)')
    node = templates.replace(template, kws=source.keywords)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(9, result.test_fn())

    with self.assertRaises(ValueError):
      templates.replace(template, kws=[])
      templates.replace(template, kws=1)
def apply_to_single_assignments(targets, values, apply_fn):
    """Applies a function to each individual assignment.

  This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
  It tries to break down the unpacking if possible. In effect, it has the same
  effect as passing the assigned values in SSA form to apply_fn.

  Examples:

  The following will result in apply_fn(a, c), apply_fn(b, d):

      a, b = c, d

  The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):

      a, b = c

  The following will result in apply_fn(a, (b, c)):

      a = b, c

  It uses the visitor pattern to allow subclasses to process single
  assignments individually.

  Args:
    targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
        used with the targets field of an ast.Assign node
    values: ast.AST
    apply_fn: Callable[[ast.AST, ast.AST], None], called with the
        respective nodes of each single assignment
  """
    if not isinstance(targets, (list, tuple)):
        targets = (targets, )
    for target in targets:
        if isinstance(target, (gast.Tuple, gast.List)):
            for i in range(len(target.elts)):
                target_el = target.elts[i]
                if isinstance(values, (gast.Tuple, gast.List)):
                    value_el = values.elts[i]
                else:
                    idx = parser.parse_expression(str(i))
                    value_el = gast.Subscript(values,
                                              gast.Index(idx),
                                              ctx=gast.Load())
                apply_to_single_assignments(target_el, value_el, apply_fn)
        else:
            apply_fn(target, values)
Exemple #34
0
  def test_list_stack(self):

    def test_fn():
      l = [1, 2, 3]
      return tf.stack(l)

    node, ctx = self.prepare(test_fn, {})
    def_, = anno.getanno(node.body[0].targets[0],
                         anno.Static.ORIG_DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.int32')
    }
    node = lists.transform(node, ctx)

    with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
      with self.test_session() as sess:
        self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
Exemple #35
0
def apply_to_single_assignments(targets, values, apply_fn):
  """Applies a function to each individual assignment.

  This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
  It tries to break down the unpacking if possible. In effect, it has the same
  effect as passing the assigned values in SSA form to apply_fn.

  Examples:

  The following will result in apply_fn(a, c), apply_fn(b, d):

      a, b = c, d

  The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):

      a, b = c

  The following will result in apply_fn(a, (b, c)):

      a = b, c

  It uses the visitor pattern to allow subclasses to process single
  assignments individually.

  Args:
    targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
        used with the targets field of an ast.Assign node
    values: ast.AST
    apply_fn: Callable[[ast.AST, ast.AST], None], called with the
        respective nodes of each single assignment
  """
  if not isinstance(targets, (list, tuple)):
    targets = (targets,)
  for target in targets:
    if isinstance(target, (gast.Tuple, gast.List)):
      for i in range(len(target.elts)):
        target_el = target.elts[i]
        if isinstance(values, (gast.Tuple, gast.List)):
          value_el = values.elts[i]
        else:
          idx = parser.parse_expression(str(i))
          value_el = gast.Subscript(values, gast.Index(idx), ctx=gast.Load())
        apply_to_single_assignments(target_el, value_el, apply_fn)
    else:
      apply_fn(target, values)
Exemple #36
0
  def test_index_access(self):

    def test_fn(l):
      return l[1]

    node, ctx = self.prepare(test_fn, {})
    def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.int32')
    }
    node = slices.transform(node, ctx)

    with self.compiled(node, {}, dtypes.int32) as result:
      with self.test_session() as sess:
        tl = list_ops.tensor_list_from_tensor(
            [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
        y = result.test_fn(tl)
        self.assertEqual(2, sess.run(y))
Exemple #37
0
    def test_index_access(self):
        def test_fn(l):
            return l[1]

        node, ctx = self.prepare(test_fn, {})
        def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.int32')
        }
        node = slices.transform(node, ctx)

        with self.compiled(node, {}, dtypes.int32) as result:
            with self.test_session() as sess:
                tl = list_ops.tensor_list_from_tensor(
                    [1, 2],
                    element_shape=constant_op.constant([], dtype=dtypes.int32))
                y = result.test_fn(tl)
                self.assertEqual(2, sess.run(y))
Exemple #38
0
def matches(node, pattern):
  """Basic pattern matcher for AST.

  The pattern may contain wildcards represented by the symbol '_'. A node
  matches a pattern if for every node in the tree, either there is a node of
  the same type in pattern, or a Name node with id='_'.

  Args:
    node: ast.AST
    pattern: ast.AST
  Returns:
    bool
  """
  if isinstance(pattern, str):
    pattern = parser.parse_expression(pattern)
  matcher = PatternMatcher(pattern)
  matcher.visit(node)
  return matcher.matches
Exemple #39
0
def matches(node, pattern):
  """Basic pattern matcher for AST.

  The pattern may contain wildcards represented by the symbol '_'. A node
  matches a pattern if for every node in the tree, either there is a node of
  the same type in pattern, or a Name node with id='_'.

  Args:
    node: ast.AST
    pattern: ast.AST
  Returns:
    bool
  """
  if isinstance(pattern, str):
    pattern = parser.parse_expression(pattern)
  matcher = PatternMatcher(pattern)
  matcher.visit(node)
  return matcher.matches
Exemple #40
0
  def _generate_pop_operation(self, original_call_node, pop_var_name):
    assert isinstance(original_call_node.func, gast.Attribute)

    if original_call_node.args:
      pop_element = original_call_node.args[0]
    else:
      pop_element = parser.parse_expression('None')

    # The call will be something like "target.pop()", and the dtype is hooked to
    # target, hence the func.value.
    # TODO(mdan): For lists of lists, this won't work.
    # The reason why it won't work is because it's unclear how to annotate
    # the list as a "list of lists with a certain element type" when using
    # operations like `l.pop().pop()`.
    dtype = self.get_definition_directive(
        original_call_node.func.value,
        directives.set_element_type,
        'dtype',
        default=templates.replace_as_expression('None'))
    shape = self.get_definition_directive(
        original_call_node.func.value,
        directives.set_element_type,
        'shape',
        default=templates.replace_as_expression('None'))

    template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
    return templates.replace(
        template,
        target=original_call_node.func.value,
        pop_var_name=pop_var_name,
        element=pop_element,
        dtype=dtype,
        shape=shape)
    def _generate_pop_operation(self, original_call_node, pop_var_name):
        assert isinstance(original_call_node.func, gast.Attribute)

        if original_call_node.args:
            pop_element = original_call_node.args[0]
        else:
            pop_element = parser.parse_expression('None')

        # The call will be something like "target.pop()", and the dtype is hooked to
        # target, hence the func.value.
        # TODO(mdan): For lists of lists, this won't work.
        # The reason why it won't work is because it's unclear how to annotate
        # the list as a "list of lists with a certain element type" when using
        # operations like `l.pop().pop()`.
        dtype = self.get_definition_directive(
            original_call_node.func.value,
            directives.set_element_type,
            'dtype',
            default=templates.replace_as_expression('None'))
        shape = self.get_definition_directive(
            original_call_node.func.value,
            directives.set_element_type,
            'shape',
            default=templates.replace_as_expression('None'))

        template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
        return templates.replace(template,
                                 target=original_call_node.func.value,
                                 pop_var_name=pop_var_name,
                                 element=pop_element,
                                 dtype=dtype,
                                 shape=shape)
 def test_parse_expression(self):
     node = parser.parse_expression('a.b')
     self.assertEqual('a', node.value.id)
     self.assertEqual('b', node.attr)
Exemple #43
0
def matches(node, pattern):
  if isinstance(pattern, str):
    pattern = parser.parse_expression(pattern)
  matcher = PatternMatcher(pattern)
  matcher.visit(node)
  return matcher.matches
def from_str(qn_str):
  node = parser.parse_expression(qn_str)
  node = resolve(node)
  return anno.getanno(node, anno.Basic.QN)
def from_str(qn_str):
    node = parser.parse_expression(qn_str)
    node = resolve(node)
    return anno.getanno(node, anno.Basic.QN)
 def assertNoMatch(self, target_str, pattern_str):
   node = parser.parse_expression(target_str)
   pattern = parser.parse_expression(pattern_str)
   self.assertFalse(ast_util.matches(node, pattern))
 def test_parse_expression(self):
   node = parser.parse_expression('a.b')
   self.assertEqual('a', node.value.id)
   self.assertEqual('b', node.attr)
 def assertNoMatch(self, target_str, pattern_str):
     node = parser.parse_expression(target_str)
     pattern = parser.parse_expression(pattern_str)
     self.assertFalse(ast_util.matches(node, pattern))