def test_inline_call_branch_pruning(self): # branch pruning pass should run properly in inlining to enable # functions with type checks @njit def foo(A=None): if A is None: return 2 else: return A def test_impl(A=None): return foo(A) @register_pass(analysis_only=False, mutates_CFG=True) class PruningInlineTestPass(FunctionPass): _name = "pruning_inline_test_pass" def __init__(self): FunctionPass.__init__(self) def run_pass(self, state): # assuming the function has one block with one call inside assert len(state.func_ir.blocks) == 1 block = list(state.func_ir.blocks.values())[0] for i, stmt in enumerate(block.body): if guard(find_callname, state.func_ir, stmt.value) is not None: inline_closure_call( state.func_ir, {}, block, i, foo.py_func, state.typingctx, (state.type_annotation.typemap[ stmt.value.args[0].name], ), state.type_annotation.typemap, state.calltypes, ) break return True class InlineTestPipelinePrune(compiler.CompilerBase): def define_pipelines(self): pm = gen_pipeline(self.state, PruningInlineTestPass) pm.finalize() return [pm] # make sure inline_closure_call runs in full pipeline j_func = njit(pipeline_class=InlineTestPipelinePrune)(test_impl) A = 3 self.assertEqual(test_impl(A), j_func(A)) self.assertEqual(test_impl(), j_func()) # make sure IR doesn't have branches fir = j_func.overloads[( types.Omitted(None), )].metadata["preserved_ir"] fir.blocks = simplify_CFG(fir.blocks) self.assertEqual(len(fir.blocks), 1)
def dufunc_inliner(func_ir, calltypes, typemap, typingctx, targetctx): _DEBUG = False modified = False if _DEBUG: print("GUFunc before inlining DUFunc".center(80, "-")) print(func_ir.dump()) work_list = list(func_ir.blocks.items()) # use a work list, look for call sites via `ir.Expr.op == call` and # then pass these to `self._do_work` to make decisions about inlining. while work_list: label, block = work_list.pop() for i, instr in enumerate(block.body): if isinstance(instr, ir.Assign): expr = instr.value if isinstance(expr, ir.Expr): call_node = _is_dufunc_callsite(expr, block) if call_node: py_func = call_node.value._dispatcher.py_func workfn = _inline( func_ir, work_list, block, i, expr, py_func, typemap, calltypes, typingctx, targetctx, ) if workfn: modified = True break # because block structure changed else: continue if _DEBUG: print("GUFunc after inlining DUFunc".center(80, "-")) print(func_ir.dump()) print("".center(80, "-")) if modified: # clean up leftover load instructions. This step is needed or else # SpirvLowerer would complain dead_code_elimination(func_ir, typemap=typemap) # clean up unconditional branches that appear due to inlined # functions introducing blocks func_ir.blocks = simplify_CFG(func_ir.blocks) if _DEBUG: print("GUFunc after inlining DUFunc, DCE, SimplyCFG".center(80, "-")) print(func_ir.dump()) print("".center(80, "-")) return True
def run_pass(self, state): rewrite_ndarray_function_name_pass = RewriteNdarrayFunctions( state, rewrite_function_name_map) mutated = rewrite_ndarray_function_name_pass.run() if mutated: remove_dead(state.func_ir.blocks, state.func_ir.arg_names, state.func_ir) state.func_ir.blocks = simplify_CFG(state.func_ir.blocks) return mutated
def run_pass(self, state): """Run inlining of overloads """ if self._DEBUG: print('before overload inline'.center(80, '-')) print(state.func_ir.dump()) print(''.center(80, '-')) modified = False work_list = list(state.func_ir.blocks.items()) # use a work list, look for call sites via `ir.Expr.op == call` and # then pass these to `self._do_work` to make decisions about inlining. while work_list: label, block = work_list.pop() for i, instr in enumerate(block.body): if isinstance(instr, ir.Assign): expr = instr.value if isinstance(expr, ir.Expr): if expr.op == 'call': workfn = self._do_work_call elif expr.op == 'getattr': workfn = self._do_work_getattr else: continue if guard(workfn, state, work_list, block, i, expr): modified = True break # because block structure changed if self._DEBUG: print('after overload inline'.center(80, '-')) print(state.func_ir.dump()) print(''.center(80, '-')) if modified: # clean up blocks dead_code_elimination(state.func_ir, typemap=state.type_annotation.typemap) # clean up unconditional branches that appear due to inlined # functions introducing blocks state.func_ir.blocks = simplify_CFG(state.func_ir.blocks) if self._DEBUG: print('after overload inline DCE'.center(80, '-')) print(state.func_ir.dump()) print(''.center(80, '-')) return True
def inline_new_blocks(func_ir, block, i, callee_blocks, work_list=None): # adopted from inline_closure_call scope = block.scope instr = block.body[i] # 1. relabel callee_ir by adding an offset callee_blocks = add_offset_to_labels(callee_blocks, ir_utils._max_label + 1) callee_blocks = ir_utils.simplify_CFG(callee_blocks) max_label = max(callee_blocks.keys()) # reset globals in ir_utils before we use it ir_utils._max_label = max_label topo_order = find_topo_order(callee_blocks) # 5. split caller blocks into two new_blocks = [] new_block = ir.Block(scope, block.loc) new_block.body = block.body[i + 1:] new_label = ir_utils.next_label() func_ir.blocks[new_label] = new_block new_blocks.append((new_label, new_block)) block.body = block.body[:i] min_label = topo_order[0] block.body.append(ir.Jump(min_label, instr.loc)) # 6. replace Return with assignment to LHS numba.core.inline_closurecall._replace_returns(callee_blocks, instr.target, new_label) # remove the old definition of instr.target too if (instr.target.name in func_ir._definitions): func_ir._definitions[instr.target.name] = [] # 7. insert all new blocks, and add back definitions for label in topo_order: # block scope must point to parent's block = callee_blocks[label] block.scope = scope numba.core.inline_closurecall._add_definitions(func_ir, block) func_ir.blocks[label] = block new_blocks.append((label, block)) if work_list is not None: for block in new_blocks: work_list.append(block) return callee_blocks
def check(self, test_impl, *args, **kwargs): inline_expect = kwargs.pop('inline_expect', None) assert inline_expect block_count = kwargs.pop('block_count', 1) assert not kwargs for k, v in inline_expect.items(): assert isinstance(k, str) assert isinstance(v, bool) j_func = njit(pipeline_class=IRPreservingTestPipeline)(test_impl) # check they produce the same answer first! self.assertEqual(test_impl(*args), j_func(*args)) # make sure IR doesn't have branches fir = j_func.overloads[j_func.signatures[0]].metadata['preserved_ir'] fir.blocks = ir_utils.simplify_CFG(fir.blocks) if self._DEBUG: print("FIR".center(80, "-")) fir.dump() if block_count != 'SKIP': self.assertEqual(len(fir.blocks), block_count) block = next(iter(fir.blocks.values())) # if we don't expect the function to be inlined then make sure there is # 'call' present still exprs = [x for x in block.find_exprs()] assert exprs for k, v in inline_expect.items(): found = False for expr in exprs: if getattr(expr, 'op', False) == 'call': func_defn = fir.get_definition(expr.func) found |= func_defn.name == k elif ir_utils.is_operator_or_getitem(expr): found |= expr.fn.__name__ == k self.assertFalse(found == v) return fir # for use in further analysis
def run_pass(self, state): """Run inlining of overloads """ if self._DEBUG: print('before overload inline'.center(80, '-')) print(state.func_id.unique_name) print(state.func_ir.dump()) print(''.center(80, '-')) from numba.core.inline_closurecall import (InlineWorker, callee_ir_validator) inline_worker = InlineWorker( state.typingctx, state.targetctx, state.locals, state.pipeline, state.flags, callee_ir_validator, state.typemap, state.calltypes, ) modified = False work_list = list(state.func_ir.blocks.items()) # use a work list, look for call sites via `ir.Expr.op == call` and # then pass these to `self._do_work` to make decisions about inlining. while work_list: label, block = work_list.pop() for i, instr in enumerate(block.body): # TO-DO: other statements (setitem) if isinstance(instr, ir.Assign): expr = instr.value if isinstance(expr, ir.Expr): workfn = self._do_work_expr if guard(workfn, state, work_list, block, i, expr, inline_worker): modified = True break # because block structure changed if self._DEBUG: print('after overload inline'.center(80, '-')) print(state.func_id.unique_name) print(state.func_ir.dump()) print(''.center(80, '-')) if modified: # Remove dead blocks, this is safe as it relies on the CFG only. cfg = compute_cfg_from_blocks(state.func_ir.blocks) for dead in cfg.dead_nodes(): del state.func_ir.blocks[dead] # clean up blocks dead_code_elimination(state.func_ir, typemap=state.type_annotation.typemap) # clean up unconditional branches that appear due to inlined # functions introducing blocks state.func_ir.blocks = simplify_CFG(state.func_ir.blocks) if self._DEBUG: print('after overload inline DCE'.center(80, '-')) print(state.func_id.unique_name) print(state.func_ir.dump()) print(''.center(80, '-')) return True
def _mk_stencil_parfor(self, label, in_args, out_arr, stencil_ir, index_offsets, target, return_type, stencil_func, arg_to_arr_dict): """ Converts a set of stencil kernel blocks to a parfor. """ gen_nodes = [] stencil_blocks = stencil_ir.blocks if config.DEBUG_ARRAY_OPT >= 1: print("_mk_stencil_parfor", label, in_args, out_arr, index_offsets, return_type, stencil_func, stencil_blocks) ir_utils.dump_blocks(stencil_blocks) in_arr = in_args[0] # run copy propagate to replace in_args copies (e.g. a = A) in_arr_typ = self.typemap[in_arr.name] in_cps, out_cps = ir_utils.copy_propagate(stencil_blocks, self.typemap) name_var_table = ir_utils.get_name_var_table(stencil_blocks) ir_utils.apply_copy_propagate(stencil_blocks, in_cps, name_var_table, self.typemap, self.calltypes) if config.DEBUG_ARRAY_OPT >= 1: print("stencil_blocks after copy_propagate") ir_utils.dump_blocks(stencil_blocks) ir_utils.remove_dead(stencil_blocks, self.func_ir.arg_names, stencil_ir, self.typemap) if config.DEBUG_ARRAY_OPT >= 1: print("stencil_blocks after removing dead code") ir_utils.dump_blocks(stencil_blocks) # create parfor vars ndims = self.typemap[in_arr.name].ndim scope = in_arr.scope loc = in_arr.loc parfor_vars = [] for i in range(ndims): parfor_var = ir.Var(scope, mk_unique_var("$parfor_index_var"), loc) self.typemap[parfor_var.name] = types.intp parfor_vars.append(parfor_var) start_lengths, end_lengths = self._replace_stencil_accesses( stencil_ir, parfor_vars, in_args, index_offsets, stencil_func, arg_to_arr_dict) if config.DEBUG_ARRAY_OPT >= 1: print("stencil_blocks after replace stencil accesses") ir_utils.dump_blocks(stencil_blocks) # create parfor loop nests loopnests = [] equiv_set = self.array_analysis.get_equiv_set(label) in_arr_dim_sizes = equiv_set.get_shape(in_arr) assert ndims == len(in_arr_dim_sizes) for i in range(ndims): last_ind = self._get_stencil_last_ind(in_arr_dim_sizes[i], end_lengths[i], gen_nodes, scope, loc) start_ind = self._get_stencil_start_ind(start_lengths[i], gen_nodes, scope, loc) # start from stencil size to avoid invalid array access loopnests.append( numba.parfors.parfor.LoopNest(parfor_vars[i], start_ind, last_ind, 1)) # We have to guarantee that the exit block has maximum label and that # there's only one exit block for the parfor body. # So, all return statements will change to jump to the parfor exit block. parfor_body_exit_label = max(stencil_blocks.keys()) + 1 stencil_blocks[parfor_body_exit_label] = ir.Block(scope, loc) exit_value_var = ir.Var(scope, mk_unique_var("$parfor_exit_value"), loc) self.typemap[exit_value_var.name] = return_type.dtype # create parfor index var for_replacing_ret = [] if ndims == 1: parfor_ind_var = parfor_vars[0] else: parfor_ind_var = ir.Var(scope, mk_unique_var("$parfor_index_tuple_var"), loc) self.typemap[parfor_ind_var.name] = types.containers.UniTuple( types.intp, ndims) tuple_call = ir.Expr.build_tuple(parfor_vars, loc) tuple_assign = ir.Assign(tuple_call, parfor_ind_var, loc) for_replacing_ret.append(tuple_assign) if config.DEBUG_ARRAY_OPT >= 1: print("stencil_blocks after creating parfor index var") ir_utils.dump_blocks(stencil_blocks) # empty init block init_block = ir.Block(scope, loc) if out_arr is None: in_arr_typ = self.typemap[in_arr.name] shape_name = ir_utils.mk_unique_var("in_arr_shape") shape_var = ir.Var(scope, shape_name, loc) shape_getattr = ir.Expr.getattr(in_arr, "shape", loc) self.typemap[shape_name] = types.containers.UniTuple( types.intp, in_arr_typ.ndim) init_block.body.extend([ir.Assign(shape_getattr, shape_var, loc)]) zero_name = ir_utils.mk_unique_var("zero_val") zero_var = ir.Var(scope, zero_name, loc) if "cval" in stencil_func.options: cval = stencil_func.options["cval"] # TODO: Loosen this restriction to adhere to casting rules. if return_type.dtype != typing.typeof.typeof(cval): raise ValueError( "cval type does not match stencil return type.") temp2 = return_type.dtype(cval) else: temp2 = return_type.dtype(0) full_const = ir.Const(temp2, loc) self.typemap[zero_name] = return_type.dtype init_block.body.extend([ir.Assign(full_const, zero_var, loc)]) so_name = ir_utils.mk_unique_var("stencil_output") out_arr = ir.Var(scope, so_name, loc) self.typemap[out_arr.name] = numba.core.types.npytypes.Array( return_type.dtype, in_arr_typ.ndim, in_arr_typ.layout) dtype_g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc) self.typemap[dtype_g_np_var.name] = types.misc.Module(np) dtype_g_np = ir.Global('np', np, loc) dtype_g_np_assign = ir.Assign(dtype_g_np, dtype_g_np_var, loc) init_block.body.append(dtype_g_np_assign) dtype_np_attr_call = ir.Expr.getattr(dtype_g_np_var, return_type.dtype.name, loc) dtype_attr_var = ir.Var(scope, mk_unique_var("$np_attr_attr"), loc) self.typemap[dtype_attr_var.name] = types.functions.NumberClass( return_type.dtype) dtype_attr_assign = ir.Assign(dtype_np_attr_call, dtype_attr_var, loc) init_block.body.append(dtype_attr_assign) stmts = ir_utils.gen_np_call("full", np.full, out_arr, [shape_var, zero_var, dtype_attr_var], self.typingctx, self.typemap, self.calltypes) equiv_set.insert_equiv(out_arr, in_arr_dim_sizes) init_block.body.extend(stmts) else: # out is present if "cval" in stencil_func.options: # do out[:] = cval cval = stencil_func.options["cval"] # TODO: Loosen this restriction to adhere to casting rules. cval_ty = typing.typeof.typeof(cval) if not self.typingctx.can_convert(cval_ty, return_type.dtype): msg = "cval type does not match stencil return type." raise ValueError(msg) # get slice ref slice_var = ir.Var(scope, mk_unique_var("$py_g_var"), loc) slice_fn_ty = self.typingctx.resolve_value_type(slice) self.typemap[slice_var.name] = slice_fn_ty slice_g = ir.Global('slice', slice, loc) slice_assigned = ir.Assign(slice_g, slice_var, loc) init_block.body.append(slice_assigned) sig = self.typingctx.resolve_function_type( slice_fn_ty, (types.none, ) * 2, {}) callexpr = ir.Expr.call(func=slice_var, args=(), kws=(), loc=loc) self.calltypes[callexpr] = sig slice_inst_var = ir.Var(scope, mk_unique_var("$slice_inst"), loc) self.typemap[slice_inst_var.name] = types.slice2_type slice_assign = ir.Assign(callexpr, slice_inst_var, loc) init_block.body.append(slice_assign) # get const val for cval cval_const_val = ir.Const(return_type.dtype(cval), loc) cval_const_var = ir.Var(scope, mk_unique_var("$cval_const"), loc) self.typemap[cval_const_var.name] = return_type.dtype cval_const_assign = ir.Assign(cval_const_val, cval_const_var, loc) init_block.body.append(cval_const_assign) # do setitem on `out` array setitemexpr = ir.StaticSetItem(out_arr, slice(None, None), slice_inst_var, cval_const_var, loc) init_block.body.append(setitemexpr) sig = signature(types.none, self.typemap[out_arr.name], self.typemap[slice_inst_var.name], self.typemap[out_arr.name].dtype) self.calltypes[setitemexpr] = sig self.replace_return_with_setitem(stencil_blocks, exit_value_var, parfor_body_exit_label) if config.DEBUG_ARRAY_OPT >= 1: print("stencil_blocks after replacing return") ir_utils.dump_blocks(stencil_blocks) setitem_call = ir.SetItem(out_arr, parfor_ind_var, exit_value_var, loc) self.calltypes[setitem_call] = signature( types.none, self.typemap[out_arr.name], self.typemap[parfor_ind_var.name], self.typemap[out_arr.name].dtype) stencil_blocks[parfor_body_exit_label].body.extend(for_replacing_ret) stencil_blocks[parfor_body_exit_label].body.append(setitem_call) # simplify CFG of parfor body (exit block could be simplified often) # add dummy return to enable CFG dummy_loc = ir.Loc("stencilparfor_dummy", -1) ret_const_var = ir.Var(scope, mk_unique_var("$cval_const"), dummy_loc) cval_const_assign = ir.Assign(ir.Const(0, loc=dummy_loc), ret_const_var, dummy_loc) stencil_blocks[parfor_body_exit_label].body.append(cval_const_assign) stencil_blocks[parfor_body_exit_label].body.append( ir.Return(ret_const_var, dummy_loc), ) stencil_blocks = ir_utils.simplify_CFG(stencil_blocks) stencil_blocks[max(stencil_blocks.keys())].body.pop() if config.DEBUG_ARRAY_OPT >= 1: print("stencil_blocks after adding SetItem") ir_utils.dump_blocks(stencil_blocks) pattern = ('stencil', [start_lengths, end_lengths]) parfor = numba.parfors.parfor.Parfor(loopnests, init_block, stencil_blocks, loc, parfor_ind_var, equiv_set, pattern, self.flags) gen_nodes.append(parfor) gen_nodes.append(ir.Assign(out_arr, target, loc)) return gen_nodes