예제 #1
0
 def visit_Assert(self, node):                                                # pylint: disable=invalid-name, no-self-use
     """Visit Assert
     Convert BinOp asserts into unittest's assert equals
     """
     args = [node.test]
     assertion = 'assertTrue'
     if isinstance(node.test, ast.Compare) and len(node.test.ops) == 1:
         args = [node.test.left, node.test.comparators[0]]
         assertion = {
             ast.Eq: 'assertEqual',
             ast.NotEq: 'assertNotEqual',
             ast.Lt: 'assertLess',
             ast.LtE: 'assertLessEqual',
             ast.Gt: 'assertGreater',
             ast.GtE: 'assertGreaterEqual',
             ast.Is: 'assertIs',
             ast.IsNot: 'assertIsNot',
             ast.In: 'assertIn',
             ast.NotIn: 'assertNotIn',
         }[node.test.ops[0].__class__]
     if node.msg:
         args.append(node.msg)
     return ast.copy_location(ast.Expr(ast.copy_location(call(
         ast.copy_location(ast.Attribute(
             ast.Name('self', ast.Load()),
             assertion, ast.Load()
         ), node),
         args
     ), node)), node)
예제 #2
0
    def visit_Name(self, old_node):
        node = nodes.Name(old_node.id, old_node.ctx)
        ast.copy_location(node, old_node)

        # Set some defaults
        node.cf_maybe_null = True
        node.cf_is_null = False
        node.allow_null = False

        node.name = node.id
        if isinstance(node.ctx, ast.Param):
            var = self.symtab[node.name]
            var.is_arg = True
            self.flow.mark_assignment(node, None, var, assignment=None)
        elif isinstance(node.ctx, ast.Load):
            var = self.symtab.lookup(node.name)
            if var:
                # Local variable
                self.flow.mark_reference(node, var)

        # Set position of assignment of this definition
        if isinstance(node.ctx, (ast.Param, ast.Store)):
            var = self.symtab[node.name]
            if var.lineno == -1:
                var.lineno = getattr(node, "lineno", 0)
                var.col_offset = getattr(node, "col_offset", 0)

        return node
예제 #3
0
파일: unroller.py 프로젝트: ml-lab/TerpreT
 def visit_BinOp(self, node):
     node = self.generic_visit(node)
     if isinstance(node.left, ast.Num) and isinstance(node.right, ast.Num):
         value = eval(compile(ast.copy_location(ast.Expression(body=node), node), '', 'eval'))
         return ast.copy_location(ast.Num(n=value), node)
     else:
         return node
예제 #4
0
파일: test_ast.py 프로젝트: 3lnc/cpython
    def test_load_const(self):
        consts = [None,
                  True, False,
                  124,
                  2.0,
                  3j,
                  "unicode",
                  b'bytes',
                  (1, 2, 3)]

        code = '\n'.join(['x={!r}'.format(const) for const in consts])
        code += '\nx = ...'
        consts.extend((Ellipsis, None))

        tree = ast.parse(code)
        self.assertEqual(self.get_load_const(tree),
                         consts)

        # Replace expression nodes with constants
        for assign, const in zip(tree.body, consts):
            assert isinstance(assign, ast.Assign), ast.dump(assign)
            new_node = ast.Constant(value=const)
            ast.copy_location(new_node, assign.value)
            assign.value = new_node

        self.assertEqual(self.get_load_const(tree),
                         consts)
예제 #5
0
파일: slast.py 프로젝트: duanemoody/renpy
    def prepare(self):

        for i in self.children:
            i.prepare()

        # Compile the keywords.

        keyword_values = { }
        keyword_keys = [ ]
        keyword_exprs = [ ]

        for k, expr in self.keyword:

            node = py_compile(expr, 'eval', ast_node=True)

            if is_constant(node):
                keyword_values[k] = py_eval_bytecode(compile_expr(node))
            else:
                keyword_keys.append(ast.Str(s=k))
                keyword_exprs.append(node)

        if keyword_values:
            self.keyword_values = keyword_values
        else:
            self.keyword_values = None

        if keyword_keys:
            node = ast.Dict(keys=keyword_keys, values=keyword_exprs)
            ast.copy_location(node, keyword_exprs[0])
            self.keyword_exprs = compile_expr(node)
        else:
            self.keyword_exprs = None
예제 #6
0
파일: slast.py 프로젝트: duanemoody/renpy
    def prepare(self):

        SLBlock.prepare(self)

        # Prepare the positional arguments.

        exprs = [ ]
        values = [ ]
        has_exprs = False
        has_values = False


        for a in self.positional:
            node = py_compile(a, 'eval', ast_node=True)

            if is_constant(node):
                values.append(py_eval_bytecode(compile_expr(node)))
                exprs.append(ast.Num(n=0))
                has_values = True
            else:
                values.append(use_expression)
                exprs.append(node)
                has_exprs = True

        if has_values:
            self.positional_values = values
        else:
            self.positional_values = None

        if has_exprs:
            t = ast.Tuple(elts=exprs, ctx=ast.Load())
            ast.copy_location(t, exprs[0])
            self.positional_exprs = compile_expr(t)
        else:
            self.positional_exprs = None
예제 #7
0
파일: dump.py 프로젝트: yeukhon/testcapture
    def visit_FunctionDef(self, node):
        if node.name.startswith('test'):
            statements = [stmt for stmt in node.body
                        if isinstance(stmt, ast.Assign) or
                           isinstance(stmt, ast.Expr)
                         ]
            self.statements += statements
            self.tracking[node.name] = None

            body = node.body

            new_stmts = []
            for _node in node.body:
                new_stmts.append(createNode(node.name, _node))
            new_node_body = []
            for i in xrange(0, len(new_stmts)):
                new_node = new_stmts[i]
                old_node = body[i]
                if isinstance(new_node, list):
                    for _new_node in new_node:
                        ast.copy_location(_new_node, old_node)
                        new_node_body.append(_new_node)
                    new_node_body.append(old_node)
                elif isinstance(new_node, ast.TryExcept):
                    ast.copy_location(new_node, old_node)
                    new_node_body.append(new_node)
                    new_node_body.append(old_node)
                else:
                    new_node_body.append(new_node)
            node.body = new_node_body
            ast.fix_missing_locations(node)
            return node

        ast.fix_missing_locations(node)
        return node
예제 #8
0
파일: astsix.py 프로젝트: ASPP/numba
 def __visit_FunctionDef(self, node):
     new_node = ast.FunctionDef(args=self.visit_arguments(node.args),
                                body=self._visit_list(node.body),
                                decorator_list=self._visit_list(node.decorator_list),
                                name=node.name)
     ast.copy_location(new_node, node)
     return new_node
예제 #9
0
파일: pow_to_mult.py 프로젝트: Amper/opyum
 def visit_BinOp(self, node: ast.BinOp):
     node = self.generic_visit(node)
     if self._is_numeric_pow(node):
         left, right = node.left, node.right
         degree = (    right.n         if isinstance(right, ast.Num) 
                 else -right.operand.n if isinstance(right.op, ast.USub) 
                 else  right.operand.n )
         degree = int(degree)
         if abs(degree) == 0:
             node = ast.copy_location(ast.Num(n = 1), node)
         elif abs(degree) == 1:
             node = node.left
         elif 2 <= abs(degree) <= self.MAX_DEGREE:
             for _ in range(1, abs(degree)):
                 new_node = ast.BinOp\
                             ( left = left
                             , op = ast.Mult()
                             , right = copy(node.left)
                             )
                 left = new_node = ast.copy_location(new_node, node)
             node = new_node
         else:
             return node
         if degree < 0:
             new_node = ast.BinOp\
                             ( left  = ast.Num(n = 1)
                             , op    = ast.Div()
                             , right = node 
                             )
             node = ast.copy_location(new_node, node)
     return node
예제 #10
0
 def visit_Name(self, node):
   if self.randomize():
     if node.id == 'forall':
       return ast.copy_location(_ast.Name(id='exists'), node)
     elif node.id == 'exists':
       return ast.copy_location(_ast.Name(id='forall'), node)
   return node
예제 #11
0
    def test_load_const(self):
        consts = [None,
                  True, False,
                  124,
                  2.0,
                  3j,
                  "unicode",
                  b'bytes',
                  (1, 2, 3)]

        code = '\n'.join(map(repr, consts))
        code += '\n...'

        code_consts = [const for const in consts
                       if (not isinstance(const, (str, int, float, complex))
                           or isinstance(const, bool))]
        code_consts.append(Ellipsis)
        # the compiler adds a final "LOAD_CONST None"
        code_consts.append(None)

        tree = ast.parse(code)
        self.assertEqual(self.get_load_const(tree), code_consts)

        # Replace expression nodes with constants
        for expr_node, const in zip(tree.body, consts):
            assert isinstance(expr_node, ast.Expr)
            new_node = ast.Constant(value=const)
            ast.copy_location(new_node, expr_node.value)
            expr_node.value = new_node

        self.assertEqual(self.get_load_const(tree), code_consts)
예제 #12
0
 def visit_If(self, node):
     node = self.generic_visit(node)
     if (node.orelse 
     and len(node.orelse) == 1 
     and isinstance(node.orelse[0], ast.Pass)
        ):
         node.orelse = []
     if (len(node.body) == 1
     and isinstance(node.body[0], ast.Pass)
        ):
         if node.orelse:
             node_test = ast.UnaryOp(op=ast.Not(), operand=node.test)
             if (len(node.orelse) == 1
             and isinstance(node.orelse[0], ast.If)
                ):
                 node_test   = ast.BoolOp\
                                     ( op     = ast.And()
                                     , values = [node_test, node.orelse[0].test]
                                     )
                 node.test   = ast.copy_location(node_test, node.orelse[0].test)
                 node.body   = node.orelse[0].body
                 node.orelse = node.orelse[0].orelse
             else:
                 node.test   = ast.copy_location(node_test, node.test)
                 node.body   = node.orelse
                 node.orelse = []
         else:
             node = None
     return node
예제 #13
0
파일: control_flow.py 프로젝트: ASPP/numba
    def visit_If(self, node):
        exit_block = self.flow.exit_block(label='exit_if', pos=node)

        # Condition
        cond_block = self.flow.nextblock(self.flow.block, label='if_cond',
                                         is_expr=True, pos=node.test)
        node.test = self.visit(node.test)

        # Body
        if_block = self.flow.nextblock(label='if_body', pos=node.body[0])
        self.visitlist(node.body)
        if self.flow.block:
            self.flow.block.add_child(exit_block)

        # Else clause
        if node.orelse:
            else_block = self.flow.nextblock(cond_block,
                                             label='else_body',
                                             pos=node.orelse[0])
            self.visitlist(node.orelse)
            if self.flow.block:
                self.flow.block.add_child(exit_block)
        else:
            cond_block.add_child(exit_block)
            else_block = None

        new_node = nodes.build_if(cond_block=cond_block, test=node.test,
                              if_block=if_block, body=node.body,
                              else_block=else_block, orelse=node.orelse,
                              exit_block=exit_block)
        ast.copy_location(new_node, node)
        return self.exit_block(exit_block, new_node)
예제 #14
0
 def visit_Name(self, node):
     new_node = ast.Name(
         self._visit(node.id),
         self._visit(node.ctx),
     )
     ast.copy_location(new_node, node)
     return new_node
예제 #15
0
 def test_ast_2(self):
     self.ns['FOOBAR'] = ''
     expr = ast.Expression()
     expr.body = ast.parse('FOO+BAR').body[0].value
     ast.copy_location(expr, expr.body)
     self.ns['FOOBAR'] = Expression(expr)
     self.assertEqual(self.ns['FOOBAR'].get(), 'foobar')
예제 #16
0
파일: unroller.py 프로젝트: ml-lab/TerpreT
        def visit_Assign(self, node):
            if isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Call):
                subscr = node.value
                call = subscr.value
                
                if len(node.targets) > 1:
                    error.error('Cannot use multiple assignment in array declaration.', node)
                    
                variable_name = node.targets[0].id
                value_type = call.func.id
                declaration_args = call.args

                # Get the indices being accessed.
                shape = slice_node_to_tuple_of_numbers(subscr.slice)

                new_assigns = []
                for indices in itertools.product(*[range(n) for n in shape]):
                    index_name = flattened_array_name(variable_name, indices)
                    new_index_name_node = ast.copy_location(ast.Name(index_name, ast.Store()), node)
                    new_value_type_node = ast.copy_location(ast.Name(value_type, ast.Load()), node)
                    new_declaration_args = [copy.deepcopy(arg) for arg in declaration_args]
                    new_call_node = ast.copy_location(ast.Call(new_value_type_node, new_declaration_args, [], None, None), node)
                    new_assign = ast.Assign([new_index_name_node], new_call_node)
                    new_assign = ast.copy_location(new_assign, node)
                    new_assigns.append(new_assign)
                return new_assigns
            else:
                return node
예제 #17
0
    def CALL(self, tree):
        is_demandload = is_demandload_regex = False
        if isinstance(tree.func, _ast.Attribute):
            if not isinstance(tree.func.value, _ast.Name):
                # this means it's a multilevel lookup;
                # pkgcore.ebuild.ebuild_src.some_func
                # ignore it; it *could* miss a direct
                # snakeoil.demandload.demandload, but
                # I don't care, bad form of access imo.
                return self.handleChildren(tree)

            src = self.scope.get(tree.func.value.id)

            if getattr(src, 'is_demandload_module', False):
                if tree.func.attr == 'demandload':
                    is_demandload = True
                elif tree.func.attr == 'demand_compile_regexp':
                    is_demandload_regex = True
        elif hasattr(tree.func, 'id'):
            is_demandload       = getattr(self.scope.get(tree.func.id), 'is_demandload_func', False)
            is_demandload_regex = getattr(self.scope.get(tree.func.id), 'is_demandload_regex', False)

        if is_demandload_regex:
            # should do validation here.
            if len(tree.args) < 3:
                self.report(BadDemandloadRegexCall, tree.lineno)
            elif tree.args[1].__class__.__name__.upper() not in ("STR", "UNICODE"):
                self.report(BadDemandloadRegexCall, tree.lineno, "name argument isn't string nor unicode")
            elif tree.args[2].__class__.__name__.upper() not in ("STR", "UNICODE"):
                self.report(BadDemandloadRegexCall, tree.lineno, "regex argument isn't string nor unicode")
            else:
                code = "%s = re.compile(%r)\n" % (tree.args[1].s, tree.args[2].s)
                fakenode = _ast.copy_location(compile(code, self.filename, "exec", _ast.PyCF_ONLY_AST).body[0],
                    tree)
                self.addBinding(tree.lineno, _checker.Assignment(tree.args[1].s, fakenode))

        if is_demandload:
            if len(tree.args) < 2:
                self.report(BadDemandloadCall, tree.lineno)
                return self.handleChildren(tree)

            for chunk in tree.args[1:]:
                chunk_cls = chunk.__class__.__name__
                if chunk_cls.upper() not in ('STR', 'UNICODE'):
                    self.report(BadDemandloadCall, chunk.lineno,
                        "invoked with non string/unicode arg: %r" % (chunk_cls,))
                    continue
                s = chunk.s
                try:
                    targets = list(parse_demandload([s]))
                except ValueError, ve:
                    self.report(BadDemandloadCall, chunk.lineno, ve)
                    continue
                for src, asname in targets:
                    fakenode = _ast.copy_location(compile("import %s as %s\n" % (src, asname),
                        self.filename, "exec", _ast.PyCF_ONLY_AST).body[0],
                        chunk)
                    self.addBinding(chunk.lineno,
                        DemandloadImportation(asname, fakenode))
예제 #18
0
 def search_def_methods(self, class_node):
     for node in class_node.body:
         if isinstance(node, ast.FunctionDef):
             if self.is_test_method(node.name):
                 newnode = ast.arguments(args=[], vararg=None, kwarg=None, defaults=[])
                 ast.copy_location(newnode, node.args)
                 node.args = newnode
                 self.methods_to_run.append(node)
예제 #19
0
 def visit_With(self, node):
     new_node = ast.With(
         self._visit(node.items[0].context_expr),
         self._visit(node.items[0].optional_vars),
         self._visit(node.body)
     )
     ast.copy_location(new_node, node)
     return new_node
예제 #20
0
파일: suba.py 프로젝트: pcdinh/suba
	def visit_Expr(self, node):
		""" When capturing a call to include, we must grab it here, so we can replace the whole Expr(Call('include')).
		"""
		if type(node.value) is Call:
			call = node.value
			if type(call.func) is Name and call.func.id == 'include': 
				if len(call.args) < 1:
					raise FormatError("include requires at least a filename as an argument.")
				root = None
				# if the original call to include had an additional argument
				# use that argument as the root
				# print('call',ast.dump(call))
				if len(call.args) > 1:
					root = call.args[1].s
				# or if there was a root= kwarg provided, use that
				elif len(call.keywords) > 0:
					for k in call.keywords:
						if k.arg == "root":
							root = k.value.s
				if root is None:
					# if we didn't get one from the call to include
					# look for one that was given as an argument to the template() call
					root = self.root
				if type(root) is str:
					root = root.split(os.path.sep)
				# the first argument to include() is the filename
				template_name = call.args[0].s
				# get the ast tree that comes from this included file
				check, fundef = include_ast(template_name, root)
				# each include produces the code to execute, plus some code to check for freshness
				# this code absolutely must run first, because we can't restart the generator once it has already yielded
				self.preamble.append(check)
				if fundef is None:
					raise FormatError("include_ast returned None")
				# return a copy of the the cached ast tree, because it will be further modified to fit with the including template
				fundef = copy.deepcopy(fundef)
				_yieldall(fundef.body)
				for expr in fundef.body:
					self.visit(expr)
				return fundef.body
		elif type(node.value) is Yield:
			y = node.value
			if type(y.value) == Str:
				if self.stripWhitespace:
					s = strip_whitespace(y.value.s)
					if len(s) == 0:
						return None # dont even compile in the Expr(Yield) if it was only yielding white space
					else:
						y.value.s = s
			elif type(y.value) == Call:
				call = y.value
				if type(call.func) is Name:
					if self.seenFuncs.get(call.func.id, False) is not False: # was defined locally
						# replace the Call with one to ''.join(Call)
						y.value = _call(Attribute(value=Str(s=''), attr='join', ctx=Load()), [y.value])
						ast.copy_location(y.value, node)
		self.generic_visit(node)
		return node
예제 #21
0
 def visit_Str(self, node):
     if "html" == self._get_template_type():
         n = ast.Name(id='_q_htmltext', ctx=ast.Load())
         ast.copy_location(n, node)
         n = ast.Call(func=n, args=[node], keywords=[], starargs=None,
                      kwargs=None)
         return ast.copy_location(n, node)
     else:
         return node
예제 #22
0
파일: calc.py 프로젝트: GertBurger/ibid
 def visit_BinOp(self, node):
     self.generic_visit(node)
     if isinstance(node.op, Pow):
         fnode = Name("pow", Load())
         copy_location(fnode, node)
         cnode = Call(fnode, [node.left, node.right], [], None, None)
         copy_location(cnode, node)
         return cnode
     return node
예제 #23
0
 def visit_TryFinally(self, node):
     new_node = gast.Try(
         self._visit(node.body),
         [],  # handlers
         [],  # orelse
         self._visit(node.finalbody)
     )
     ast.copy_location(new_node, node)
     return new_node
예제 #24
0
 def visit_TryExcept(self, node):
     new_node = gast.Try(
         self._visit(node.body),
         self._visit(node.handlers),
         self._visit(node.orelse),
         []  # finalbody
     )
     ast.copy_location(new_node, node)
     return new_node
예제 #25
0
 def visit_Return(self, node):
         assign = ast.Assign(targets = [ast.Name(id = 'y' , ctx = ast.Store())], value = ast.Num(8))
         ast.increment_lineno(node, 1)
         ast.copy_location(assign, node)
         ast.fix_missing_locations(assign)
         #assign.col_offset = 8
        # lists = list(ast.iter_child_nodes(assign))
        # print lists
         return assign
예제 #26
0
 def visit_FunctionDef(self, node):
     new_node = ast.FunctionDef(
         self._visit(node.name),
         self._visit(node.args),
         self._visit(node.body),
         self._visit(node.decorator_list),
     )
     ast.copy_location(new_node, node)
     return new_node
예제 #27
0
def _seconds_to_mu(ref_period, node):
    divided = ast.copy_location(
        ast.BinOp(left=node,
                  op=ast.Div(),
                  right=value_to_ast(ref_period)),
        node)
    return ast.copy_location(
        ast.Call(func=ast.Name("round64", ast.Load()),
                 args=[divided], keywords=[]),
        divided)
    def visit_Assert(self, assert_node):
        if assert_node.msg is not None:
            return assert_node

        statements = AssertionChildVisitor().visit(assert_node.test)

        for s in statements:
            ast.copy_location(s, assert_node)
            ast.fix_missing_locations(s)
        return statements
예제 #29
0
    def visit_ClassDef(self, node):
        new_node = ast.ClassDef(
            self._visit(node.name),
            self._visit(node.bases),
            self._visit(node.body),
            self._visit(node.decorator_list),
        )

        ast.copy_location(new_node, node)
        return new_node
예제 #30
0
 def visit_With(self, node):
     new_node = gast.With(
         [gast.withitem(
             self._visit(node.context_expr),
             self._visit(node.optional_vars)
         )],
         self._visit(node.body)
     )
     ast.copy_location(new_node, node)
     return new_node
예제 #31
0
 def visit(self, node):
     new_node = super().visit(node)
     if new_node is not node:
         return ast.copy_location(new_node, node)
     return node
예제 #32
0
파일: astutils.py 프로젝트: mfkiwl/dace
 def visit_Num(self, node: ast.Num):
     newname = f'__uu{self.id}'
     self.gvars[newname] = node.n
     self.id += 1
     return ast.copy_location(ast.Name(id=newname, ctx=ast.Load()), node)
예제 #33
0
 def _aug_assign(self, target, oper, value):
     # transform a += 1 to a = a + 1, then we can use assign and eval
     new_value = ast.BinOp(left=target, op=oper, right=value)
     ast.copy_location(new_value, target)
     self._assign(target, new_value)
예제 #34
0
def ast_call(node):
    """Visit and transform ast call node"""
    distributed_mode = util_global.get_value("distributed_mode", "")
    is_not_strategy = distributed_mode in ("horovod", "")
    is_not_horovod = distributed_mode in ("tf_strategy", "")
    convert_loss_scale_api(node)
    if _call_name_match(node.func, "set_experimental_options"):
        log_msg(
            getattr(node, 'lineno', 'None'),
            'change set_experimental_options(*) to set_experimental_options(experimental_options)'
        )
        node.args = [ast.Name(id='experimental_options', ctx=ast.Load())]
        node.keywords = []
        util_global.set_value('need_conver', True)
    if isinstance(node.func,
                  ast.Name) and node.func.id == 'check_available_gpus':
        log_msg(getattr(node, 'lineno', 'None'),
                "change check_available_gpus() to ['/device:CPU:0']")
        util_global.set_value('need_conver', True)
        return ast.List(elts=[ast.Str(s="/device:CPU:0")], ctx=ast.Load())
    if ((isinstance(node.func, ast.Name) and node.func.id == 'GraphOptions')
            or (isinstance(node.func, ast.Attribute)
                and node.func.attr == 'GraphOptions')):
        log_success_report(getattr(node, 'lineno', 'None'), 'GraphOptions()')
        src = copy.deepcopy(node)
        node.func = ast.Name(id='npu_graph_options', ctx=ast.Load())
        node.args = []
        node.keywords = []
        node.keywords.append(ast.keyword(arg='graph_options', value=src))
        util_global.set_value('need_conver', True)
        return node
    if (isinstance(node.func, ast.Name) and node.func.id == 'OptimizerOptions') or \
            (isinstance(node.func, ast.Attribute) and node.func.attr == 'OptimizerOptions'):
        log_success_report(getattr(node, 'lineno', 'None'),
                           'OptimizerOptions()')
        src = copy.deepcopy(node)
        node.func = ast.Name(id='npu_optimizer_options', ctx=ast.Load())
        node.args = []
        node.keywords = []
        node.keywords.append(ast.keyword(arg='optimizer_options', value=src))
        util_global.set_value('need_conver', True)
        return node
    if _call_name_match(node.func, "Session"):
        return convert_origin_func_to_npu(node, tf_func_map["tf.Session"],
                                          "tf.Session", ["config"])
    if _call_name_match(node.func, "InteractiveSession"):
        return convert_origin_func_to_npu(node,
                                          tf_func_map["tf.InteractiveSession"],
                                          "tf.InteractiveSession", ["config"])
    if isinstance(node.func, ast.Attribute
                  ) and node.func.attr == "BroadcastGlobalVariablesHook":
        if isinstance(node.func.value,
                      ast.Name) and node.func.value.id == "hvd":
            if is_not_horovod:
                log_strategy_distributed_mode_error(node)
                return node
            log_msg(
                getattr(node, "lineno", "None"),
                'change hvd.BroadcastGlobalVariablesHook to NPUBroadcastGlobalVariablesHook'
            )
            node = pasta.parse(
                "NPUBroadcastGlobalVariablesHook(0, int(os.getenv('RANK_ID', '0')))"
            )
            util_global.set_value('need_conver', True)
            return node
    if isinstance(node.func, ast.Attribute
                  ) and node.func.attr == "BroadcastGlobalVariablesCallback":
        if isinstance(node.func.value,
                      ast.Attribute) and node.func.value.attr == "callbacks":
            if is_not_horovod:
                log_strategy_distributed_mode_error(node)
                return node
            log_msg(
                getattr(node, "lineno", "None"),
                'change hvd.callbacks.BroadcastGlobalVariablesCallback to NPUBroadcastGlobalVariablesCallback'
            )
            node = pasta.parse(
                "NPUBroadcastGlobalVariablesCallback(root_rank=0)")
            util_global.set_value('need_conver', True)
            return node
    if isinstance(node.func,
                  ast.Attribute) and node.func.attr == "DistributedOptimizer":
        if isinstance(node.func.value,
                      ast.Name) and node.func.value.id == "hvd":
            if is_not_horovod:
                log_strategy_distributed_mode_error(node)
                return node
            return convert_hvd_distributed_api(node)
    if isinstance(node.func, ast.Attribute) and node.func.attr == 'shard':
        log_success_report(getattr(node, "lineno", "None"), 'shard')
        node.args = [
            pasta.parse("int(os.getenv('RANK_SIZE', '1'))"),
            pasta.parse("int(os.getenv('RANK_ID', '0'))")
        ]
        node.keywords.clear()
        util_global.set_value('need_conver', True)
        return node
    if isinstance(node.func, ast.Attribute) and node.func.attr == 'dropout':
        if isinstance(node.func.value,
                      ast.Attribute) and node.func.value.attr == 'nn':
            for index, _ in enumerate(node.args):
                if index == 2:
                    return node
            for keyword in node.keywords:
                if keyword.arg == "noise_shape":
                    return node
            log_success_report(getattr(node, "lineno", "None"), 'dropout')
            node.func = ast.Attribute(value=ast.Name(id='npu_ops',
                                                     ctx=ast.Load()),
                                      attr='dropout',
                                      ctx=ast.Load())
            keywords_new = []
            for keyword in node.keywords:
                if keyword.arg != 'rate':
                    keywords_new.append(keyword)
                else:
                    keywords_new.append(
                        ast.keyword(arg='keep_prob',
                                    value=ast.BinOp(left=ast.Num(n=1),
                                                    op=ast.Sub(),
                                                    right=keyword.value)))
            node.keywords = keywords_new
            util_global.set_value('need_conver', True)
        return node
    if isinstance(node.func, ast.Attribute) and \
            ((node.func.attr == 'map_and_batch') or
             (node.func.attr == 'batch' and (not isinstance(node.func.value, ast.Attribute) or (
                     isinstance(node.func.value, ast.Attribute) and node.func.value.attr != 'train')))):
        exist = False
        for keyword in node.keywords:
            if keyword.arg == 'drop_remainder':
                exist = True
                if ((isinstance(keyword.value, ast.NameConstant)
                     and not keyword.value.value)
                        or (not isinstance(keyword.value, ast.NameConstant))):
                    log_success_report(getattr(node, "lineno", "None"),
                                       node.func.attr)
                    keyword.value = pasta.parse('True')
                    util_global.set_value('need_conver', True)
        if not exist:
            log_success_report(getattr(node, "lineno", "None"), node.func.attr)
            keyword = ast.keyword(arg='drop_remainder',
                                  value=pasta.parse('True'))
            node.keywords.insert(0, keyword)
            util_global.set_value('need_conver', True)
        return node
    if (isinstance(node.func, ast.Attribute)
            and isinstance(node.func.value, ast.Name)
            and node.func.value.id == 'tf' and node.func.attr == 'device'):
        log_success_report(getattr(node, "lineno", "None"), node.func.attr)
        node.args = [ast.Str(s='/cpu:0')]
        util_global.set_value('need_conver', True)
        return node
    if isinstance(node.func, ast.Attribute) and \
            (node.func.attr == "get_distribution_strategy" or
             node.func.attr == "MirroredStrategy" or
             node.func.attr == "MultiWorkerMirroredStrategy"):
        if is_not_strategy:
            log_hvd_distributed_mode_error(node)
            return node
        log_success_report(getattr(node, "lineno", "None"), node.func.attr)
        new_func = ast.Attribute(value=ast.Name(id="npu_strategy",
                                                ctx=ast.Load()),
                                 attr="NPUStrategy",
                                 ctx=ast.Load())
        ast.copy_location(new_func, node.func)
        node.func = new_func
        node.keywords = []
        node.args = []
        util_global.set_value('need_conver', True)
        return node
    if (isinstance(node.func, ast.Attribute) and (node.func.attr == 'RunConfig')) and \
            (_call_name_match(node.func.value, 'estimator') or _call_name_match(node.func.value, 'tpu')):
        if node.keywords.count("train_distribute") or node.keywords.count(
                "eval_distribute"):
            if is_not_strategy:
                log_hvd_distributed_mode_error(node)
        save_summary_steps = None
        for keyword in node.keywords:
            if keyword.arg == 'save_summary_steps':
                save_summary_steps = keyword
                break
        if len(node.args) < 3 and not save_summary_steps:
            log_msg(getattr(node, 'lineno'),
                    'RunConfig() add save_summary_steps=0')
            util_global.set_value('need_conver', True)
            node.keywords.append(
                ast.keyword(arg='save_summary_steps', value=pasta.parse('0')))
        return node
    if isinstance(node.func, ast.Attribute) and (node.func.attr == 'TPUEstimator') and \
            ((isinstance(node.func.value, ast.Attribute) and (node.func.value.attr == 'tpu')) or
             (isinstance(node.func.value, ast.Name) and (node.func.value.id == 'tpu'))):
        add_eval_on_tpu = True
        add_use_tpu = True
        add_export_to_tpu = True
        for keyword in node.keywords:
            if (keyword.arg == 'eval_on_tpu') or (
                    keyword.arg == 'use_tpu') or (keyword.arg
                                                  == 'export_to_tpu'):
                if (not isinstance(keyword.value, ast.NameConstant)) or \
                        (isinstance(keyword.value, ast.NameConstant) and (keyword.value.value)):
                    log_success_report(getattr(node, 'lineno', 'None'),
                                       'TPUEstimator(' + keyword.arg + '=*)')
                    keyword.value = pasta.parse('False')
                    util_global.set_value('need_conver', True)
                if add_eval_on_tpu and (keyword.arg == 'eval_on_tpu'):
                    add_eval_on_tpu = False
                if add_use_tpu and (keyword.arg == 'use_tpu'):
                    add_use_tpu = False
                if add_export_to_tpu and (keyword.arg == 'export_to_tpu'):
                    add_export_to_tpu = False
        if add_eval_on_tpu:
            log_success_report(getattr(node, 'lineno', 'None'),
                               'TPUEstimator(eval_on_tpu=*)')
            node.keywords.append(
                ast.keyword(arg='eval_on_tpu', value=pasta.parse('False')))
            util_global.set_value('need_conver', True)
        if add_use_tpu:
            log_success_report(getattr(node, 'lineno', 'None'),
                               'TPUEstimator(use_tpu=*)')
            node.keywords.append(
                ast.keyword(arg='use_tpu', value=pasta.parse('False')))
            util_global.set_value('need_conver', True)
        if add_export_to_tpu:
            log_success_report(getattr(node, 'lineno', 'None'),
                               'TPUEstimator(export_to_tpu=*)')
            node.keywords.append(
                ast.keyword(arg='export_to_tpu', value=pasta.parse('False')))
            util_global.set_value('need_conver', True)
    if isinstance(node.func,
                  ast.Attribute) and (node.func.attr
                                      == 'VirtualDeviceConfiguration'):
        log_success_report(getattr(node, 'lineno', 'None'),
                           'VirtualDeviceConfiguration')
        util_global.set_value('need_conver', True)
        memory_limit = None
        for keyword in node.keywords:
            if keyword.arg == 'memory_limit':
                memory_limit = keyword
                break
        if memory_limit:
            memory_limit.value = ast.NameConstant(value=None)
        else:
            node.keywords.append(
                ast.keyword(arg='memory_limit',
                            value=ast.NameConstant(value=None)))
        return node
    if isinstance(node.func,
                  ast.Attribute) and (node.func.attr
                                      == 'set_soft_device_placement'):
        log_success_report(getattr(node, 'lineno', 'None'),
                           'set_soft_device_placement')
        util_global.set_value('need_conver', True)
        node.args = []
        node.keywords = [
            ast.keyword(arg='enabled', value=ast.NameConstant(value=True))
        ]
        return node
    if isinstance(node.func, ast.Attribute) and (node.func.attr
                                                 == 'set_memory_growth'):
        log_success_report(getattr(node, 'lineno', 'None'),
                           'set_memory_growth')
        util_global.set_value('need_conver', True)
        node = ast.NameConstant(value=None)
        return node
    if isinstance(node.func,
                  ast.Attribute) and (node.func.attr
                                      == 'set_virtual_device_configuration'):
        log_success_report(getattr(node, 'lineno', 'None'),
                           'set_virtual_device_configuration')
        util_global.set_value('need_conver', True)
        node = ast.NameConstant(value=None)
        return node
    if isinstance(node.func, ast.Attribute) and (node.func.attr
                                                 == 'jit_scope'):
        if isinstance(node.func.value, ast.Attribute) and (node.func.value.attr
                                                           == 'experimental'):
            if isinstance(node.func.value.value,
                          ast.Attribute) and (node.func.value.value.attr
                                              == 'xla'):
                log_success_report(getattr(node, 'lineno', 'None'),
                                   '*.xla.experimental.jit_scope')
                util_global.set_value('need_conver', True)
                compile_ops = None
                for keyword in node.keywords:
                    if keyword.arg == 'compile_ops':
                        compile_ops = keyword
                        break
                if compile_ops:
                    compile_ops.value = pasta.parse('False')
                else:
                    node.keywords.append(
                        ast.keyword(arg='compile_ops',
                                    value=pasta.parse('False')))
                return node
    for estimator in util_global.get_value('Estimators', []):
        if (isinstance(node.func, ast.Attribute) and (node.func.attr == estimator)) \
                or (isinstance(node.func, ast.Name) and (node.func.id == estimator)):
            log_msg(
                getattr(node, 'lineno'),
                "".join([estimator, '() add config=npu_run_config_init()']))
            config = None
            for keyword in node.keywords:
                if keyword.arg == 'config':
                    config = keyword
                    break
            if config:
                new_value = ast.Call(func=ast.Name(id='npu_run_config_init',
                                                   ctx=ast.Load()),
                                     args=[],
                                     keywords=[
                                         ast.keyword(arg='run_config',
                                                     value=config.value)
                                     ])
                ast.copy_location(new_value, config.value)
                config.value = new_value
            else:
                node.keywords.append(
                    ast.keyword(arg='config',
                                value=pasta.parse('npu_run_config_init()')))
            util_global.set_value('need_conver', True)
            return node
    if isinstance(node.func, ast.Attribute) and (node.func.attr
                                                 == 'clear_session'):
        log_msg(getattr(node, 'lineno'),
                "change keras.clear_session() to npu_clear_session()")
        node = ast.Call(func=ast.Name(id='npu_clear_session', ctx=ast.Load()),
                        args=[],
                        keywords=[])
        util_global.set_value('need_conver', True)
    if _call_name_match(node.func, "MonitoredTrainingSession"):
        return convert_origin_func_to_npu(
            node, tf_func_map["tf.train.MonitoredTrainingSession"],
            "MonitoredTrainingSession", ["config", "hooks"])
    if isinstance(node.func,
                  ast.Attribute) and node.func.attr == "managed_session":
        return convert_origin_func_to_npu(
            node, tf_func_map["tf.train.Supervisor.managed_session"],
            "managed_session", ["config"], True)
    if distributed_mode == "tf_strategy":  # this cond should be placed at the end of the Call function.
        return convert_distributed_strategy_apis(node)
    return node
예제 #35
0
    def visit_Assign(self, node):
        assert (len(node.targets) == 1)
        self.generic_visit(node)

        decorated = isinstance(node.value, ast.Call) and isinstance(
            node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Name) \
                    and node.value.func.value.id == 'ti'

        is_static_assign = False

        if decorated:
            attr = node.value.func
            if attr.attr == 'static':
                is_static_assign = True
            else:
                pass  # eg. x = ti.cast(xx) will reach here, but they're not decorators, so no raising errors here

        if is_static_assign:
            return node

        if isinstance(node.targets[0], ast.Tuple):
            targets = node.targets[0].elts

            # Create
            stmts = []

            holder = self.parse_stmt('__tmp_tuple = ti.expr_init_list(0, '
                                     f'{len(targets)})')
            holder.value.args[0] = node.value

            stmts.append(holder)

            def tuple_indexed(i):
                indexing = self.parse_stmt('__tmp_tuple[0]')
                indexing.value.slice.value = self.parse_expr("{}".format(i))
                return indexing.value

            for i, target in enumerate(targets):
                is_local = isinstance(target, ast.Name)
                if is_local and self.is_creation(target.id):
                    var_name = target.id
                    target.ctx = ast.Store()
                    # Create
                    init = ast.Attribute(value=ast.Name(id='ti',
                                                        ctx=ast.Load()),
                                         attr='expr_init',
                                         ctx=ast.Load())
                    rhs = ast.Call(
                        func=init,
                        args=[tuple_indexed(i)],
                        keywords=[],
                    )
                    self.create_variable(var_name)
                    stmts.append(ast.Assign(targets=[target], value=rhs))
                else:
                    # Assign
                    target.ctx = ast.Load()
                    func = ast.Attribute(value=target,
                                         attr='assign',
                                         ctx=ast.Load())
                    call = ast.Call(func=func,
                                    args=[tuple_indexed(i)],
                                    keywords=[])
                    stmts.append(ast.Expr(value=call))

            for stmt in stmts:
                ast.copy_location(stmt, node)
            stmts.append(self.parse_stmt('del __tmp_tuple'))
            return self.make_single_statement(stmts)
        else:
            is_local = isinstance(node.targets[0], ast.Name)
            if is_local and self.is_creation(node.targets[0].id):
                var_name = node.targets[0].id
                # Create
                init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()),
                                     attr='expr_init',
                                     ctx=ast.Load())
                rhs = ast.Call(
                    func=init,
                    args=[node.value],
                    keywords=[],
                )
                self.create_variable(var_name)
                return ast.copy_location(
                    ast.Assign(targets=node.targets, value=rhs), node)
            else:
                # Assign
                node.targets[0].ctx = ast.Load()
                func = ast.Attribute(value=node.targets[0],
                                     attr='assign',
                                     ctx=ast.Load())
                call = ast.Call(func=func, args=[node.value], keywords=[])
                return ast.copy_location(ast.Expr(value=call), node)
예제 #36
0
    def visit_For(self, node):
        if node.orelse:
            raise TaichiSyntaxError(
                "'else' clause for 'for' not supported in Taichi kernels")
        decorated = isinstance(node.iter, ast.Call) and isinstance(
            node.iter.func, ast.Attribute) and isinstance(node.iter.func.value, ast.Name) \
                    and node.iter.func.value.id == 'ti'
        is_ndrange_for = False
        is_static_for = False
        is_grouped = False

        if decorated:
            attr = node.iter.func
            if attr.attr == 'static':
                is_static_for = True
            elif attr.attr == 'grouped':
                is_grouped = True
            elif attr.attr == 'ndrange':
                is_ndrange_for = True
            else:
                raise Exception('Not supported')
        is_range_for = isinstance(node.iter, ast.Call) and isinstance(
            node.iter.func, ast.Name) and node.iter.func.id == 'range'
        ast.fix_missing_locations(node)
        if not is_ndrange_for:
            self.generic_visit(node, ['body'])
        if is_ndrange_for:
            template = '''
if ti.static(1):
  __ndrange = 0
  for __ndrange_I in range(0):
    __I = __ndrange_I
      '''
            t = ast.parse(template).body[0]
            t.body[0].value = node.iter
            t.body[1].iter.args[0] = self.parse_expr(
                '__ndrange.acc_dimensions[0]')
            targets = node.target
            if isinstance(targets, ast.Tuple):
                targets = [name.id for name in targets.elts]
            else:
                targets = [targets.id]
            loop_body = t.body[1].body
            for i in range(len(targets)):
                if i + 1 < len(targets):
                    stmt = '__{} = __I // __ndrange.acc_dimensions[{}]'.format(
                        targets[i], i + 1)
                else:
                    stmt = '__{} = __I'.format(targets[i])
                loop_body.append(self.parse_stmt(stmt))
                stmt = '{} = __{} + __ndrange.bounds[{}][0]'.format(
                    targets[i], targets[i], i)
                loop_body.append(self.parse_stmt(stmt))
                if i + 1 < len(targets):
                    stmt = '__I = __I - __{} * __ndrange.acc_dimensions[{}]'.format(
                        targets[i], i + 1)
                    loop_body.append(self.parse_stmt(stmt))
            loop_body += node.body

            node = ast.copy_location(t, node)
            return self.visit(node)  # further translate as a range for
        elif is_static_for:
            return node
        elif is_range_for:
            loop_var = node.target.id
            self.check_loop_var(loop_var)
            template = ''' 
if 1:
  {} = ti.Expr(ti.core.make_id_expr(''))
  ___begin = ti.Expr(0) 
  ___end = ti.Expr(0)
  ___begin = ti.cast(___begin, ti.i32)
  ___end = ti.cast(___end, ti.i32)
  ti.core.begin_frontend_range_for({}.ptr, ___begin.ptr, ___end.ptr)
  ti.core.end_frontend_range_for()
      '''.format(loop_var, loop_var)
            t = ast.parse(template).body[0]

            assert len(node.iter.args) in [1, 2]
            if len(node.iter.args) == 2:
                bgn = node.iter.args[0]
                end = node.iter.args[1]
            else:
                bgn = self.make_constant(value=0)
                end = node.iter.args[0]

            t.body[1].value.args[0] = bgn
            t.body[2].value.args[0] = end
            t.body = t.body[:6] + node.body + t.body[6:]
            t.body.append(self.parse_stmt('del {}'.format(loop_var)))
            return ast.copy_location(t, node)
        else:  # Struct for
            if isinstance(node.target, ast.Name):
                elts = [node.target]
            else:
                elts = node.target.elts

            for loop_var in elts:
                self.check_loop_var(loop_var.id)

            var_decl = ''.join(
                '  {} = ti.Expr(ti.core.make_id_expr(""))\n'.format(ind.id)
                for ind in elts)
            vars = ', '.join(ind.id for ind in elts)
            if is_grouped:
                template = ''' 
if 1:
  ___loop_var = 0
  {} = ti.make_var_vector(size=___loop_var.loop_range().dim())
  ___expr_group = ti.make_expr_group({})
  ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr)
  ti.core.end_frontend_range_for()
        '''.format(vars, vars)
                t = ast.parse(template).body[0]
                cut = 4
                t.body[0].value = node.iter
                t.body = t.body[:cut] + node.body + t.body[cut:]
            else:
                template = ''' 
if 1:
{}
  ___loop_var = 0
  ___expr_group = ti.make_expr_group({})
  ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr)
  ti.core.end_frontend_range_for()
        '''.format(var_decl, vars)
                t = ast.parse(template).body[0]
                cut = len(elts) + 3
                t.body[cut - 3].value = node.iter
                t.body = t.body[:cut] + node.body + t.body[cut:]
            for loop_var in reversed(elts):
                t.body.append(self.parse_stmt('del {}'.format(loop_var.id)))
            return ast.copy_location(t, node)
예제 #37
0
 def visit(self, node):
     method = 'visit_' + node.__class__.__name__
     visitor = getattr(self, method, self.generic_visit)
     return ast.copy_location(visitor(node), self.__benchmark)
예제 #38
0
    def _unellipsify(self, node, slices, subscript_node):
        """
        Given an array node `node`, process all AST slices and create the
        final type:

            - process newaxes (None or numpy.newaxis)
            - replace Ellipsis with a bunch of ast.Slice objects
            - process integer indices
            - append any missing slices in trailing dimensions
        """
        type = node.variable.type

        if not type.is_array:
            assert type.is_object
            return minitypes.object_, node

        if (len(slices) == 1 and self._is_constant_index(slices[0])
                and slices[0].value.pyval is Ellipsis):
            # A[...]
            return type, node

        result = []
        seen_ellipsis = False

        # Filter out newaxes
        newaxes = [newaxis for newaxis in slices if self._is_newaxis(newaxis)]
        n_indices = len(slices) - len(newaxes)

        full_slice = ast.Slice(lower=None, upper=None, step=None)
        full_slice.variable = Variable(numba_types.SliceType())
        ast.copy_location(full_slice, slices[0])

        # process ellipses and count integer indices
        indices_seen = 0
        for slice_node in slices[::-1]:
            slice_type = slice_node.variable.type
            if slice_type.is_ellipsis:
                if seen_ellipsis:
                    result.append(full_slice)
                else:
                    nslices = type.ndim - n_indices + 1
                    result.extend([full_slice] * nslices)
                    seen_ellipsis = True
            elif (slice_type.is_slice or slice_type.is_int
                  or self._is_newaxis(slice_node)):
                indices_seen += slice_type.is_int
                result.append(slice_node)
            else:
                # TODO: Coerce all object operands to integer indices?
                # TODO: (This will break indexing with the Ellipsis object or
                # TODO:  with slice objects that we couldn't infer)
                return minitypes.object_, nodes.CoercionNode(
                    node, minitypes.object_)

        # append any missing slices (e.g. a2d[:]
        result_length = len(result) - len(newaxes)
        if result_length < type.ndim:
            nslices = type.ndim - result_length
            result.extend([full_slice] * nslices)

        result.reverse()
        subscript_node.slice = ast.ExtSlice(result)
        ast.copy_location(subscript_node.slice, slices[0])

        # create the final array type and set it in value.variable
        result_dtype = node.variable.type.dtype
        result_ndim = node.variable.type.ndim + len(newaxes) - indices_seen
        if result_ndim > 0:
            result_type = result_dtype[(slice(None), ) * result_ndim]
        elif result_ndim == 0:
            result_type = result_dtype
        else:
            result_type = minitypes.object_

        return result_type, node
예제 #39
0
 def visit_FunctionDef(self, orig_ast: ast.FunctionDef) -> ast.AST:
     cascades_body = []
     # Compute selected features.
     for node in selected_feature_nodes:
         cascades_body += node.get_ast()
     # Predict with approximate model.
     approximate_model_ast = create_model_ast(
         function_name=predict_proba_function.__name__,
         output_name="__willump_approximate_preds",
         input_names=selected_feature_names,
         model_param="__willump_approximate_model")
     cascades_body += approximate_model_ast
     # Get indices that can't be approximated.
     unapproximated_indices_ast = create_function_ast(
         function_name="__willump_get_unapproximated_indices",
         output_name="__willump_unapproximated_indices",
         input_names=[
             "__willump_approximate_preds",
             "__willump_cascade_threshold"
         ])
     cascades_body += unapproximated_indices_ast
     # Only compute remaining features for unapproximated indices.
     shortened_inputs = set()
     for node in remaining_feature_nodes:
         for input_name in node.input_names:
             if input_name not in shortened_inputs:
                 shorten_ast = create_function_ast(
                     function_name=
                     "__willump_select_unapproximated_rows",
                     output_name=input_name,
                     input_names=[
                         input_name, "__willump_unapproximated_indices"
                     ])
                 cascades_body += shorten_ast
                 shortened_inputs.add(input_name)
         cascades_body += node.get_ast()
     # Shorten the selected features.
     for name in selected_feature_names:
         shorten_ast = create_function_ast(
             function_name="__willump_select_unapproximated_rows",
             output_name=name,
             input_names=[name, "__willump_unapproximated_indices"])
         cascades_body += shorten_ast
     # Predict with full model.
     full_model_ast = create_model_ast(
         function_name=predict_function.__name__,
         output_name="__willump_full_preds",
         input_names=model_node.input_names,
         model_param="__willump_full_model")
     cascades_body += full_model_ast
     # Return combined predictions.
     combine_predictions_ast = create_function_ast(
         function_name="__willump_combine_predictions",
         output_name="__willump_final_predictions",
         input_names=[
             "__willump_approximate_preds", "__willump_full_preds",
             "__willump_cascade_threshold"
         ])
     cascades_body += combine_predictions_ast
     return_ast = ast.parse("return __willump_final_predictions",
                            "exec").body
     cascades_body += return_ast
     # Finalize AST.
     new_ast = copy.deepcopy(orig_ast)
     new_ast.body = cascades_body
     # No recursion allowed!
     new_ast.decorator_list = []
     return ast.copy_location(new_ast, orig_ast)
예제 #40
0
 def visit_Return(self, node, visit_count):
     """Visit a ``Return`` node."""
     if not self.ordinal or visit_count in self.ordinal:
         return ast.copy_location(self.inject(node), node)
     self.generic_visit(node)
     return node
예제 #41
0
 def visit_With(self, node):
     new_node = ast.With(self._visit(node.items[0].context_expr),
                         self._visit(node.items[0].optional_vars),
                         self._visit(node.body))
     ast.copy_location(new_node, node)
     return new_node
예제 #42
0
def _copy_location(newnode, node):
    return ast.fix_missing_locations(ast.copy_location(newnode, node))
예제 #43
0
    def visitName(self, node: ast.Name):
        if node.id == "__debug__":
            return copy_location(Constant(not self.optimize), node)

        return self.generic_visit(node)
예제 #44
0
 def passer(_: ast3.NodeTransformer, node: cls) -> None:
     return ast.copy_location(ast.Pass(), node)
예제 #45
0
    def visit_Assign(self, node):
        assert (len(node.targets) == 1)
        self.generic_visit(node)

        if isinstance(node.targets[0], ast.Tuple):
            targets = node.targets[0].elts

            # Create
            stmts = []

            holder = self.parse_stmt('__tmp_tuple = 0')
            holder.value = node.value

            stmts.append(holder)

            def tuple_indexed(i):
                indexing = self.parse_stmt('__tmp_tuple[0]')
                indexing.value.slice.value = self.parse_expr("{}".format(i))
                return indexing.value

            for i, target in enumerate(targets):
                is_local = isinstance(target, ast.Name)
                if is_local and self.is_creation(target.id):
                    var_name = target.id
                    target.ctx = ast.Store()
                    # Create
                    init = ast.Attribute(value=ast.Name(id='ti',
                                                        ctx=ast.Load()),
                                         attr='expr_init',
                                         ctx=ast.Load())
                    rhs = ast.Call(
                        func=init,
                        args=[tuple_indexed(i)],
                        keywords=[],
                    )
                    self.create_variable(var_name)
                    stmts.append(ast.Assign(targets=[target], value=rhs))
                else:
                    # Assign
                    target.ctx = ast.Load()
                    func = ast.Attribute(value=target,
                                         attr='assign',
                                         ctx=ast.Load())
                    call = ast.Call(func=func,
                                    args=[tuple_indexed(i)],
                                    keywords=[])
                    stmts.append(ast.Expr(value=call))

            for stmt in stmts:
                ast.copy_location(stmt, node)
            stmts.append(self.parse_stmt('del __tmp_tuple'))
            return self.make_single_statement(stmts)
        else:
            is_local = isinstance(node.targets[0], ast.Name)
            if is_local and self.is_creation(node.targets[0].id):
                var_name = node.targets[0].id
                # Create
                init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()),
                                     attr='expr_init',
                                     ctx=ast.Load())
                rhs = ast.Call(
                    func=init,
                    args=[node.value],
                    keywords=[],
                )
                self.create_variable(var_name)
                return ast.copy_location(
                    ast.Assign(targets=node.targets, value=rhs), node)
            else:
                # Assign
                node.targets[0].ctx = ast.Load()
                func = ast.Attribute(value=node.targets[0],
                                     attr='assign',
                                     ctx=ast.Load())
                call = ast.Call(func=func, args=[node.value], keywords=[])
                return ast.copy_location(ast.Expr(value=call), node)
예제 #46
0
def fix_location(new, old):
    ast.copy_location(new, old)
    ast.fix_missing_locations(new)
    return new
예제 #47
0
파일: rewrite.py 프로젝트: symonk/pytest
    def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
        """Return the AST statements to replace the ast.Assert instance.

        This rewrites the test of an assertion to provide
        intermediate values and replace it with an if statement which
        raises an assertion error with a detailed explanation in case
        the expression is false.
        """
        if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
            from _pytest.warning_types import PytestAssertRewriteWarning
            import warnings

            # TODO: This assert should not be needed.
            assert self.module_path is not None
            warnings.warn_explicit(
                PytestAssertRewriteWarning(
                    "assertion is always true, perhaps remove parentheses?"
                ),
                category=None,
                filename=self.module_path,
                lineno=assert_.lineno,
            )

        self.statements: List[ast.stmt] = []
        self.variables: List[str] = []
        self.variable_counter = itertools.count()

        if self.enable_assertion_pass_hook:
            self.format_variables: List[str] = []

        self.stack: List[Dict[str, ast.expr]] = []
        self.expl_stmts: List[ast.stmt] = []
        self.push_format_context()
        # Rewrite assert into a bunch of statements.
        top_condition, explanation = self.visit(assert_.test)

        negation = ast.UnaryOp(ast.Not(), top_condition)

        if self.enable_assertion_pass_hook:  # Experimental pytest_assertion_pass hook
            msg = self.pop_format_context(ast.Str(explanation))

            # Failed
            if assert_.msg:
                assertmsg = self.helper("_format_assertmsg", assert_.msg)
                gluestr = "\n>assert "
            else:
                assertmsg = ast.Str("")
                gluestr = "assert "
            err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
            err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
            err_name = ast.Name("AssertionError", ast.Load())
            fmt = self.helper("_format_explanation", err_msg)
            exc = ast.Call(err_name, [fmt], [])
            raise_ = ast.Raise(exc, None)
            statements_fail = []
            statements_fail.extend(self.expl_stmts)
            statements_fail.append(raise_)

            # Passed
            fmt_pass = self.helper("_format_explanation", msg)
            orig = _get_assertion_exprs(self.source)[assert_.lineno]
            hook_call_pass = ast.Expr(
                self.helper(
                    "_call_assertion_pass",
                    ast.Num(assert_.lineno),
                    ast.Str(orig),
                    fmt_pass,
                )
            )
            # If any hooks implement assert_pass hook
            hook_impl_test = ast.If(
                self.helper("_check_if_assertion_pass_impl"),
                self.expl_stmts + [hook_call_pass],
                [],
            )
            statements_pass = [hook_impl_test]

            # Test for assertion condition
            main_test = ast.If(negation, statements_fail, statements_pass)
            self.statements.append(main_test)
            if self.format_variables:
                variables = [
                    ast.Name(name, ast.Store()) for name in self.format_variables
                ]
                clear_format = ast.Assign(variables, ast.NameConstant(None))
                self.statements.append(clear_format)

        else:  # Original assertion rewriting
            # Create failure message.
            body = self.expl_stmts
            self.statements.append(ast.If(negation, body, []))
            if assert_.msg:
                assertmsg = self.helper("_format_assertmsg", assert_.msg)
                explanation = "\n>assert " + explanation
            else:
                assertmsg = ast.Str("")
                explanation = "assert " + explanation
            template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
            msg = self.pop_format_context(template)
            fmt = self.helper("_format_explanation", msg)
            err_name = ast.Name("AssertionError", ast.Load())
            exc = ast.Call(err_name, [fmt], [])
            raise_ = ast.Raise(exc, None)

            body.append(raise_)

        # Clear temporary variables by setting them to None.
        if self.variables:
            variables = [ast.Name(name, ast.Store()) for name in self.variables]
            clear = ast.Assign(variables, ast.NameConstant(None))
            self.statements.append(clear)
        # Fix locations (line numbers/column offsets).
        for stmt in self.statements:
            for node in traverse_node(stmt):
                ast.copy_location(node, assert_)
        return self.statements
예제 #48
0
def dyn(n):
    return ast.copy_location(ast.Name(id='Any', ctx=ast.Load()), n)
예제 #49
0
def convert_distributed_strategy_apis(node):
    """Convert distributed strategy API"""
    if isinstance(node.func, ast.Attribute) and isinstance(
            node.func.value, ast.Attribute):
        if ("Optimizer" in node.func.attr
                and node.func.attr != "ScipyOptimizerInterface"
                and node.func.attr != "MixedPrecisionLossScaleOptimizer"):
            log_msg(getattr(node, "lineno", "None"),
                    "add npu distribute optimizer to tensorflow optimizer")
            new_node = ast.Call(func=ast.Name(
                id="npu_distributed_optimizer_wrapper", ctx=ast.Load()),
                                args=[node],
                                keywords=[])
            ast.copy_location(new_node, node)
            util_global.set_value('need_conver', True)
            return new_node
    if isinstance(
            node.func, ast.Name
    ) and "Optimizer" in node.func.id and node.func.id != "NPULossScaleOptimizer":
        log_msg(getattr(node, "lineno", "None"),
                "add npu distribute optimizer to tensorflow optimizer")
        new_node = ast.Call(func=ast.Name(
            id="npu_distributed_optimizer_wrapper", ctx=ast.Load()),
                            args=[node],
                            keywords=[])
        ast.copy_location(new_node, node)
        util_global.set_value('need_conver', True)
        return new_node
    if _call_name_match(node.func, "TrainSpec"):
        return convert_origin_func_to_npu(
            node, tf_func_map["tf.estimator.TrainSpec"], "TrainSpec",
            ["hooks"])
    if _call_name_match(node.func, "EvalSpec"):
        return convert_origin_func_to_npu(node,
                                          tf_func_map["tf.estimator.EvalSpec"],
                                          "EvalSpec", ["hooks"])
    if isinstance(node.func, ast.Attribute) and node.func.attr == "train":
        if isinstance(node.func.value,
                      ast.Attribute) and node.func.value.attr == "learning":
            return node
        return convert_origin_func_to_npu(
            node, tf_func_map["tf.estimator.Estimator.train"],
            "Estimator.train", ["hooks"], True)
    if isinstance(node.func, ast.Attribute) and (node.func.attr == 'compile'):
        if isinstance(node.func.value,
                      ast.Name) and node.func.value.id == "re":
            return node
        return convert_origin_func_to_npu(
            node, tf_func_map["tf.keras.Model.compile"], "Model.compile",
            ["optimizer"], True)
    if isinstance(node.func, ast.Attribute) and node.func.attr == "fit":
        return convert_origin_func_to_npu(node,
                                          tf_func_map["tf.keras.Model.fit"],
                                          "Model.fit", ["callbacks"], True)
    if isinstance(node.func,
                  ast.Attribute) and node.func.attr == "fit_generator":
        return convert_origin_func_to_npu(
            node, tf_func_map["tf.keras.Model.fit_generator"],
            "Model.fit_generator", ["callbacks"], True)
    if isinstance(node.func, ast.Attribute) and node.func.attr == "gradients" and \
            isinstance(node.func.value, ast.Name) and node.func.value.id == "tf":
        return convert_tf_gradient_distributed(node)
    return node
예제 #50
0
파일: astutils.py 프로젝트: mfkiwl/dace
    def visit_Subscript(self, node: ast.Subscript):
        if rname(node) in self.keywords:
            return ast.copy_location(node.value, node)

        return self.generic_visit(node)
예제 #51
0
    def visit_BinOp(self, node):
        self.generic_visit(node)
        self.binop_count += 1

        # Check if this is the node we want to alter. We can accomplish this by
        # keeping track of a counter, which we increment every time encounter
        # a BinOp. Since the traversal through the AST is deterministic using the visitor
        # pattern (IT IS NOT DETERMINISTIC IF YOU USE ast.walk), we can identify AST nodes
        # uniquely by the value of the counter
        if (self.binop_count == self.count_of_node_to_mutate):
            # We make sure to use deepcopy so that we preserve all extra
            # information we don't explicitly modify
            new_node = copy.deepcopy(node)
            ast.copy_location(new_node, node)

            # figure out a way to randomize what operator it transforms to based on the current operator

            if isinstance(node.op, ast.Add):
                #randomly generate a number which will associate to a certain type of transformation
                num = random.randint(0, 2)
                print('random number', num)
                if num == 0:
                    new_node.op = ast.Mult()
                if num == 1:
                    new_node.op = ast.Sub()
                if num == 2:
                    new_node.op = ast.Div()
            if isinstance(node.op, ast.Mult):
                num = random.randint(0, 3)
                if num == 0:
                    new_node.op = ast.Div()
                if num == 1:
                    new_node.op = ast.Add()
                if num == 2:
                    new_node.op = ast.FloorDiv()
                if num == 3:
                    new_node.op = ast.Sub()
            if isinstance(node.op, ast.Div):
                num = random.randint(0, 2)
                if num == 0:
                    new_node.op = ast.FloorDiv()
                if num == 1:
                    new_node.op = ast.Mult()
                if num == 2:
                    new_node.op = ast.Add()
            if isinstance(node.op, ast.Sub):
                num = random.randint(0, 2)
                if num == 0:
                    new_node.op = ast.Add()
                if num == 1:
                    new_node.op = ast.Mult()
                if num == 2:
                    new_node.op = ast.Div()
            if isinstance(node.op, ast.FloorDiv):
                new_node.op = ast.Div()
            print('I AM CREATING A NEW NODE HERE', self.binop_count)
            return new_node
        else:
            # If we're not looking at an add node we want to change, don't modify
            # this node whatsoever
            return node
예제 #52
0
 def visit_Str(self, node):
     return ast.copy_location(ast.Str(s=node.s.lower()), node)
예제 #53
0
    def visit_For(self, node):
        if node.orelse:
            raise TaichiSyntaxError(
                "'else' clause for 'for' not supported in Taichi kernels")
        self.generic_visit(node, ['body'])
        decorated = isinstance(node.iter, ast.Call) and isinstance(
            node.iter.func, ast.Attribute)
        is_static_for = False
        is_grouped = False
        if decorated:
            attr = node.iter.func
            if attr.attr == 'static':
                is_static_for = True
            elif attr.attr == 'grouped':
                is_grouped = True
        is_range_for = isinstance(node.iter, ast.Call) and isinstance(
            node.iter.func, ast.Name) and node.iter.func.id == 'range'
        if is_static_for:
            return node
        elif is_range_for:
            loop_var = node.target.id
            self.check_loop_var(loop_var)
            template = ''' 
if 1:
  {} = ti.Expr(ti.core.make_id_expr(''))
  ___begin = ti.Expr(0) 
  ___end = ti.Expr(0)
  ___begin = ti.cast(___begin, ti.i32)
  ___end = ti.cast(___end, ti.i32)
  ti.core.begin_frontend_range_for({}.ptr, ___begin.ptr, ___end.ptr)
  ti.core.end_frontend_range_for()
      '''.format(loop_var, loop_var)
            t = ast.parse(template).body[0]

            assert len(node.iter.args) in [1, 2]
            if len(node.iter.args) == 2:
                bgn = node.iter.args[0]
                end = node.iter.args[1]
            else:
                bgn = self.make_constant(value=0)
                end = node.iter.args[0]

            t.body[1].value.args[0] = bgn
            t.body[2].value.args[0] = end
            t.body = t.body[:6] + node.body + t.body[6:]
            t.body.append(self.parse_stmt('del {}'.format(loop_var)))
            return ast.copy_location(t, node)
        else:  # Struct for
            if isinstance(node.target, ast.Name):
                elts = [node.target]
            else:
                elts = node.target.elts

            for loop_var in elts:
                self.check_loop_var(loop_var.id)

            var_decl = ''.join(
                '  {} = ti.Expr(ti.core.make_id_expr(""))\n'.format(ind.id)
                for ind in elts)
            vars = ', '.join(ind.id for ind in elts)
            if is_grouped:
                template = ''' 
if 1:
  ___loop_var = 0
  {} = ti.make_var_vector(size=___loop_var.loop_range().dim())
  ___expr_group = ti.make_expr_group({})
  ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr)
  ti.core.end_frontend_range_for()
        '''.format(vars, vars)
                t = ast.parse(template).body[0]
                cut = 4
                t.body[0].value = node.iter
                t.body = t.body[:cut] + node.body + t.body[cut:]
            else:
                template = ''' 
if 1:
{}
  ___loop_var = 0
  ___expr_group = ti.make_expr_group({})
  ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr)
  ti.core.end_frontend_range_for()
        '''.format(var_decl, vars)
                t = ast.parse(template).body[0]
                cut = len(elts) + 3
                t.body[cut - 3].value = node.iter
                t.body = t.body[:cut] + node.body + t.body[cut:]
            for loop_var in reversed(elts):
                t.body.append(self.parse_stmt('del {}'.format(loop_var.id)))
            return ast.copy_location(t, node)
예제 #54
0
파일: cpp.py 프로젝트: 1C4nfaN/dace
    def visit_Assign(self, node):
        target = rname(node.targets[-1])
        if target not in self.memlets:
            return self.generic_visit(node)

        memlet, nc, wcr, dtype = self.memlets[target]
        value = self.visit(node.value)

        if not isinstance(node.targets[-1], ast.Subscript):
            # Dynamic accesses or streams -> every access counts
            try:
                if memlet and memlet.data and (memlet.dynamic or isinstance(
                        self.sdfg.arrays[memlet.data], data.Stream)):
                    if wcr is not None:
                        newnode = ast.Name(
                            id=self.codegen.write_and_resolve_expr(
                                self.sdfg,
                                memlet,
                                nc,
                                target,
                                cppunparse.cppunparse(value,
                                                      expr_semicolon=False),
                                dtype=dtype))
                        node.value = ast.copy_location(newnode, node.value)
                        return node
                    elif isinstance(self.sdfg.arrays[memlet.data], data.Stream):
                        newnode = ast.Name(id="%s.push(%s);" % (
                            memlet.data,
                            cppunparse.cppunparse(value, expr_semicolon=False),
                        ))
                    else:
                        var_type, ctypedef = self.codegen._dispatcher.defined_vars.get(
                            memlet.data)
                        if var_type == DefinedType.Scalar:
                            newnode = ast.Name(id="%s = %s;" % (
                                memlet.data,
                                cppunparse.cppunparse(value,
                                                      expr_semicolon=False),
                            ))
                        else:
                            newnode = ast.Name(id="%s = %s;" % (
                                cpp_array_expr(self.sdfg, memlet),
                                cppunparse.cppunparse(value,
                                                      expr_semicolon=False),
                            ))

                    return self._replace_assignment(newnode, node)
            except TypeError:  # cannot determine truth value of Relational
                pass

            return self.generic_visit(node)

        subscript = self._subscript_expr(node.targets[-1].slice, target)

        if wcr is not None:
            newnode = ast.Name(id=self.codegen.write_and_resolve_expr(
                self.sdfg,
                memlet,
                nc,
                target,
                cppunparse.cppunparse(value, expr_semicolon=False),
                indices=sym2cpp(subscript),
                dtype=dtype) + ';')
        else:
            newnode = ast.Name(
                id="%s[%s] = %s;" %
                (target, sym2cpp(subscript),
                 cppunparse.cppunparse(value, expr_semicolon=False)))

        return self._replace_assignment(newnode, node)
예제 #55
0
 def visit_Attribute(self, node):
     return ast.copy_location(
         ast.Attribute(value=self.visit(node.value),
                       attr=node.attr.lower()), node)
예제 #56
0
    def apply(self, state: SDFGState, sdfg: SDFG):
        input: nodes.AccessNode = self.input
        tasklet: nodes.Tasklet = self.tasklet
        output: nodes.AccessNode = self.output

        # If state fission is necessary to keep semantics, do it first
        if (self.expr_index == 0 and state.in_degree(input) > 0
                and state.out_degree(output) == 0):
            newstate = sdfg.add_state_after(state)
            newstate.add_node(tasklet)
            new_input, new_output = None, None

            # Keep old edges for after we remove tasklet from the original state
            in_edges = list(state.in_edges(tasklet))
            out_edges = list(state.out_edges(tasklet))

            for e in in_edges:
                r = newstate.add_read(e.src.data)
                newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data)
                if e.src is input:
                    new_input = r
            for e in out_edges:
                w = newstate.add_write(e.dst.data)
                newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data)
                if e.dst is output:
                    new_output = w

            # Remove tasklet and resulting isolated nodes
            state.remove_node(tasklet)
            for e in in_edges:
                if state.degree(e.src) == 0:
                    state.remove_node(e.src)
            for e in out_edges:
                if state.degree(e.dst) == 0:
                    state.remove_node(e.dst)

            # Reset state and nodes for rest of transformation
            input = new_input
            output = new_output
            state = newstate
        # End of state fission

        if self.expr_index == 0:
            inedges = state.edges_between(input, tasklet)
            outedge = state.edges_between(tasklet, output)[0]
        else:
            me = self.map_entry
            mx = self.map_exit

            inedges = state.edges_between(me, tasklet)
            outedge = state.edges_between(tasklet, mx)[0]

        # Get relevant output connector
        outconn = outedge.src_conn

        ops = '[%s]' % ''.join(
            re.escape(o) for o in AugAssignToWCR._EXPRESSIONS)

        # Change tasklet code
        if tasklet.language is dtypes.Language.Python:
            # Match a single assignment with a binary operation as RHS
            ast_node: ast.Assign = tasklet.code.code[0]
            lhs: ast.Name = ast_node.targets[0]
            rhs: ast.BinOp = ast_node.value
            op = AugAssignToWCR._PYOP_MAP[type(rhs.op)]
            inconns = list(edge.dst_conn for edge in inedges)
            for n in (rhs.left, rhs.right):
                if isinstance(n, ast.Name) and n.id in inconns:
                    inedge = inedges[inconns.index(n.id)]
                else:
                    new_rhs = n
            new_node = ast.copy_location(
                ast.Assign(targets=[lhs], value=new_rhs), ast_node)
            tasklet.code.code = [new_node]

        elif tasklet.language is dtypes.Language.CPP:
            cstr = tasklet.code.as_string.strip()
            for edge in inedges:
                inconn = edge.dst_conn
                match = re.match(
                    r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' %
                    (re.escape(outconn), re.escape(inconn), ops), cstr)
                if match is None:
                    # match = re.match(
                    #     r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' %
                    #     (re.escape(outconn), ops, re.escape(inconn)), cstr)
                    # if match is None:
                    continue
                    # op = match.group(2)
                    # expr = match.group(1)
                else:
                    op = match.group(1)
                    expr = match.group(2)

                if edge.data.subset != outedge.data.subset:
                    continue

                # Map asymmetric WCRs to symmetric ones if possible
                if op in AugAssignToWCR._EXPR_MAP:
                    op, newexpr = AugAssignToWCR._EXPR_MAP[op]
                    expr = newexpr.format(expr=expr)

                tasklet.code.code = '%s = %s;' % (outconn, expr)
                inedge = edge
                break
        else:
            raise NotImplementedError

        # Change output edge
        outedge.data.wcr = f'lambda a,b: a {op} b'

        if self.expr_index == 0:
            # Remove input node and connector
            state.remove_edge_and_connectors(inedge)
            if state.degree(input) == 0:
                state.remove_node(input)
        else:
            # Remove input edge and dst connector, but not necessarily src
            state.remove_memlet_path(inedge)

        # If outedge leads to non-transient, and this is a nested SDFG,
        # propagate outwards
        sd = sdfg
        while (not sd.arrays[outedge.data.data].transient
               and sd.parent_nsdfg_node is not None):
            nsdfg = sd.parent_nsdfg_node
            nstate = sd.parent
            sd = sd.parent_sdfg
            outedge = next(
                iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data)))
            for outedge in nstate.memlet_path(outedge):
                outedge.data.wcr = f'lambda a,b: a {op} b'
예제 #57
0
 def visit_Name(self, node):
     return ast.copy_location(ast.Name(id=node.id.lower()), node)
예제 #58
0
 def replace(self, node, new_node):
     copy_location(new_node, node)
     NodeVisitor.generic_visit(self, new_node)
     return new_node
예제 #59
0
 def visit_arg(self, node):
     new_node = ast.Name(node.arg, ast.Param())
     ast.copy_location(new_node, node)
     return new_node
예제 #60
0
 def visit_Raise(self, node):
     if node.cause:
         raise error.NumbaError(node, "Cause to 'raise' not supported")
     newnode = Raise(type=node.exc, inst=None, tback=None)
     return ast.copy_location(newnode, node)