def visit_Assert(self, node): # pylint: disable=invalid-name, no-self-use """Visit Assert Convert BinOp asserts into unittest's assert equals """ args = [node.test] assertion = 'assertTrue' if isinstance(node.test, ast.Compare) and len(node.test.ops) == 1: args = [node.test.left, node.test.comparators[0]] assertion = { ast.Eq: 'assertEqual', ast.NotEq: 'assertNotEqual', ast.Lt: 'assertLess', ast.LtE: 'assertLessEqual', ast.Gt: 'assertGreater', ast.GtE: 'assertGreaterEqual', ast.Is: 'assertIs', ast.IsNot: 'assertIsNot', ast.In: 'assertIn', ast.NotIn: 'assertNotIn', }[node.test.ops[0].__class__] if node.msg: args.append(node.msg) return ast.copy_location(ast.Expr(ast.copy_location(call( ast.copy_location(ast.Attribute( ast.Name('self', ast.Load()), assertion, ast.Load() ), node), args ), node)), node)
def visit_Name(self, old_node): node = nodes.Name(old_node.id, old_node.ctx) ast.copy_location(node, old_node) # Set some defaults node.cf_maybe_null = True node.cf_is_null = False node.allow_null = False node.name = node.id if isinstance(node.ctx, ast.Param): var = self.symtab[node.name] var.is_arg = True self.flow.mark_assignment(node, None, var, assignment=None) elif isinstance(node.ctx, ast.Load): var = self.symtab.lookup(node.name) if var: # Local variable self.flow.mark_reference(node, var) # Set position of assignment of this definition if isinstance(node.ctx, (ast.Param, ast.Store)): var = self.symtab[node.name] if var.lineno == -1: var.lineno = getattr(node, "lineno", 0) var.col_offset = getattr(node, "col_offset", 0) return node
def visit_BinOp(self, node): node = self.generic_visit(node) if isinstance(node.left, ast.Num) and isinstance(node.right, ast.Num): value = eval(compile(ast.copy_location(ast.Expression(body=node), node), '', 'eval')) return ast.copy_location(ast.Num(n=value), node) else: return node
def test_load_const(self): consts = [None, True, False, 124, 2.0, 3j, "unicode", b'bytes', (1, 2, 3)] code = '\n'.join(['x={!r}'.format(const) for const in consts]) code += '\nx = ...' consts.extend((Ellipsis, None)) tree = ast.parse(code) self.assertEqual(self.get_load_const(tree), consts) # Replace expression nodes with constants for assign, const in zip(tree.body, consts): assert isinstance(assign, ast.Assign), ast.dump(assign) new_node = ast.Constant(value=const) ast.copy_location(new_node, assign.value) assign.value = new_node self.assertEqual(self.get_load_const(tree), consts)
def prepare(self): for i in self.children: i.prepare() # Compile the keywords. keyword_values = { } keyword_keys = [ ] keyword_exprs = [ ] for k, expr in self.keyword: node = py_compile(expr, 'eval', ast_node=True) if is_constant(node): keyword_values[k] = py_eval_bytecode(compile_expr(node)) else: keyword_keys.append(ast.Str(s=k)) keyword_exprs.append(node) if keyword_values: self.keyword_values = keyword_values else: self.keyword_values = None if keyword_keys: node = ast.Dict(keys=keyword_keys, values=keyword_exprs) ast.copy_location(node, keyword_exprs[0]) self.keyword_exprs = compile_expr(node) else: self.keyword_exprs = None
def prepare(self): SLBlock.prepare(self) # Prepare the positional arguments. exprs = [ ] values = [ ] has_exprs = False has_values = False for a in self.positional: node = py_compile(a, 'eval', ast_node=True) if is_constant(node): values.append(py_eval_bytecode(compile_expr(node))) exprs.append(ast.Num(n=0)) has_values = True else: values.append(use_expression) exprs.append(node) has_exprs = True if has_values: self.positional_values = values else: self.positional_values = None if has_exprs: t = ast.Tuple(elts=exprs, ctx=ast.Load()) ast.copy_location(t, exprs[0]) self.positional_exprs = compile_expr(t) else: self.positional_exprs = None
def visit_FunctionDef(self, node): if node.name.startswith('test'): statements = [stmt for stmt in node.body if isinstance(stmt, ast.Assign) or isinstance(stmt, ast.Expr) ] self.statements += statements self.tracking[node.name] = None body = node.body new_stmts = [] for _node in node.body: new_stmts.append(createNode(node.name, _node)) new_node_body = [] for i in xrange(0, len(new_stmts)): new_node = new_stmts[i] old_node = body[i] if isinstance(new_node, list): for _new_node in new_node: ast.copy_location(_new_node, old_node) new_node_body.append(_new_node) new_node_body.append(old_node) elif isinstance(new_node, ast.TryExcept): ast.copy_location(new_node, old_node) new_node_body.append(new_node) new_node_body.append(old_node) else: new_node_body.append(new_node) node.body = new_node_body ast.fix_missing_locations(node) return node ast.fix_missing_locations(node) return node
def __visit_FunctionDef(self, node): new_node = ast.FunctionDef(args=self.visit_arguments(node.args), body=self._visit_list(node.body), decorator_list=self._visit_list(node.decorator_list), name=node.name) ast.copy_location(new_node, node) return new_node
def visit_BinOp(self, node: ast.BinOp): node = self.generic_visit(node) if self._is_numeric_pow(node): left, right = node.left, node.right degree = ( right.n if isinstance(right, ast.Num) else -right.operand.n if isinstance(right.op, ast.USub) else right.operand.n ) degree = int(degree) if abs(degree) == 0: node = ast.copy_location(ast.Num(n = 1), node) elif abs(degree) == 1: node = node.left elif 2 <= abs(degree) <= self.MAX_DEGREE: for _ in range(1, abs(degree)): new_node = ast.BinOp\ ( left = left , op = ast.Mult() , right = copy(node.left) ) left = new_node = ast.copy_location(new_node, node) node = new_node else: return node if degree < 0: new_node = ast.BinOp\ ( left = ast.Num(n = 1) , op = ast.Div() , right = node ) node = ast.copy_location(new_node, node) return node
def visit_Name(self, node): if self.randomize(): if node.id == 'forall': return ast.copy_location(_ast.Name(id='exists'), node) elif node.id == 'exists': return ast.copy_location(_ast.Name(id='forall'), node) return node
def test_load_const(self): consts = [None, True, False, 124, 2.0, 3j, "unicode", b'bytes', (1, 2, 3)] code = '\n'.join(map(repr, consts)) code += '\n...' code_consts = [const for const in consts if (not isinstance(const, (str, int, float, complex)) or isinstance(const, bool))] code_consts.append(Ellipsis) # the compiler adds a final "LOAD_CONST None" code_consts.append(None) tree = ast.parse(code) self.assertEqual(self.get_load_const(tree), code_consts) # Replace expression nodes with constants for expr_node, const in zip(tree.body, consts): assert isinstance(expr_node, ast.Expr) new_node = ast.Constant(value=const) ast.copy_location(new_node, expr_node.value) expr_node.value = new_node self.assertEqual(self.get_load_const(tree), code_consts)
def visit_If(self, node): node = self.generic_visit(node) if (node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], ast.Pass) ): node.orelse = [] if (len(node.body) == 1 and isinstance(node.body[0], ast.Pass) ): if node.orelse: node_test = ast.UnaryOp(op=ast.Not(), operand=node.test) if (len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If) ): node_test = ast.BoolOp\ ( op = ast.And() , values = [node_test, node.orelse[0].test] ) node.test = ast.copy_location(node_test, node.orelse[0].test) node.body = node.orelse[0].body node.orelse = node.orelse[0].orelse else: node.test = ast.copy_location(node_test, node.test) node.body = node.orelse node.orelse = [] else: node = None return node
def visit_If(self, node): exit_block = self.flow.exit_block(label='exit_if', pos=node) # Condition cond_block = self.flow.nextblock(self.flow.block, label='if_cond', is_expr=True, pos=node.test) node.test = self.visit(node.test) # Body if_block = self.flow.nextblock(label='if_body', pos=node.body[0]) self.visitlist(node.body) if self.flow.block: self.flow.block.add_child(exit_block) # Else clause if node.orelse: else_block = self.flow.nextblock(cond_block, label='else_body', pos=node.orelse[0]) self.visitlist(node.orelse) if self.flow.block: self.flow.block.add_child(exit_block) else: cond_block.add_child(exit_block) else_block = None new_node = nodes.build_if(cond_block=cond_block, test=node.test, if_block=if_block, body=node.body, else_block=else_block, orelse=node.orelse, exit_block=exit_block) ast.copy_location(new_node, node) return self.exit_block(exit_block, new_node)
def visit_Name(self, node): new_node = ast.Name( self._visit(node.id), self._visit(node.ctx), ) ast.copy_location(new_node, node) return new_node
def test_ast_2(self): self.ns['FOOBAR'] = '' expr = ast.Expression() expr.body = ast.parse('FOO+BAR').body[0].value ast.copy_location(expr, expr.body) self.ns['FOOBAR'] = Expression(expr) self.assertEqual(self.ns['FOOBAR'].get(), 'foobar')
def visit_Assign(self, node): if isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Call): subscr = node.value call = subscr.value if len(node.targets) > 1: error.error('Cannot use multiple assignment in array declaration.', node) variable_name = node.targets[0].id value_type = call.func.id declaration_args = call.args # Get the indices being accessed. shape = slice_node_to_tuple_of_numbers(subscr.slice) new_assigns = [] for indices in itertools.product(*[range(n) for n in shape]): index_name = flattened_array_name(variable_name, indices) new_index_name_node = ast.copy_location(ast.Name(index_name, ast.Store()), node) new_value_type_node = ast.copy_location(ast.Name(value_type, ast.Load()), node) new_declaration_args = [copy.deepcopy(arg) for arg in declaration_args] new_call_node = ast.copy_location(ast.Call(new_value_type_node, new_declaration_args, [], None, None), node) new_assign = ast.Assign([new_index_name_node], new_call_node) new_assign = ast.copy_location(new_assign, node) new_assigns.append(new_assign) return new_assigns else: return node
def CALL(self, tree): is_demandload = is_demandload_regex = False if isinstance(tree.func, _ast.Attribute): if not isinstance(tree.func.value, _ast.Name): # this means it's a multilevel lookup; # pkgcore.ebuild.ebuild_src.some_func # ignore it; it *could* miss a direct # snakeoil.demandload.demandload, but # I don't care, bad form of access imo. return self.handleChildren(tree) src = self.scope.get(tree.func.value.id) if getattr(src, 'is_demandload_module', False): if tree.func.attr == 'demandload': is_demandload = True elif tree.func.attr == 'demand_compile_regexp': is_demandload_regex = True elif hasattr(tree.func, 'id'): is_demandload = getattr(self.scope.get(tree.func.id), 'is_demandload_func', False) is_demandload_regex = getattr(self.scope.get(tree.func.id), 'is_demandload_regex', False) if is_demandload_regex: # should do validation here. if len(tree.args) < 3: self.report(BadDemandloadRegexCall, tree.lineno) elif tree.args[1].__class__.__name__.upper() not in ("STR", "UNICODE"): self.report(BadDemandloadRegexCall, tree.lineno, "name argument isn't string nor unicode") elif tree.args[2].__class__.__name__.upper() not in ("STR", "UNICODE"): self.report(BadDemandloadRegexCall, tree.lineno, "regex argument isn't string nor unicode") else: code = "%s = re.compile(%r)\n" % (tree.args[1].s, tree.args[2].s) fakenode = _ast.copy_location(compile(code, self.filename, "exec", _ast.PyCF_ONLY_AST).body[0], tree) self.addBinding(tree.lineno, _checker.Assignment(tree.args[1].s, fakenode)) if is_demandload: if len(tree.args) < 2: self.report(BadDemandloadCall, tree.lineno) return self.handleChildren(tree) for chunk in tree.args[1:]: chunk_cls = chunk.__class__.__name__ if chunk_cls.upper() not in ('STR', 'UNICODE'): self.report(BadDemandloadCall, chunk.lineno, "invoked with non string/unicode arg: %r" % (chunk_cls,)) continue s = chunk.s try: targets = list(parse_demandload([s])) except ValueError, ve: self.report(BadDemandloadCall, chunk.lineno, ve) continue for src, asname in targets: fakenode = _ast.copy_location(compile("import %s as %s\n" % (src, asname), self.filename, "exec", _ast.PyCF_ONLY_AST).body[0], chunk) self.addBinding(chunk.lineno, DemandloadImportation(asname, fakenode))
def search_def_methods(self, class_node): for node in class_node.body: if isinstance(node, ast.FunctionDef): if self.is_test_method(node.name): newnode = ast.arguments(args=[], vararg=None, kwarg=None, defaults=[]) ast.copy_location(newnode, node.args) node.args = newnode self.methods_to_run.append(node)
def visit_With(self, node): new_node = ast.With( self._visit(node.items[0].context_expr), self._visit(node.items[0].optional_vars), self._visit(node.body) ) ast.copy_location(new_node, node) return new_node
def visit_Expr(self, node): """ When capturing a call to include, we must grab it here, so we can replace the whole Expr(Call('include')). """ if type(node.value) is Call: call = node.value if type(call.func) is Name and call.func.id == 'include': if len(call.args) < 1: raise FormatError("include requires at least a filename as an argument.") root = None # if the original call to include had an additional argument # use that argument as the root # print('call',ast.dump(call)) if len(call.args) > 1: root = call.args[1].s # or if there was a root= kwarg provided, use that elif len(call.keywords) > 0: for k in call.keywords: if k.arg == "root": root = k.value.s if root is None: # if we didn't get one from the call to include # look for one that was given as an argument to the template() call root = self.root if type(root) is str: root = root.split(os.path.sep) # the first argument to include() is the filename template_name = call.args[0].s # get the ast tree that comes from this included file check, fundef = include_ast(template_name, root) # each include produces the code to execute, plus some code to check for freshness # this code absolutely must run first, because we can't restart the generator once it has already yielded self.preamble.append(check) if fundef is None: raise FormatError("include_ast returned None") # return a copy of the the cached ast tree, because it will be further modified to fit with the including template fundef = copy.deepcopy(fundef) _yieldall(fundef.body) for expr in fundef.body: self.visit(expr) return fundef.body elif type(node.value) is Yield: y = node.value if type(y.value) == Str: if self.stripWhitespace: s = strip_whitespace(y.value.s) if len(s) == 0: return None # dont even compile in the Expr(Yield) if it was only yielding white space else: y.value.s = s elif type(y.value) == Call: call = y.value if type(call.func) is Name: if self.seenFuncs.get(call.func.id, False) is not False: # was defined locally # replace the Call with one to ''.join(Call) y.value = _call(Attribute(value=Str(s=''), attr='join', ctx=Load()), [y.value]) ast.copy_location(y.value, node) self.generic_visit(node) return node
def visit_Str(self, node): if "html" == self._get_template_type(): n = ast.Name(id='_q_htmltext', ctx=ast.Load()) ast.copy_location(n, node) n = ast.Call(func=n, args=[node], keywords=[], starargs=None, kwargs=None) return ast.copy_location(n, node) else: return node
def visit_BinOp(self, node): self.generic_visit(node) if isinstance(node.op, Pow): fnode = Name("pow", Load()) copy_location(fnode, node) cnode = Call(fnode, [node.left, node.right], [], None, None) copy_location(cnode, node) return cnode return node
def visit_TryFinally(self, node): new_node = gast.Try( self._visit(node.body), [], # handlers [], # orelse self._visit(node.finalbody) ) ast.copy_location(new_node, node) return new_node
def visit_TryExcept(self, node): new_node = gast.Try( self._visit(node.body), self._visit(node.handlers), self._visit(node.orelse), [] # finalbody ) ast.copy_location(new_node, node) return new_node
def visit_Return(self, node): assign = ast.Assign(targets = [ast.Name(id = 'y' , ctx = ast.Store())], value = ast.Num(8)) ast.increment_lineno(node, 1) ast.copy_location(assign, node) ast.fix_missing_locations(assign) #assign.col_offset = 8 # lists = list(ast.iter_child_nodes(assign)) # print lists return assign
def visit_FunctionDef(self, node): new_node = ast.FunctionDef( self._visit(node.name), self._visit(node.args), self._visit(node.body), self._visit(node.decorator_list), ) ast.copy_location(new_node, node) return new_node
def _seconds_to_mu(ref_period, node): divided = ast.copy_location( ast.BinOp(left=node, op=ast.Div(), right=value_to_ast(ref_period)), node) return ast.copy_location( ast.Call(func=ast.Name("round64", ast.Load()), args=[divided], keywords=[]), divided)
def visit_Assert(self, assert_node): if assert_node.msg is not None: return assert_node statements = AssertionChildVisitor().visit(assert_node.test) for s in statements: ast.copy_location(s, assert_node) ast.fix_missing_locations(s) return statements
def visit_ClassDef(self, node): new_node = ast.ClassDef( self._visit(node.name), self._visit(node.bases), self._visit(node.body), self._visit(node.decorator_list), ) ast.copy_location(new_node, node) return new_node
def visit_With(self, node): new_node = gast.With( [gast.withitem( self._visit(node.context_expr), self._visit(node.optional_vars) )], self._visit(node.body) ) ast.copy_location(new_node, node) return new_node
def visit(self, node): new_node = super().visit(node) if new_node is not node: return ast.copy_location(new_node, node) return node
def visit_Num(self, node: ast.Num): newname = f'__uu{self.id}' self.gvars[newname] = node.n self.id += 1 return ast.copy_location(ast.Name(id=newname, ctx=ast.Load()), node)
def _aug_assign(self, target, oper, value): # transform a += 1 to a = a + 1, then we can use assign and eval new_value = ast.BinOp(left=target, op=oper, right=value) ast.copy_location(new_value, target) self._assign(target, new_value)
def ast_call(node): """Visit and transform ast call node""" distributed_mode = util_global.get_value("distributed_mode", "") is_not_strategy = distributed_mode in ("horovod", "") is_not_horovod = distributed_mode in ("tf_strategy", "") convert_loss_scale_api(node) if _call_name_match(node.func, "set_experimental_options"): log_msg( getattr(node, 'lineno', 'None'), 'change set_experimental_options(*) to set_experimental_options(experimental_options)' ) node.args = [ast.Name(id='experimental_options', ctx=ast.Load())] node.keywords = [] util_global.set_value('need_conver', True) if isinstance(node.func, ast.Name) and node.func.id == 'check_available_gpus': log_msg(getattr(node, 'lineno', 'None'), "change check_available_gpus() to ['/device:CPU:0']") util_global.set_value('need_conver', True) return ast.List(elts=[ast.Str(s="/device:CPU:0")], ctx=ast.Load()) if ((isinstance(node.func, ast.Name) and node.func.id == 'GraphOptions') or (isinstance(node.func, ast.Attribute) and node.func.attr == 'GraphOptions')): log_success_report(getattr(node, 'lineno', 'None'), 'GraphOptions()') src = copy.deepcopy(node) node.func = ast.Name(id='npu_graph_options', ctx=ast.Load()) node.args = [] node.keywords = [] node.keywords.append(ast.keyword(arg='graph_options', value=src)) util_global.set_value('need_conver', True) return node if (isinstance(node.func, ast.Name) and node.func.id == 'OptimizerOptions') or \ (isinstance(node.func, ast.Attribute) and node.func.attr == 'OptimizerOptions'): log_success_report(getattr(node, 'lineno', 'None'), 'OptimizerOptions()') src = copy.deepcopy(node) node.func = ast.Name(id='npu_optimizer_options', ctx=ast.Load()) node.args = [] node.keywords = [] node.keywords.append(ast.keyword(arg='optimizer_options', value=src)) util_global.set_value('need_conver', True) return node if _call_name_match(node.func, "Session"): return convert_origin_func_to_npu(node, tf_func_map["tf.Session"], "tf.Session", ["config"]) if _call_name_match(node.func, "InteractiveSession"): return convert_origin_func_to_npu(node, tf_func_map["tf.InteractiveSession"], "tf.InteractiveSession", ["config"]) if isinstance(node.func, ast.Attribute ) and node.func.attr == "BroadcastGlobalVariablesHook": if isinstance(node.func.value, ast.Name) and node.func.value.id == "hvd": if is_not_horovod: log_strategy_distributed_mode_error(node) return node log_msg( getattr(node, "lineno", "None"), 'change hvd.BroadcastGlobalVariablesHook to NPUBroadcastGlobalVariablesHook' ) node = pasta.parse( "NPUBroadcastGlobalVariablesHook(0, int(os.getenv('RANK_ID', '0')))" ) util_global.set_value('need_conver', True) return node if isinstance(node.func, ast.Attribute ) and node.func.attr == "BroadcastGlobalVariablesCallback": if isinstance(node.func.value, ast.Attribute) and node.func.value.attr == "callbacks": if is_not_horovod: log_strategy_distributed_mode_error(node) return node log_msg( getattr(node, "lineno", "None"), 'change hvd.callbacks.BroadcastGlobalVariablesCallback to NPUBroadcastGlobalVariablesCallback' ) node = pasta.parse( "NPUBroadcastGlobalVariablesCallback(root_rank=0)") util_global.set_value('need_conver', True) return node if isinstance(node.func, ast.Attribute) and node.func.attr == "DistributedOptimizer": if isinstance(node.func.value, ast.Name) and node.func.value.id == "hvd": if is_not_horovod: log_strategy_distributed_mode_error(node) return node return convert_hvd_distributed_api(node) if isinstance(node.func, ast.Attribute) and node.func.attr == 'shard': log_success_report(getattr(node, "lineno", "None"), 'shard') node.args = [ pasta.parse("int(os.getenv('RANK_SIZE', '1'))"), pasta.parse("int(os.getenv('RANK_ID', '0'))") ] node.keywords.clear() util_global.set_value('need_conver', True) return node if isinstance(node.func, ast.Attribute) and node.func.attr == 'dropout': if isinstance(node.func.value, ast.Attribute) and node.func.value.attr == 'nn': for index, _ in enumerate(node.args): if index == 2: return node for keyword in node.keywords: if keyword.arg == "noise_shape": return node log_success_report(getattr(node, "lineno", "None"), 'dropout') node.func = ast.Attribute(value=ast.Name(id='npu_ops', ctx=ast.Load()), attr='dropout', ctx=ast.Load()) keywords_new = [] for keyword in node.keywords: if keyword.arg != 'rate': keywords_new.append(keyword) else: keywords_new.append( ast.keyword(arg='keep_prob', value=ast.BinOp(left=ast.Num(n=1), op=ast.Sub(), right=keyword.value))) node.keywords = keywords_new util_global.set_value('need_conver', True) return node if isinstance(node.func, ast.Attribute) and \ ((node.func.attr == 'map_and_batch') or (node.func.attr == 'batch' and (not isinstance(node.func.value, ast.Attribute) or ( isinstance(node.func.value, ast.Attribute) and node.func.value.attr != 'train')))): exist = False for keyword in node.keywords: if keyword.arg == 'drop_remainder': exist = True if ((isinstance(keyword.value, ast.NameConstant) and not keyword.value.value) or (not isinstance(keyword.value, ast.NameConstant))): log_success_report(getattr(node, "lineno", "None"), node.func.attr) keyword.value = pasta.parse('True') util_global.set_value('need_conver', True) if not exist: log_success_report(getattr(node, "lineno", "None"), node.func.attr) keyword = ast.keyword(arg='drop_remainder', value=pasta.parse('True')) node.keywords.insert(0, keyword) util_global.set_value('need_conver', True) return node if (isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and node.func.value.id == 'tf' and node.func.attr == 'device'): log_success_report(getattr(node, "lineno", "None"), node.func.attr) node.args = [ast.Str(s='/cpu:0')] util_global.set_value('need_conver', True) return node if isinstance(node.func, ast.Attribute) and \ (node.func.attr == "get_distribution_strategy" or node.func.attr == "MirroredStrategy" or node.func.attr == "MultiWorkerMirroredStrategy"): if is_not_strategy: log_hvd_distributed_mode_error(node) return node log_success_report(getattr(node, "lineno", "None"), node.func.attr) new_func = ast.Attribute(value=ast.Name(id="npu_strategy", ctx=ast.Load()), attr="NPUStrategy", ctx=ast.Load()) ast.copy_location(new_func, node.func) node.func = new_func node.keywords = [] node.args = [] util_global.set_value('need_conver', True) return node if (isinstance(node.func, ast.Attribute) and (node.func.attr == 'RunConfig')) and \ (_call_name_match(node.func.value, 'estimator') or _call_name_match(node.func.value, 'tpu')): if node.keywords.count("train_distribute") or node.keywords.count( "eval_distribute"): if is_not_strategy: log_hvd_distributed_mode_error(node) save_summary_steps = None for keyword in node.keywords: if keyword.arg == 'save_summary_steps': save_summary_steps = keyword break if len(node.args) < 3 and not save_summary_steps: log_msg(getattr(node, 'lineno'), 'RunConfig() add save_summary_steps=0') util_global.set_value('need_conver', True) node.keywords.append( ast.keyword(arg='save_summary_steps', value=pasta.parse('0'))) return node if isinstance(node.func, ast.Attribute) and (node.func.attr == 'TPUEstimator') and \ ((isinstance(node.func.value, ast.Attribute) and (node.func.value.attr == 'tpu')) or (isinstance(node.func.value, ast.Name) and (node.func.value.id == 'tpu'))): add_eval_on_tpu = True add_use_tpu = True add_export_to_tpu = True for keyword in node.keywords: if (keyword.arg == 'eval_on_tpu') or ( keyword.arg == 'use_tpu') or (keyword.arg == 'export_to_tpu'): if (not isinstance(keyword.value, ast.NameConstant)) or \ (isinstance(keyword.value, ast.NameConstant) and (keyword.value.value)): log_success_report(getattr(node, 'lineno', 'None'), 'TPUEstimator(' + keyword.arg + '=*)') keyword.value = pasta.parse('False') util_global.set_value('need_conver', True) if add_eval_on_tpu and (keyword.arg == 'eval_on_tpu'): add_eval_on_tpu = False if add_use_tpu and (keyword.arg == 'use_tpu'): add_use_tpu = False if add_export_to_tpu and (keyword.arg == 'export_to_tpu'): add_export_to_tpu = False if add_eval_on_tpu: log_success_report(getattr(node, 'lineno', 'None'), 'TPUEstimator(eval_on_tpu=*)') node.keywords.append( ast.keyword(arg='eval_on_tpu', value=pasta.parse('False'))) util_global.set_value('need_conver', True) if add_use_tpu: log_success_report(getattr(node, 'lineno', 'None'), 'TPUEstimator(use_tpu=*)') node.keywords.append( ast.keyword(arg='use_tpu', value=pasta.parse('False'))) util_global.set_value('need_conver', True) if add_export_to_tpu: log_success_report(getattr(node, 'lineno', 'None'), 'TPUEstimator(export_to_tpu=*)') node.keywords.append( ast.keyword(arg='export_to_tpu', value=pasta.parse('False'))) util_global.set_value('need_conver', True) if isinstance(node.func, ast.Attribute) and (node.func.attr == 'VirtualDeviceConfiguration'): log_success_report(getattr(node, 'lineno', 'None'), 'VirtualDeviceConfiguration') util_global.set_value('need_conver', True) memory_limit = None for keyword in node.keywords: if keyword.arg == 'memory_limit': memory_limit = keyword break if memory_limit: memory_limit.value = ast.NameConstant(value=None) else: node.keywords.append( ast.keyword(arg='memory_limit', value=ast.NameConstant(value=None))) return node if isinstance(node.func, ast.Attribute) and (node.func.attr == 'set_soft_device_placement'): log_success_report(getattr(node, 'lineno', 'None'), 'set_soft_device_placement') util_global.set_value('need_conver', True) node.args = [] node.keywords = [ ast.keyword(arg='enabled', value=ast.NameConstant(value=True)) ] return node if isinstance(node.func, ast.Attribute) and (node.func.attr == 'set_memory_growth'): log_success_report(getattr(node, 'lineno', 'None'), 'set_memory_growth') util_global.set_value('need_conver', True) node = ast.NameConstant(value=None) return node if isinstance(node.func, ast.Attribute) and (node.func.attr == 'set_virtual_device_configuration'): log_success_report(getattr(node, 'lineno', 'None'), 'set_virtual_device_configuration') util_global.set_value('need_conver', True) node = ast.NameConstant(value=None) return node if isinstance(node.func, ast.Attribute) and (node.func.attr == 'jit_scope'): if isinstance(node.func.value, ast.Attribute) and (node.func.value.attr == 'experimental'): if isinstance(node.func.value.value, ast.Attribute) and (node.func.value.value.attr == 'xla'): log_success_report(getattr(node, 'lineno', 'None'), '*.xla.experimental.jit_scope') util_global.set_value('need_conver', True) compile_ops = None for keyword in node.keywords: if keyword.arg == 'compile_ops': compile_ops = keyword break if compile_ops: compile_ops.value = pasta.parse('False') else: node.keywords.append( ast.keyword(arg='compile_ops', value=pasta.parse('False'))) return node for estimator in util_global.get_value('Estimators', []): if (isinstance(node.func, ast.Attribute) and (node.func.attr == estimator)) \ or (isinstance(node.func, ast.Name) and (node.func.id == estimator)): log_msg( getattr(node, 'lineno'), "".join([estimator, '() add config=npu_run_config_init()'])) config = None for keyword in node.keywords: if keyword.arg == 'config': config = keyword break if config: new_value = ast.Call(func=ast.Name(id='npu_run_config_init', ctx=ast.Load()), args=[], keywords=[ ast.keyword(arg='run_config', value=config.value) ]) ast.copy_location(new_value, config.value) config.value = new_value else: node.keywords.append( ast.keyword(arg='config', value=pasta.parse('npu_run_config_init()'))) util_global.set_value('need_conver', True) return node if isinstance(node.func, ast.Attribute) and (node.func.attr == 'clear_session'): log_msg(getattr(node, 'lineno'), "change keras.clear_session() to npu_clear_session()") node = ast.Call(func=ast.Name(id='npu_clear_session', ctx=ast.Load()), args=[], keywords=[]) util_global.set_value('need_conver', True) if _call_name_match(node.func, "MonitoredTrainingSession"): return convert_origin_func_to_npu( node, tf_func_map["tf.train.MonitoredTrainingSession"], "MonitoredTrainingSession", ["config", "hooks"]) if isinstance(node.func, ast.Attribute) and node.func.attr == "managed_session": return convert_origin_func_to_npu( node, tf_func_map["tf.train.Supervisor.managed_session"], "managed_session", ["config"], True) if distributed_mode == "tf_strategy": # this cond should be placed at the end of the Call function. return convert_distributed_strategy_apis(node) return node
def visit_Assign(self, node): assert (len(node.targets) == 1) self.generic_visit(node) decorated = isinstance(node.value, ast.Call) and isinstance( node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Name) \ and node.value.func.value.id == 'ti' is_static_assign = False if decorated: attr = node.value.func if attr.attr == 'static': is_static_assign = True else: pass # eg. x = ti.cast(xx) will reach here, but they're not decorators, so no raising errors here if is_static_assign: return node if isinstance(node.targets[0], ast.Tuple): targets = node.targets[0].elts # Create stmts = [] holder = self.parse_stmt('__tmp_tuple = ti.expr_init_list(0, ' f'{len(targets)})') holder.value.args[0] = node.value stmts.append(holder) def tuple_indexed(i): indexing = self.parse_stmt('__tmp_tuple[0]') indexing.value.slice.value = self.parse_expr("{}".format(i)) return indexing.value for i, target in enumerate(targets): is_local = isinstance(target, ast.Name) if is_local and self.is_creation(target.id): var_name = target.id target.ctx = ast.Store() # Create init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()), attr='expr_init', ctx=ast.Load()) rhs = ast.Call( func=init, args=[tuple_indexed(i)], keywords=[], ) self.create_variable(var_name) stmts.append(ast.Assign(targets=[target], value=rhs)) else: # Assign target.ctx = ast.Load() func = ast.Attribute(value=target, attr='assign', ctx=ast.Load()) call = ast.Call(func=func, args=[tuple_indexed(i)], keywords=[]) stmts.append(ast.Expr(value=call)) for stmt in stmts: ast.copy_location(stmt, node) stmts.append(self.parse_stmt('del __tmp_tuple')) return self.make_single_statement(stmts) else: is_local = isinstance(node.targets[0], ast.Name) if is_local and self.is_creation(node.targets[0].id): var_name = node.targets[0].id # Create init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()), attr='expr_init', ctx=ast.Load()) rhs = ast.Call( func=init, args=[node.value], keywords=[], ) self.create_variable(var_name) return ast.copy_location( ast.Assign(targets=node.targets, value=rhs), node) else: # Assign node.targets[0].ctx = ast.Load() func = ast.Attribute(value=node.targets[0], attr='assign', ctx=ast.Load()) call = ast.Call(func=func, args=[node.value], keywords=[]) return ast.copy_location(ast.Expr(value=call), node)
def visit_For(self, node): if node.orelse: raise TaichiSyntaxError( "'else' clause for 'for' not supported in Taichi kernels") decorated = isinstance(node.iter, ast.Call) and isinstance( node.iter.func, ast.Attribute) and isinstance(node.iter.func.value, ast.Name) \ and node.iter.func.value.id == 'ti' is_ndrange_for = False is_static_for = False is_grouped = False if decorated: attr = node.iter.func if attr.attr == 'static': is_static_for = True elif attr.attr == 'grouped': is_grouped = True elif attr.attr == 'ndrange': is_ndrange_for = True else: raise Exception('Not supported') is_range_for = isinstance(node.iter, ast.Call) and isinstance( node.iter.func, ast.Name) and node.iter.func.id == 'range' ast.fix_missing_locations(node) if not is_ndrange_for: self.generic_visit(node, ['body']) if is_ndrange_for: template = ''' if ti.static(1): __ndrange = 0 for __ndrange_I in range(0): __I = __ndrange_I ''' t = ast.parse(template).body[0] t.body[0].value = node.iter t.body[1].iter.args[0] = self.parse_expr( '__ndrange.acc_dimensions[0]') targets = node.target if isinstance(targets, ast.Tuple): targets = [name.id for name in targets.elts] else: targets = [targets.id] loop_body = t.body[1].body for i in range(len(targets)): if i + 1 < len(targets): stmt = '__{} = __I // __ndrange.acc_dimensions[{}]'.format( targets[i], i + 1) else: stmt = '__{} = __I'.format(targets[i]) loop_body.append(self.parse_stmt(stmt)) stmt = '{} = __{} + __ndrange.bounds[{}][0]'.format( targets[i], targets[i], i) loop_body.append(self.parse_stmt(stmt)) if i + 1 < len(targets): stmt = '__I = __I - __{} * __ndrange.acc_dimensions[{}]'.format( targets[i], i + 1) loop_body.append(self.parse_stmt(stmt)) loop_body += node.body node = ast.copy_location(t, node) return self.visit(node) # further translate as a range for elif is_static_for: return node elif is_range_for: loop_var = node.target.id self.check_loop_var(loop_var) template = ''' if 1: {} = ti.Expr(ti.core.make_id_expr('')) ___begin = ti.Expr(0) ___end = ti.Expr(0) ___begin = ti.cast(___begin, ti.i32) ___end = ti.cast(___end, ti.i32) ti.core.begin_frontend_range_for({}.ptr, ___begin.ptr, ___end.ptr) ti.core.end_frontend_range_for() '''.format(loop_var, loop_var) t = ast.parse(template).body[0] assert len(node.iter.args) in [1, 2] if len(node.iter.args) == 2: bgn = node.iter.args[0] end = node.iter.args[1] else: bgn = self.make_constant(value=0) end = node.iter.args[0] t.body[1].value.args[0] = bgn t.body[2].value.args[0] = end t.body = t.body[:6] + node.body + t.body[6:] t.body.append(self.parse_stmt('del {}'.format(loop_var))) return ast.copy_location(t, node) else: # Struct for if isinstance(node.target, ast.Name): elts = [node.target] else: elts = node.target.elts for loop_var in elts: self.check_loop_var(loop_var.id) var_decl = ''.join( ' {} = ti.Expr(ti.core.make_id_expr(""))\n'.format(ind.id) for ind in elts) vars = ', '.join(ind.id for ind in elts) if is_grouped: template = ''' if 1: ___loop_var = 0 {} = ti.make_var_vector(size=___loop_var.loop_range().dim()) ___expr_group = ti.make_expr_group({}) ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr) ti.core.end_frontend_range_for() '''.format(vars, vars) t = ast.parse(template).body[0] cut = 4 t.body[0].value = node.iter t.body = t.body[:cut] + node.body + t.body[cut:] else: template = ''' if 1: {} ___loop_var = 0 ___expr_group = ti.make_expr_group({}) ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr) ti.core.end_frontend_range_for() '''.format(var_decl, vars) t = ast.parse(template).body[0] cut = len(elts) + 3 t.body[cut - 3].value = node.iter t.body = t.body[:cut] + node.body + t.body[cut:] for loop_var in reversed(elts): t.body.append(self.parse_stmt('del {}'.format(loop_var.id))) return ast.copy_location(t, node)
def visit(self, node): method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) return ast.copy_location(visitor(node), self.__benchmark)
def _unellipsify(self, node, slices, subscript_node): """ Given an array node `node`, process all AST slices and create the final type: - process newaxes (None or numpy.newaxis) - replace Ellipsis with a bunch of ast.Slice objects - process integer indices - append any missing slices in trailing dimensions """ type = node.variable.type if not type.is_array: assert type.is_object return minitypes.object_, node if (len(slices) == 1 and self._is_constant_index(slices[0]) and slices[0].value.pyval is Ellipsis): # A[...] return type, node result = [] seen_ellipsis = False # Filter out newaxes newaxes = [newaxis for newaxis in slices if self._is_newaxis(newaxis)] n_indices = len(slices) - len(newaxes) full_slice = ast.Slice(lower=None, upper=None, step=None) full_slice.variable = Variable(numba_types.SliceType()) ast.copy_location(full_slice, slices[0]) # process ellipses and count integer indices indices_seen = 0 for slice_node in slices[::-1]: slice_type = slice_node.variable.type if slice_type.is_ellipsis: if seen_ellipsis: result.append(full_slice) else: nslices = type.ndim - n_indices + 1 result.extend([full_slice] * nslices) seen_ellipsis = True elif (slice_type.is_slice or slice_type.is_int or self._is_newaxis(slice_node)): indices_seen += slice_type.is_int result.append(slice_node) else: # TODO: Coerce all object operands to integer indices? # TODO: (This will break indexing with the Ellipsis object or # TODO: with slice objects that we couldn't infer) return minitypes.object_, nodes.CoercionNode( node, minitypes.object_) # append any missing slices (e.g. a2d[:] result_length = len(result) - len(newaxes) if result_length < type.ndim: nslices = type.ndim - result_length result.extend([full_slice] * nslices) result.reverse() subscript_node.slice = ast.ExtSlice(result) ast.copy_location(subscript_node.slice, slices[0]) # create the final array type and set it in value.variable result_dtype = node.variable.type.dtype result_ndim = node.variable.type.ndim + len(newaxes) - indices_seen if result_ndim > 0: result_type = result_dtype[(slice(None), ) * result_ndim] elif result_ndim == 0: result_type = result_dtype else: result_type = minitypes.object_ return result_type, node
def visit_FunctionDef(self, orig_ast: ast.FunctionDef) -> ast.AST: cascades_body = [] # Compute selected features. for node in selected_feature_nodes: cascades_body += node.get_ast() # Predict with approximate model. approximate_model_ast = create_model_ast( function_name=predict_proba_function.__name__, output_name="__willump_approximate_preds", input_names=selected_feature_names, model_param="__willump_approximate_model") cascades_body += approximate_model_ast # Get indices that can't be approximated. unapproximated_indices_ast = create_function_ast( function_name="__willump_get_unapproximated_indices", output_name="__willump_unapproximated_indices", input_names=[ "__willump_approximate_preds", "__willump_cascade_threshold" ]) cascades_body += unapproximated_indices_ast # Only compute remaining features for unapproximated indices. shortened_inputs = set() for node in remaining_feature_nodes: for input_name in node.input_names: if input_name not in shortened_inputs: shorten_ast = create_function_ast( function_name= "__willump_select_unapproximated_rows", output_name=input_name, input_names=[ input_name, "__willump_unapproximated_indices" ]) cascades_body += shorten_ast shortened_inputs.add(input_name) cascades_body += node.get_ast() # Shorten the selected features. for name in selected_feature_names: shorten_ast = create_function_ast( function_name="__willump_select_unapproximated_rows", output_name=name, input_names=[name, "__willump_unapproximated_indices"]) cascades_body += shorten_ast # Predict with full model. full_model_ast = create_model_ast( function_name=predict_function.__name__, output_name="__willump_full_preds", input_names=model_node.input_names, model_param="__willump_full_model") cascades_body += full_model_ast # Return combined predictions. combine_predictions_ast = create_function_ast( function_name="__willump_combine_predictions", output_name="__willump_final_predictions", input_names=[ "__willump_approximate_preds", "__willump_full_preds", "__willump_cascade_threshold" ]) cascades_body += combine_predictions_ast return_ast = ast.parse("return __willump_final_predictions", "exec").body cascades_body += return_ast # Finalize AST. new_ast = copy.deepcopy(orig_ast) new_ast.body = cascades_body # No recursion allowed! new_ast.decorator_list = [] return ast.copy_location(new_ast, orig_ast)
def visit_Return(self, node, visit_count): """Visit a ``Return`` node.""" if not self.ordinal or visit_count in self.ordinal: return ast.copy_location(self.inject(node), node) self.generic_visit(node) return node
def visit_With(self, node): new_node = ast.With(self._visit(node.items[0].context_expr), self._visit(node.items[0].optional_vars), self._visit(node.body)) ast.copy_location(new_node, node) return new_node
def _copy_location(newnode, node): return ast.fix_missing_locations(ast.copy_location(newnode, node))
def visitName(self, node: ast.Name): if node.id == "__debug__": return copy_location(Constant(not self.optimize), node) return self.generic_visit(node)
def passer(_: ast3.NodeTransformer, node: cls) -> None: return ast.copy_location(ast.Pass(), node)
def visit_Assign(self, node): assert (len(node.targets) == 1) self.generic_visit(node) if isinstance(node.targets[0], ast.Tuple): targets = node.targets[0].elts # Create stmts = [] holder = self.parse_stmt('__tmp_tuple = 0') holder.value = node.value stmts.append(holder) def tuple_indexed(i): indexing = self.parse_stmt('__tmp_tuple[0]') indexing.value.slice.value = self.parse_expr("{}".format(i)) return indexing.value for i, target in enumerate(targets): is_local = isinstance(target, ast.Name) if is_local and self.is_creation(target.id): var_name = target.id target.ctx = ast.Store() # Create init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()), attr='expr_init', ctx=ast.Load()) rhs = ast.Call( func=init, args=[tuple_indexed(i)], keywords=[], ) self.create_variable(var_name) stmts.append(ast.Assign(targets=[target], value=rhs)) else: # Assign target.ctx = ast.Load() func = ast.Attribute(value=target, attr='assign', ctx=ast.Load()) call = ast.Call(func=func, args=[tuple_indexed(i)], keywords=[]) stmts.append(ast.Expr(value=call)) for stmt in stmts: ast.copy_location(stmt, node) stmts.append(self.parse_stmt('del __tmp_tuple')) return self.make_single_statement(stmts) else: is_local = isinstance(node.targets[0], ast.Name) if is_local and self.is_creation(node.targets[0].id): var_name = node.targets[0].id # Create init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()), attr='expr_init', ctx=ast.Load()) rhs = ast.Call( func=init, args=[node.value], keywords=[], ) self.create_variable(var_name) return ast.copy_location( ast.Assign(targets=node.targets, value=rhs), node) else: # Assign node.targets[0].ctx = ast.Load() func = ast.Attribute(value=node.targets[0], attr='assign', ctx=ast.Load()) call = ast.Call(func=func, args=[node.value], keywords=[]) return ast.copy_location(ast.Expr(value=call), node)
def fix_location(new, old): ast.copy_location(new, old) ast.fix_missing_locations(new) return new
def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: """Return the AST statements to replace the ast.Assert instance. This rewrites the test of an assertion to provide intermediate values and replace it with an if statement which raises an assertion error with a detailed explanation in case the expression is false. """ if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1: from _pytest.warning_types import PytestAssertRewriteWarning import warnings # TODO: This assert should not be needed. assert self.module_path is not None warnings.warn_explicit( PytestAssertRewriteWarning( "assertion is always true, perhaps remove parentheses?" ), category=None, filename=self.module_path, lineno=assert_.lineno, ) self.statements: List[ast.stmt] = [] self.variables: List[str] = [] self.variable_counter = itertools.count() if self.enable_assertion_pass_hook: self.format_variables: List[str] = [] self.stack: List[Dict[str, ast.expr]] = [] self.expl_stmts: List[ast.stmt] = [] self.push_format_context() # Rewrite assert into a bunch of statements. top_condition, explanation = self.visit(assert_.test) negation = ast.UnaryOp(ast.Not(), top_condition) if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook msg = self.pop_format_context(ast.Str(explanation)) # Failed if assert_.msg: assertmsg = self.helper("_format_assertmsg", assert_.msg) gluestr = "\n>assert " else: assertmsg = ast.Str("") gluestr = "assert " err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg) err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) err_name = ast.Name("AssertionError", ast.Load()) fmt = self.helper("_format_explanation", err_msg) exc = ast.Call(err_name, [fmt], []) raise_ = ast.Raise(exc, None) statements_fail = [] statements_fail.extend(self.expl_stmts) statements_fail.append(raise_) # Passed fmt_pass = self.helper("_format_explanation", msg) orig = _get_assertion_exprs(self.source)[assert_.lineno] hook_call_pass = ast.Expr( self.helper( "_call_assertion_pass", ast.Num(assert_.lineno), ast.Str(orig), fmt_pass, ) ) # If any hooks implement assert_pass hook hook_impl_test = ast.If( self.helper("_check_if_assertion_pass_impl"), self.expl_stmts + [hook_call_pass], [], ) statements_pass = [hook_impl_test] # Test for assertion condition main_test = ast.If(negation, statements_fail, statements_pass) self.statements.append(main_test) if self.format_variables: variables = [ ast.Name(name, ast.Store()) for name in self.format_variables ] clear_format = ast.Assign(variables, ast.NameConstant(None)) self.statements.append(clear_format) else: # Original assertion rewriting # Create failure message. body = self.expl_stmts self.statements.append(ast.If(negation, body, [])) if assert_.msg: assertmsg = self.helper("_format_assertmsg", assert_.msg) explanation = "\n>assert " + explanation else: assertmsg = ast.Str("") explanation = "assert " + explanation template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) msg = self.pop_format_context(template) fmt = self.helper("_format_explanation", msg) err_name = ast.Name("AssertionError", ast.Load()) exc = ast.Call(err_name, [fmt], []) raise_ = ast.Raise(exc, None) body.append(raise_) # Clear temporary variables by setting them to None. if self.variables: variables = [ast.Name(name, ast.Store()) for name in self.variables] clear = ast.Assign(variables, ast.NameConstant(None)) self.statements.append(clear) # Fix locations (line numbers/column offsets). for stmt in self.statements: for node in traverse_node(stmt): ast.copy_location(node, assert_) return self.statements
def dyn(n): return ast.copy_location(ast.Name(id='Any', ctx=ast.Load()), n)
def convert_distributed_strategy_apis(node): """Convert distributed strategy API""" if isinstance(node.func, ast.Attribute) and isinstance( node.func.value, ast.Attribute): if ("Optimizer" in node.func.attr and node.func.attr != "ScipyOptimizerInterface" and node.func.attr != "MixedPrecisionLossScaleOptimizer"): log_msg(getattr(node, "lineno", "None"), "add npu distribute optimizer to tensorflow optimizer") new_node = ast.Call(func=ast.Name( id="npu_distributed_optimizer_wrapper", ctx=ast.Load()), args=[node], keywords=[]) ast.copy_location(new_node, node) util_global.set_value('need_conver', True) return new_node if isinstance( node.func, ast.Name ) and "Optimizer" in node.func.id and node.func.id != "NPULossScaleOptimizer": log_msg(getattr(node, "lineno", "None"), "add npu distribute optimizer to tensorflow optimizer") new_node = ast.Call(func=ast.Name( id="npu_distributed_optimizer_wrapper", ctx=ast.Load()), args=[node], keywords=[]) ast.copy_location(new_node, node) util_global.set_value('need_conver', True) return new_node if _call_name_match(node.func, "TrainSpec"): return convert_origin_func_to_npu( node, tf_func_map["tf.estimator.TrainSpec"], "TrainSpec", ["hooks"]) if _call_name_match(node.func, "EvalSpec"): return convert_origin_func_to_npu(node, tf_func_map["tf.estimator.EvalSpec"], "EvalSpec", ["hooks"]) if isinstance(node.func, ast.Attribute) and node.func.attr == "train": if isinstance(node.func.value, ast.Attribute) and node.func.value.attr == "learning": return node return convert_origin_func_to_npu( node, tf_func_map["tf.estimator.Estimator.train"], "Estimator.train", ["hooks"], True) if isinstance(node.func, ast.Attribute) and (node.func.attr == 'compile'): if isinstance(node.func.value, ast.Name) and node.func.value.id == "re": return node return convert_origin_func_to_npu( node, tf_func_map["tf.keras.Model.compile"], "Model.compile", ["optimizer"], True) if isinstance(node.func, ast.Attribute) and node.func.attr == "fit": return convert_origin_func_to_npu(node, tf_func_map["tf.keras.Model.fit"], "Model.fit", ["callbacks"], True) if isinstance(node.func, ast.Attribute) and node.func.attr == "fit_generator": return convert_origin_func_to_npu( node, tf_func_map["tf.keras.Model.fit_generator"], "Model.fit_generator", ["callbacks"], True) if isinstance(node.func, ast.Attribute) and node.func.attr == "gradients" and \ isinstance(node.func.value, ast.Name) and node.func.value.id == "tf": return convert_tf_gradient_distributed(node) return node
def visit_Subscript(self, node: ast.Subscript): if rname(node) in self.keywords: return ast.copy_location(node.value, node) return self.generic_visit(node)
def visit_BinOp(self, node): self.generic_visit(node) self.binop_count += 1 # Check if this is the node we want to alter. We can accomplish this by # keeping track of a counter, which we increment every time encounter # a BinOp. Since the traversal through the AST is deterministic using the visitor # pattern (IT IS NOT DETERMINISTIC IF YOU USE ast.walk), we can identify AST nodes # uniquely by the value of the counter if (self.binop_count == self.count_of_node_to_mutate): # We make sure to use deepcopy so that we preserve all extra # information we don't explicitly modify new_node = copy.deepcopy(node) ast.copy_location(new_node, node) # figure out a way to randomize what operator it transforms to based on the current operator if isinstance(node.op, ast.Add): #randomly generate a number which will associate to a certain type of transformation num = random.randint(0, 2) print('random number', num) if num == 0: new_node.op = ast.Mult() if num == 1: new_node.op = ast.Sub() if num == 2: new_node.op = ast.Div() if isinstance(node.op, ast.Mult): num = random.randint(0, 3) if num == 0: new_node.op = ast.Div() if num == 1: new_node.op = ast.Add() if num == 2: new_node.op = ast.FloorDiv() if num == 3: new_node.op = ast.Sub() if isinstance(node.op, ast.Div): num = random.randint(0, 2) if num == 0: new_node.op = ast.FloorDiv() if num == 1: new_node.op = ast.Mult() if num == 2: new_node.op = ast.Add() if isinstance(node.op, ast.Sub): num = random.randint(0, 2) if num == 0: new_node.op = ast.Add() if num == 1: new_node.op = ast.Mult() if num == 2: new_node.op = ast.Div() if isinstance(node.op, ast.FloorDiv): new_node.op = ast.Div() print('I AM CREATING A NEW NODE HERE', self.binop_count) return new_node else: # If we're not looking at an add node we want to change, don't modify # this node whatsoever return node
def visit_Str(self, node): return ast.copy_location(ast.Str(s=node.s.lower()), node)
def visit_For(self, node): if node.orelse: raise TaichiSyntaxError( "'else' clause for 'for' not supported in Taichi kernels") self.generic_visit(node, ['body']) decorated = isinstance(node.iter, ast.Call) and isinstance( node.iter.func, ast.Attribute) is_static_for = False is_grouped = False if decorated: attr = node.iter.func if attr.attr == 'static': is_static_for = True elif attr.attr == 'grouped': is_grouped = True is_range_for = isinstance(node.iter, ast.Call) and isinstance( node.iter.func, ast.Name) and node.iter.func.id == 'range' if is_static_for: return node elif is_range_for: loop_var = node.target.id self.check_loop_var(loop_var) template = ''' if 1: {} = ti.Expr(ti.core.make_id_expr('')) ___begin = ti.Expr(0) ___end = ti.Expr(0) ___begin = ti.cast(___begin, ti.i32) ___end = ti.cast(___end, ti.i32) ti.core.begin_frontend_range_for({}.ptr, ___begin.ptr, ___end.ptr) ti.core.end_frontend_range_for() '''.format(loop_var, loop_var) t = ast.parse(template).body[0] assert len(node.iter.args) in [1, 2] if len(node.iter.args) == 2: bgn = node.iter.args[0] end = node.iter.args[1] else: bgn = self.make_constant(value=0) end = node.iter.args[0] t.body[1].value.args[0] = bgn t.body[2].value.args[0] = end t.body = t.body[:6] + node.body + t.body[6:] t.body.append(self.parse_stmt('del {}'.format(loop_var))) return ast.copy_location(t, node) else: # Struct for if isinstance(node.target, ast.Name): elts = [node.target] else: elts = node.target.elts for loop_var in elts: self.check_loop_var(loop_var.id) var_decl = ''.join( ' {} = ti.Expr(ti.core.make_id_expr(""))\n'.format(ind.id) for ind in elts) vars = ', '.join(ind.id for ind in elts) if is_grouped: template = ''' if 1: ___loop_var = 0 {} = ti.make_var_vector(size=___loop_var.loop_range().dim()) ___expr_group = ti.make_expr_group({}) ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr) ti.core.end_frontend_range_for() '''.format(vars, vars) t = ast.parse(template).body[0] cut = 4 t.body[0].value = node.iter t.body = t.body[:cut] + node.body + t.body[cut:] else: template = ''' if 1: {} ___loop_var = 0 ___expr_group = ti.make_expr_group({}) ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr) ti.core.end_frontend_range_for() '''.format(var_decl, vars) t = ast.parse(template).body[0] cut = len(elts) + 3 t.body[cut - 3].value = node.iter t.body = t.body[:cut] + node.body + t.body[cut:] for loop_var in reversed(elts): t.body.append(self.parse_stmt('del {}'.format(loop_var.id))) return ast.copy_location(t, node)
def visit_Assign(self, node): target = rname(node.targets[-1]) if target not in self.memlets: return self.generic_visit(node) memlet, nc, wcr, dtype = self.memlets[target] value = self.visit(node.value) if not isinstance(node.targets[-1], ast.Subscript): # Dynamic accesses or streams -> every access counts try: if memlet and memlet.data and (memlet.dynamic or isinstance( self.sdfg.arrays[memlet.data], data.Stream)): if wcr is not None: newnode = ast.Name( id=self.codegen.write_and_resolve_expr( self.sdfg, memlet, nc, target, cppunparse.cppunparse(value, expr_semicolon=False), dtype=dtype)) node.value = ast.copy_location(newnode, node.value) return node elif isinstance(self.sdfg.arrays[memlet.data], data.Stream): newnode = ast.Name(id="%s.push(%s);" % ( memlet.data, cppunparse.cppunparse(value, expr_semicolon=False), )) else: var_type, ctypedef = self.codegen._dispatcher.defined_vars.get( memlet.data) if var_type == DefinedType.Scalar: newnode = ast.Name(id="%s = %s;" % ( memlet.data, cppunparse.cppunparse(value, expr_semicolon=False), )) else: newnode = ast.Name(id="%s = %s;" % ( cpp_array_expr(self.sdfg, memlet), cppunparse.cppunparse(value, expr_semicolon=False), )) return self._replace_assignment(newnode, node) except TypeError: # cannot determine truth value of Relational pass return self.generic_visit(node) subscript = self._subscript_expr(node.targets[-1].slice, target) if wcr is not None: newnode = ast.Name(id=self.codegen.write_and_resolve_expr( self.sdfg, memlet, nc, target, cppunparse.cppunparse(value, expr_semicolon=False), indices=sym2cpp(subscript), dtype=dtype) + ';') else: newnode = ast.Name( id="%s[%s] = %s;" % (target, sym2cpp(subscript), cppunparse.cppunparse(value, expr_semicolon=False))) return self._replace_assignment(newnode, node)
def visit_Attribute(self, node): return ast.copy_location( ast.Attribute(value=self.visit(node.value), attr=node.attr.lower()), node)
def apply(self, state: SDFGState, sdfg: SDFG): input: nodes.AccessNode = self.input tasklet: nodes.Tasklet = self.tasklet output: nodes.AccessNode = self.output # If state fission is necessary to keep semantics, do it first if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0): newstate = sdfg.add_state_after(state) newstate.add_node(tasklet) new_input, new_output = None, None # Keep old edges for after we remove tasklet from the original state in_edges = list(state.in_edges(tasklet)) out_edges = list(state.out_edges(tasklet)) for e in in_edges: r = newstate.add_read(e.src.data) newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data) if e.src is input: new_input = r for e in out_edges: w = newstate.add_write(e.dst.data) newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data) if e.dst is output: new_output = w # Remove tasklet and resulting isolated nodes state.remove_node(tasklet) for e in in_edges: if state.degree(e.src) == 0: state.remove_node(e.src) for e in out_edges: if state.degree(e.dst) == 0: state.remove_node(e.dst) # Reset state and nodes for rest of transformation input = new_input output = new_output state = newstate # End of state fission if self.expr_index == 0: inedges = state.edges_between(input, tasklet) outedge = state.edges_between(tasklet, output)[0] else: me = self.map_entry mx = self.map_exit inedges = state.edges_between(me, tasklet) outedge = state.edges_between(tasklet, mx)[0] # Get relevant output connector outconn = outedge.src_conn ops = '[%s]' % ''.join( re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) # Change tasklet code if tasklet.language is dtypes.Language.Python: # Match a single assignment with a binary operation as RHS ast_node: ast.Assign = tasklet.code.code[0] lhs: ast.Name = ast_node.targets[0] rhs: ast.BinOp = ast_node.value op = AugAssignToWCR._PYOP_MAP[type(rhs.op)] inconns = list(edge.dst_conn for edge in inedges) for n in (rhs.left, rhs.right): if isinstance(n, ast.Name) and n.id in inconns: inedge = inedges[inconns.index(n.id)] else: new_rhs = n new_node = ast.copy_location( ast.Assign(targets=[lhs], value=new_rhs), ast_node) tasklet.code.code = [new_node] elif tasklet.language is dtypes.Language.CPP: cstr = tasklet.code.as_string.strip() for edge in inedges: inconn = edge.dst_conn match = re.match( r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' % (re.escape(outconn), re.escape(inconn), ops), cstr) if match is None: # match = re.match( # r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' % # (re.escape(outconn), ops, re.escape(inconn)), cstr) # if match is None: continue # op = match.group(2) # expr = match.group(1) else: op = match.group(1) expr = match.group(2) if edge.data.subset != outedge.data.subset: continue # Map asymmetric WCRs to symmetric ones if possible if op in AugAssignToWCR._EXPR_MAP: op, newexpr = AugAssignToWCR._EXPR_MAP[op] expr = newexpr.format(expr=expr) tasklet.code.code = '%s = %s;' % (outconn, expr) inedge = edge break else: raise NotImplementedError # Change output edge outedge.data.wcr = f'lambda a,b: a {op} b' if self.expr_index == 0: # Remove input node and connector state.remove_edge_and_connectors(inedge) if state.degree(input) == 0: state.remove_node(input) else: # Remove input edge and dst connector, but not necessarily src state.remove_memlet_path(inedge) # If outedge leads to non-transient, and this is a nested SDFG, # propagate outwards sd = sdfg while (not sd.arrays[outedge.data.data].transient and sd.parent_nsdfg_node is not None): nsdfg = sd.parent_nsdfg_node nstate = sd.parent sd = sd.parent_sdfg outedge = next( iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data))) for outedge in nstate.memlet_path(outedge): outedge.data.wcr = f'lambda a,b: a {op} b'
def visit_Name(self, node): return ast.copy_location(ast.Name(id=node.id.lower()), node)
def replace(self, node, new_node): copy_location(new_node, node) NodeVisitor.generic_visit(self, new_node) return new_node
def visit_arg(self, node): new_node = ast.Name(node.arg, ast.Param()) ast.copy_location(new_node, node) return new_node
def visit_Raise(self, node): if node.cause: raise error.NumbaError(node, "Cause to 'raise' not supported") newnode = Raise(type=node.exc, inst=None, tback=None) return ast.copy_location(newnode, node)