def match(self, func_ir, block, typemap, calltypes): self._block = block self._func_ir = func_ir self._calls_to_rewrite = set() # Find all assignments with a RHS expr being a call to dict, and where arg # is a call to zip and store these ir.Expr for further modification for inst in find_operations(block=block, op_name='call'): expr = inst.value try: callee = func_ir.infer_constant(expr.func) except errors.ConstantInferenceError: continue if (callee is dict and len(expr.args) == 1): dict_arg_expr = guard(get_definition, func_ir, expr.args[0]) if (getattr(dict_arg_expr, 'op', None) == 'call'): called_func = guard(get_definition, func_ir, dict_arg_expr.func) if (called_func.value is zip and len(dict_arg_expr.args) == 2): self._calls_to_rewrite.add(dict_arg_expr) return len(self._calls_to_rewrite) > 0
def test_find_const_global(self): """ Test find_const() for values in globals (ir.Global) and freevars (ir.FreeVar) that are considered constants for compilation. """ FREEVAR_C = 12 def foo(a): b = GLOBAL_B c = FREEVAR_C return a + b + c f_ir = compiler.run_frontend(foo) block = f_ir.blocks[0] const_b = None const_c = None for inst in block.body: if isinstance(inst, ir.Assign) and inst.target.name == 'b': const_b = ir_utils.guard( ir_utils.find_const, f_ir, inst.target) if isinstance(inst, ir.Assign) and inst.target.name == 'c': const_c = ir_utils.guard( ir_utils.find_const, f_ir, inst.target) self.assertEqual(const_b, GLOBAL_B) self.assertEqual(const_c, FREEVAR_C)
def _get_const_two_irs(ir1, ir2, var): """get constant in either of two IRs if available otherwise, throw GuardException """ var_const = guard(find_const, ir1, var) if var_const is not None: return var_const var_const = guard(find_const, ir2, var) if var_const is not None: return var_const raise GuardException
def rewrite_tuple_len(val, func_ir, called_args): # rewrite len(tuple) as const(len(tuple)) if getattr(val, 'op', None) == 'call': func = guard(get_definition, func_ir, val.func) if (func is not None and isinstance(func, ir.Global) and getattr(func, 'value', None) is len): (arg, ) = val.args arg_def = guard(get_definition, func_ir, arg) if isinstance(arg_def, ir.Arg): argty = called_args[arg_def.index] if isinstance(argty, types.BaseTuple): rewrite_statement(func_ir, stmt, argty.count)
def get_tuple_items(var, block, func_ir): """ Returns tuple items. If tuple is constant creates and returns constants """ def wrap_into_var(value, block, func_ir, loc): stmt = declare_constant(value, block, func_ir, loc) return stmt.target val = guard(find_const, func_ir, var) if val is not None: if isinstance(val, tuple): return [wrap_into_var(v, block, func_ir, var.loc) for v in val] return None try: rhs = func_ir.get_definition(var) if isinstance(rhs, Expr): if rhs.op == 'build_tuple': return list(rhs.items) except Exception: pass return None
def test_inline_update_target_def(self): def test_impl(a): if a == 1: b = 2 else: b = 3 return b func_ir = compiler.run_frontend(test_impl) blocks = list(func_ir.blocks.values()) for block in blocks: for i, stmt in enumerate(block.body): # match b = 2 and replace with lambda: 2 if (isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Var) and guard(find_const, func_ir, stmt.value) == 2): # replace expr with a dummy call func_ir._definitions[stmt.target.name].remove(stmt.value) stmt.value = ir.Expr.call(ir.Var(block.scope, "myvar", loc=stmt.loc), (), (), stmt.loc) func_ir._definitions[stmt.target.name].append(stmt.value) #func = g.py_func# inline_closure_call(func_ir, {}, block, i, lambda: 2) break self.assertEqual(len(func_ir._definitions['b']), 2)
def test_inline_var_dict_ret(self): # make sure inline_closure_call returns the variable replacement dict # and it contains the original variable name used in locals @njit(locals={'b': types.float64}) def g(a): b = a + 1 return b def test_impl(): return g(1) func_ir = compiler.run_frontend(test_impl) blocks = list(func_ir.blocks.values()) for block in blocks: for i, stmt in enumerate(block.body): if (isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr) and stmt.value.op == 'call'): func_def = guard(get_definition, func_ir, stmt.value.func) if (isinstance(func_def, (ir.Global, ir.FreeVar)) and isinstance(func_def.value, CPUDispatcher)): py_func = func_def.value.py_func _, var_map = inline_closure_call( func_ir, py_func.__globals__, block, i, py_func) break self.assertTrue('b' in var_map)
def match(self, func_ir, block, typemap, calltypes): # TODO: check that vars are used only in read_csv self.block = block self.args = args = [] # Find all assignments with a right-hand read_csv() call for inst in find_operations(block=block, op_name='call'): expr = inst.value try: callee = func_ir.infer_constant(expr.func) except errors.ConstantInferenceError: continue if callee is not pd.read_csv: continue # collect arguments with list, set and dict # in order to replace with tuple for key, var in expr.kws: if key in self._read_csv_const_args: arg_def = guard(get_definition, func_ir, var) ops = ['build_list', 'build_set', 'build_map'] if arg_def.op in ops: args.append(arg_def) return len(args) > 0
def match(self, func_ir, block, typemap, calltypes): self.args = args = [] self.block = block for inst in block.find_insts(ir.Assign): if isinstance(inst.value, ir.Expr) and inst.value.op == 'call': expr = inst.value try: callee = func_ir.infer_constant(expr.func) except errors.ConstantInferenceError: continue if callee is self.callee: if not self.match_expr(expr, func_ir, block, typemap, calltypes): continue arg_var = None if len(expr.args): arg_var = expr.args[0] elif len(expr.kws) and expr.kws[0][0] == self.arg: arg_var = expr.kws[0][1] if arg_var: arg_var_def = guard(get_definition, func_ir, arg_var) if arg_var_def and arg_var_def.op in ('build_list', 'build_set'): args.append(arg_var_def) return len(args) > 0
def get_ctxmgr_obj(var_ref): """Return the context-manager object and extra info. The extra contains the arguments if the context-manager is used as a call. """ # If the contextmanager used as a Call dfn = func_ir.get_definition(var_ref) if isinstance(dfn, ir.Expr) and dfn.op == 'call': args = [get_var_dfn(x) for x in dfn.args] kws = {k: get_var_dfn(v) for k, v in dfn.kws} extra = {'args': args, 'kwargs': kws} var_ref = dfn.func else: extra = None ctxobj = ir_utils.guard(ir_utils.find_global_value, func_ir, var_ref) # check the contextmanager object if ctxobj is ir.UNDEFINED: raise errors.CompilerError( "Undefined variable used as context manager", loc=blocks[blk_start].loc, ) if ctxobj is None: raise errors.CompilerError(_illegal_cm_msg, loc=dfn.loc) return ctxobj, extra
def run(self): """ Finds all calls to StencilFuncs in the IR and converts them to parfor. """ from numba.stencils.stencil import StencilFunc # Get all the calls in the function IR. call_table, _ = get_call_table(self.func_ir.blocks) stencil_calls = [] stencil_dict = {} for call_varname, call_list in call_table.items(): for one_call in call_list: if isinstance(one_call, StencilFunc): # Remember all calls to StencilFuncs. stencil_calls.append(call_varname) stencil_dict[call_varname] = one_call if not stencil_calls: return # return early if no stencil calls found # find and transform stencil calls for label, block in self.func_ir.blocks.items(): for i, stmt in reversed(list(enumerate(block.body))): # Found a call to a StencilFunc. if (isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr) and stmt.value.op == 'call' and stmt.value.func.name in stencil_calls): kws = dict(stmt.value.kws) # Create dictionary of input argument number to # the argument itself. input_dict = {i: stmt.value.args[i] for i in range(len(stmt.value.args))} in_args = stmt.value.args arg_typemap = tuple(self.typemap[i.name] for i in in_args) for arg_type in arg_typemap: if isinstance(arg_type, types.BaseTuple): raise ValueError("Tuple parameters not supported " \ "for stencil kernels in parallel=True mode.") out_arr = kws.get('out') # Get the StencilFunc object corresponding to this call. sf = stencil_dict[stmt.value.func.name] stencil_ir, rt, arg_to_arr_dict = get_stencil_ir(sf, self.typingctx, arg_typemap, block.scope, block.loc, input_dict, self.typemap, self.calltypes) index_offsets = sf.options.get('index_offsets', None) gen_nodes = self._mk_stencil_parfor(label, in_args, out_arr, stencil_ir, index_offsets, stmt.target, rt, sf, arg_to_arr_dict) block.body = block.body[:i] + gen_nodes + block.body[i+1:] # Found a call to a stencil via numba.stencil(). elif (isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr) and stmt.value.op == 'call' and guard(find_callname, self.func_ir, stmt.value) == ('stencil', 'numba')): # remove dummy stencil() call stmt.value = ir.Const(0, stmt.loc)
def find_literally_calls(func_ir, argtypes): """An analysis to find `numba.literally` call inside the given IR. When an unsatisfied literal typing request is found, a `ForceLiteralArg` exception is raised. Parameters ---------- func_ir : numba.ir.FunctionIR argtypes : Sequence[numba.types.Type] The argument types. """ from numba.core import ir_utils marked_args = set() first_loc = {} # Scan for literally calls for blk in func_ir.blocks.values(): for assign in blk.find_exprs(op='call'): var = ir_utils.guard(ir_utils.get_definition, func_ir, assign.func) if isinstance(var, (ir.Global, ir.FreeVar)): fnobj = var.value else: fnobj = ir_utils.guard(ir_utils.resolve_func_from_module, func_ir, var) if fnobj is special.literally: # Found [arg] = assign.args defarg = func_ir.get_definition(arg) if isinstance(defarg, ir.Arg): argindex = defarg.index marked_args.add(argindex) first_loc.setdefault(argindex, assign.loc) # Signal the dispatcher to force literal typing for pos in marked_args: query_arg = argtypes[pos] do_raise = (isinstance(query_arg, types.InitialValue) and query_arg.initial_value is None) if do_raise: loc = first_loc[pos] raise errors.ForceLiteralArg(marked_args, loc=loc) if not isinstance(query_arg, (types.Literal, types.InitialValue)): loc = first_loc[pos] raise errors.ForceLiteralArg(marked_args, loc=loc)
def find_branches(func_ir): # find *all* branches branches = [] for blk in func_ir.blocks.values(): branch_or_jump = blk.body[-1] if isinstance(branch_or_jump, ir.Branch): branch = branch_or_jump pred = guard(get_definition, func_ir, branch.cond.name) if pred is not None and pred.op == "call": function = guard(get_definition, func_ir, pred.func) if (function is not None and isinstance(function, ir.Global) and function.value is bool): condition = guard(get_definition, func_ir, pred.args[0]) if condition is not None: branches.append((branch, condition, blk)) return branches
def rewrite_array_ndim(val, func_ir, called_args): # rewrite Array.ndim as const(ndim) if getattr(val, 'op', None) == 'getattr': if val.attr == 'ndim': arg_def = guard(get_definition, func_ir, val.value) if isinstance(arg_def, ir.Arg): argty = called_args[arg_def.index] if isinstance(argty, types.Array): rewrite_statement(func_ir, stmt, argty.ndim)
def get_constant(func_ir, var, default=NOT_CONSTANT): def_node = guard(get_definition, func_ir, var) if def_node is None: return default if isinstance(def_node, ir.Const): return def_node.value # call recursively if variable assignment if isinstance(def_node, ir.Var): return get_constant(func_ir, def_node, default) return default
def _get_const_index_expr_inner(stencil_ir, func_ir, index_var): """inner constant inference function that calls constant, unary and binary cases. """ require(isinstance(index_var, ir.Var)) # case where the index is a const itself in outer function var_const = guard(_get_const_two_irs, stencil_ir, func_ir, index_var) if var_const is not None: return var_const # get index definition index_def = ir_utils.get_definition(stencil_ir, index_var) # match inner_var = unary(index_var) var_const = guard(_get_const_unary_expr, stencil_ir, func_ir, index_def) if var_const is not None: return var_const # match inner_var = arg1 + arg2 var_const = guard(_get_const_binary_expr, stencil_ir, func_ir, index_def) if var_const is not None: return var_const raise GuardException
def find_branches(func_ir): # find *all* branches branches = [] for blk in func_ir.blocks.values(): branch_or_jump = blk.body[-1] if isinstance(branch_or_jump, ir.Branch): branch = branch_or_jump condition = guard(get_definition, func_ir, branch.cond.name) if condition is not None: branches.append((branch, condition, blk)) return branches
def _get_const_index_expr(stencil_ir, func_ir, index_var): """ infer index_var as constant if it is of a expression form like c-1 where c is a constant in the outer function. index_var is assumed to be inside stencil kernel """ const_val = guard(_get_const_index_expr_inner, stencil_ir, func_ir, index_var) if const_val is not None: return const_val return index_var
def run_pass(self, state): # assuming the function has one block with one call inside assert len(state.func_ir.blocks) == 1 block = list(state.func_ir.blocks.values())[0] for i, stmt in enumerate(block.body): if (guard(find_callname, state.func_ir, stmt.value) is not None): inline_closure_call(state.func_ir, {}, block, i, foo.py_func, state.typingctx, (state.type_annotation.typemap[stmt.value.args[0].name],), state.type_annotation.typemap, state.calltypes) break return True
def _analyze_call_array(self, lhs, arr, func_name, args, array_dists): """analyze distributions of array functions (arr.func_name) """ if func_name == 'transpose': if len(args) == 0: raise ValueError("Transpose with no arguments is not" " supported") in_arr_name = arr.name arg0 = guard(get_constant, self.func_ir, args[0]) if isinstance(arg0, tuple): arg0 = arg0[0] if arg0 != 0: raise ValueError("Transpose with non-zero first argument" " is not supported") self._meet_array_dists(lhs, in_arr_name, array_dists) return if func_name in ('astype', 'reshape', 'copy'): in_arr_name = arr.name self._meet_array_dists(lhs, in_arr_name, array_dists) # TODO: support 1D_Var reshape if func_name == 'reshape' and array_dists[ lhs] == Distribution.OneD_Var: # HACK support A.reshape(n, 1) for 1D_Var if len(args) == 2 and guard(find_const, self.func_ir, args[1]) == 1: return self._analyze_call_set_REP(lhs, args, array_dists, 'array.' + func_name) return # Array.tofile() is supported for all distributions if func_name == 'tofile': return # set REP if not found self._analyze_call_set_REP(lhs, args, array_dists, 'array.' + func_name)
def _analyze_setitem(self, inst, array_dists): if isinstance(inst, ir.SetItem): index_var = inst.index else: index_var = inst.index_var if ((inst.target.name, index_var.name) in self._parallel_accesses): # no parallel to parallel array set (TODO) return tup_list = guard(find_build_tuple, self.func_ir, index_var) if tup_list is not None: index_var = tup_list[0] # rest of indices should be replicated if array self._set_REP(tup_list[1:], array_dists) if guard(is_whole_slice, self.typemap, self.func_ir, index_var): # for example: X[:,3] = A self._meet_array_dists(inst.target.name, inst.value.name, array_dists) return self._set_REP([inst.value], array_dists)
def run_pass(self, state): # assuming the function has one block with one call inside assert len(state.func_ir.blocks) == 1 block = list(state.func_ir.blocks.values())[0] for i, stmt in enumerate(block.body): if guard(find_callname,state.func_ir, stmt.value) is not None: inline_closure_call(state.func_ir, {}, block, i, lambda: None, state.typingctx, (), state.type_annotation.typemap, state.type_annotation.calltypes) break # also fix up the IR post_proc = postproc.PostProcessor(state.func_ir) post_proc.run() post_proc.remove_dels() return True
def check_dtype_is_categorical(self, expr, func_ir, block, typemap, calltypes): dtype_var = None for name, var in expr.kws: if name == 'dtype': dtype_var = var if not dtype_var: return False dtype_var_def = guard(get_definition, func_ir, dtype_var) is_alias = isinstance(dtype_var_def, ir.Const) and dtype_var_def.value == 'category' is_categoricaldtype = (hasattr(dtype_var_def, 'func') and func_ir.infer_constant(dtype_var_def.func) == pd.CategoricalDtype) if not (is_alias or is_categoricaldtype): return False return True
def run_pass(self, state): """Run inlining of overloads """ if self._DEBUG: print('before overload inline'.center(80, '-')) print(state.func_ir.dump()) print(''.center(80, '-')) modified = False work_list = list(state.func_ir.blocks.items()) # use a work list, look for call sites via `ir.Expr.op == call` and # then pass these to `self._do_work` to make decisions about inlining. while work_list: label, block = work_list.pop() for i, instr in enumerate(block.body): if isinstance(instr, ir.Assign): expr = instr.value if isinstance(expr, ir.Expr): if expr.op == 'call': workfn = self._do_work_call elif expr.op == 'getattr': workfn = self._do_work_getattr else: continue if guard(workfn, state, work_list, block, i, expr): modified = True break # because block structure changed if self._DEBUG: print('after overload inline'.center(80, '-')) print(state.func_ir.dump()) print(''.center(80, '-')) if modified: # clean up blocks dead_code_elimination(state.func_ir, typemap=state.type_annotation.typemap) # clean up unconditional branches that appear due to inlined # functions introducing blocks state.func_ir.blocks = simplify_CFG(state.func_ir.blocks) if self._DEBUG: print('after overload inline DCE'.center(80, '-')) print(state.func_ir.dump()) print(''.center(80, '-')) return True
def is_tuple(var, func_ir): """ Checks if variable is either constant or non-constant tuple """ val = guard(find_const, func_ir, var) if val is not None: return isinstance(val, tuple) try: rhs = func_ir.get_definition(var) if isinstance(rhs, Expr): return rhs.op == 'build_tuple' except Exception: pass return False
def _set_REP(self, var_list, array_dists): for var in var_list: varname = var.name # Handle SeriesType since it comes from Arg node and it could # have user-defined distribution if (is_array(self.typemap, varname) or is_array_container(self.typemap, varname) or isinstance(self.typemap[varname], (SeriesType, DataFrameType))): dprint("dist setting REP {}".format(varname)) array_dists[varname] = Distribution.REP # handle tuples of arrays var_def = guard(get_definition, self.func_ir, var) if (var_def is not None and isinstance(var_def, ir.Expr) and var_def.op == 'build_tuple'): tuple_vars = var_def.items self._set_REP(tuple_vars, array_dists)
def match(self, func_ir, block, typemap, calltypes): self._reset() self._block = block self._func_ir = func_ir self._calls_to_rewrite = set() for stmt in find_operations(block=block, op_name='call'): expr = stmt.value fdef = guard(find_callname, func_ir, expr) if fdef == self._pandas_dataframe: args = get_call_parameters(call=expr, arg_names=self._df_arg_list) if self._match_dict_case(args, func_ir): self._calls_to_rewrite.add(stmt) else: pass # Forward this case to pd_dataframe_overload which will handle it return len(self._calls_to_rewrite) > 0
def _analyze_call_np_concatenate(self, lhs, args, array_dists): assert len(args) == 1 tup_def = guard(get_definition, self.func_ir, args[0]) assert isinstance(tup_def, ir.Expr) and tup_def.op == 'build_tuple' in_arrs = tup_def.items # input arrays have same distribution in_dist = Distribution.OneD for v in in_arrs: in_dist = Distribution( min(in_dist.value, array_dists[v.name].value)) # OneD_Var since sum of block sizes might not be exactly 1D out_dist = Distribution.OneD_Var out_dist = Distribution(min(out_dist.value, in_dist.value)) array_dists[lhs] = out_dist # output can cause input REP if out_dist != Distribution.OneD_Var: in_dist = out_dist for v in in_arrs: array_dists[v.name] = in_dist return
def pre_block(self, block): from numba.core.unsafe import eh super(Lower, self).pre_block(block) # Detect if we are in a TRY block by looking for a call to # `eh.exception_check`. for call in block.find_exprs(op='call'): defn = ir_utils.guard( ir_utils.get_definition, self.func_ir, call.func, ) if defn is not None and isinstance(defn, ir.Global): if defn.value is eh.exception_check: if isinstance(block.terminator, ir.Branch): targetblk = self.blkmap[block.terminator.truebr] # NOTE: This hacks in an attribute for call_conv to # pick up. This hack is no longer needed when # all old-style implementations are gone. self.builder._in_try_block = {'target': targetblk} break
def _get_array_accesses(blocks, func_ir, typemap, accesses=None): """returns a set of arrays accessed and their indices. """ if accesses is None: accesses = set() for block in blocks.values(): for inst in block.body: if isinstance(inst, ir.SetItem): accesses.add((inst.target.name, inst.index.name)) if isinstance(inst, ir.StaticSetItem): accesses.add((inst.target.name, inst.index_var.name)) if isinstance(inst, ir.Assign): lhs = inst.target.name rhs = inst.value if isinstance(rhs, ir.Expr) and rhs.op == 'getitem': accesses.add((rhs.value.name, rhs.index.name)) if isinstance(rhs, ir.Expr) and rhs.op == 'static_getitem': index = rhs.index # slice is unhashable, so just keep the variable if index is None or ir_utils.is_slice_index(index): index = rhs.index_var.name accesses.add((rhs.value.name, index)) if isinstance(rhs, ir.Expr) and rhs.op == 'call': fdef = guard(find_callname, func_ir, rhs, typemap) if fdef is not None: if fdef == ('get_split_view_index', 'sdc.hiframes.split_impl'): accesses.add((rhs.args[0].name, rhs.args[1].name)) if fdef == ('setitem_str_arr_ptr', 'sdc.str_arr_ext'): accesses.add((rhs.args[0].name, rhs.args[1].name)) if fdef == ('str_arr_item_to_numeric', 'sdc.str_arr_ext'): accesses.add((rhs.args[0].name, rhs.args[1].name)) accesses.add((rhs.args[2].name, rhs.args[3].name)) for T, f in array_accesses_extensions.items(): if isinstance(inst, T): f(inst, func_ir, typemap, accesses) return accesses