def _mk_stencil_parfor(self, label, in_args, out_arr, stencil_ir, index_offsets, target, return_type, stencil_func, arg_to_arr_dict): """ Converts a set of stencil kernel blocks to a parfor. """ gen_nodes = [] stencil_blocks = stencil_ir.blocks if config.DEBUG_ARRAY_OPT >= 1: print("_mk_stencil_parfor", label, in_args, out_arr, index_offsets, return_type, stencil_func, stencil_blocks) ir_utils.dump_blocks(stencil_blocks) in_arr = in_args[0] # run copy propagate to replace in_args copies (e.g. a = A) in_arr_typ = self.typemap[in_arr.name] in_cps, out_cps = ir_utils.copy_propagate(stencil_blocks, self.typemap) name_var_table = ir_utils.get_name_var_table(stencil_blocks) ir_utils.apply_copy_propagate( stencil_blocks, in_cps, name_var_table, self.typemap, self.calltypes) if config.DEBUG_ARRAY_OPT >= 1: print("stencil_blocks after copy_propagate") ir_utils.dump_blocks(stencil_blocks) ir_utils.remove_dead(stencil_blocks, self.func_ir.arg_names, stencil_ir, self.typemap) if config.DEBUG_ARRAY_OPT >= 1: print("stencil_blocks after removing dead code") ir_utils.dump_blocks(stencil_blocks) # create parfor vars ndims = self.typemap[in_arr.name].ndim scope = in_arr.scope loc = in_arr.loc parfor_vars = [] for i in range(ndims): parfor_var = ir.Var(scope, mk_unique_var( "$parfor_index_var"), loc) self.typemap[parfor_var.name] = types.intp parfor_vars.append(parfor_var) start_lengths, end_lengths = self._replace_stencil_accesses( stencil_ir, parfor_vars, in_args, index_offsets, stencil_func, arg_to_arr_dict) if config.DEBUG_ARRAY_OPT >= 1: print("stencil_blocks after replace stencil accesses") ir_utils.dump_blocks(stencil_blocks) # create parfor loop nests loopnests = [] equiv_set = self.array_analysis.get_equiv_set(label) in_arr_dim_sizes = equiv_set.get_shape(in_arr) assert ndims == len(in_arr_dim_sizes) for i in range(ndims): last_ind = self._get_stencil_last_ind(in_arr_dim_sizes[i], end_lengths[i], gen_nodes, scope, loc) start_ind = self._get_stencil_start_ind( start_lengths[i], gen_nodes, scope, loc) # start from stencil size to avoid invalid array access loopnests.append(numba.parfors.parfor.LoopNest(parfor_vars[i], start_ind, last_ind, 1)) # We have to guarantee that the exit block has maximum label and that # there's only one exit block for the parfor body. # So, all return statements will change to jump to the parfor exit block. parfor_body_exit_label = max(stencil_blocks.keys()) + 1 stencil_blocks[parfor_body_exit_label] = ir.Block(scope, loc) exit_value_var = ir.Var(scope, mk_unique_var("$parfor_exit_value"), loc) self.typemap[exit_value_var.name] = return_type.dtype # create parfor index var for_replacing_ret = [] if ndims == 1: parfor_ind_var = parfor_vars[0] else: parfor_ind_var = ir.Var(scope, mk_unique_var( "$parfor_index_tuple_var"), loc) self.typemap[parfor_ind_var.name] = types.containers.UniTuple( types.intp, ndims) tuple_call = ir.Expr.build_tuple(parfor_vars, loc) tuple_assign = ir.Assign(tuple_call, parfor_ind_var, loc) for_replacing_ret.append(tuple_assign) if config.DEBUG_ARRAY_OPT >= 1: print("stencil_blocks after creating parfor index var") ir_utils.dump_blocks(stencil_blocks) # empty init block init_block = ir.Block(scope, loc) if out_arr == None: in_arr_typ = self.typemap[in_arr.name] shape_name = ir_utils.mk_unique_var("in_arr_shape") shape_var = ir.Var(scope, shape_name, loc) shape_getattr = ir.Expr.getattr(in_arr, "shape", loc) self.typemap[shape_name] = types.containers.UniTuple(types.intp, in_arr_typ.ndim) init_block.body.extend([ir.Assign(shape_getattr, shape_var, loc)]) zero_name = ir_utils.mk_unique_var("zero_val") zero_var = ir.Var(scope, zero_name, loc) if "cval" in stencil_func.options: cval = stencil_func.options["cval"] # TODO: Loosen this restriction to adhere to casting rules. if return_type.dtype != typing.typeof.typeof(cval): raise ValueError("cval type does not match stencil return type.") temp2 = return_type.dtype(cval) else: temp2 = return_type.dtype(0) full_const = ir.Const(temp2, loc) self.typemap[zero_name] = return_type.dtype init_block.body.extend([ir.Assign(full_const, zero_var, loc)]) so_name = ir_utils.mk_unique_var("stencil_output") out_arr = ir.Var(scope, so_name, loc) self.typemap[out_arr.name] = numba.core.types.npytypes.Array( return_type.dtype, in_arr_typ.ndim, in_arr_typ.layout) dtype_g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc) self.typemap[dtype_g_np_var.name] = types.misc.Module(np) dtype_g_np = ir.Global('np', np, loc) dtype_g_np_assign = ir.Assign(dtype_g_np, dtype_g_np_var, loc) init_block.body.append(dtype_g_np_assign) dtype_np_attr_call = ir.Expr.getattr(dtype_g_np_var, return_type.dtype.name, loc) dtype_attr_var = ir.Var(scope, mk_unique_var("$np_attr_attr"), loc) self.typemap[dtype_attr_var.name] = types.functions.NumberClass(return_type.dtype) dtype_attr_assign = ir.Assign(dtype_np_attr_call, dtype_attr_var, loc) init_block.body.append(dtype_attr_assign) stmts = ir_utils.gen_np_call("full", np.full, out_arr, [shape_var, zero_var, dtype_attr_var], self.typingctx, self.typemap, self.calltypes) equiv_set.insert_equiv(out_arr, in_arr_dim_sizes) init_block.body.extend(stmts) else: # out is present if "cval" in stencil_func.options: # do out[:] = cval cval = stencil_func.options["cval"] # TODO: Loosen this restriction to adhere to casting rules. cval_ty = typing.typeof.typeof(cval) if not self.typingctx.can_convert(cval_ty, return_type.dtype): msg = "cval type does not match stencil return type." raise ValueError(msg) # get slice ref slice_var = ir.Var(scope, mk_unique_var("$py_g_var"), loc) slice_fn_ty = self.typingctx.resolve_value_type(slice) self.typemap[slice_var.name] = slice_fn_ty slice_g = ir.Global('slice', slice, loc) slice_assigned = ir.Assign(slice_g, slice_var, loc) init_block.body.append(slice_assigned) sig = self.typingctx.resolve_function_type(slice_fn_ty, (types.none,) * 2, {}) callexpr = ir.Expr.call(func=slice_var, args=(), kws=(), loc=loc) self.calltypes[callexpr] = sig slice_inst_var = ir.Var(scope, mk_unique_var("$slice_inst"), loc) self.typemap[slice_inst_var.name] = types.slice2_type slice_assign = ir.Assign(callexpr, slice_inst_var, loc) init_block.body.append(slice_assign) # get const val for cval cval_const_val = ir.Const(return_type.dtype(cval), loc) cval_const_var = ir.Var(scope, mk_unique_var("$cval_const"), loc) self.typemap[cval_const_var.name] = return_type.dtype cval_const_assign = ir.Assign(cval_const_val, cval_const_var, loc) init_block.body.append(cval_const_assign) # do setitem on `out` array setitemexpr = ir.StaticSetItem(out_arr, slice(None, None), slice_inst_var, cval_const_var, loc) init_block.body.append(setitemexpr) sig = signature(types.none, self.typemap[out_arr.name], self.typemap[slice_inst_var.name], self.typemap[out_arr.name].dtype) self.calltypes[setitemexpr] = sig self.replace_return_with_setitem(stencil_blocks, exit_value_var, parfor_body_exit_label) if config.DEBUG_ARRAY_OPT >= 1: print("stencil_blocks after replacing return") ir_utils.dump_blocks(stencil_blocks) setitem_call = ir.SetItem(out_arr, parfor_ind_var, exit_value_var, loc) self.calltypes[setitem_call] = signature( types.none, self.typemap[out_arr.name], self.typemap[parfor_ind_var.name], self.typemap[out_arr.name].dtype ) stencil_blocks[parfor_body_exit_label].body.extend(for_replacing_ret) stencil_blocks[parfor_body_exit_label].body.append(setitem_call) # simplify CFG of parfor body (exit block could be simplified often) # add dummy return to enable CFG dummy_loc = ir.Loc("stencilparfor_dummy", -1) ret_const_var = ir.Var(scope, mk_unique_var("$cval_const"), dummy_loc) cval_const_assign = ir.Assign(ir.Const(0, loc=dummy_loc), ret_const_var, dummy_loc) stencil_blocks[parfor_body_exit_label].body.append(cval_const_assign) stencil_blocks[parfor_body_exit_label].body.append( ir.Return(ret_const_var, dummy_loc), ) stencil_blocks = ir_utils.simplify_CFG(stencil_blocks) stencil_blocks[max(stencil_blocks.keys())].body.pop() if config.DEBUG_ARRAY_OPT >= 1: print("stencil_blocks after adding SetItem") ir_utils.dump_blocks(stencil_blocks) pattern = ('stencil', [start_lengths, end_lengths]) parfor = numba.parfors.parfor.Parfor(loopnests, init_block, stencil_blocks, loc, parfor_ind_var, equiv_set, pattern, self.flags) gen_nodes.append(parfor) gen_nodes.append(ir.Assign(out_arr, target, loc)) return gen_nodes
def _add_index_offsets(self, index_list, index_offsets, new_body, scope, loc): """ Does the actual work of adding loop index variables to the relative index constants or variables. """ assert len(index_list) == len(index_offsets) # shortcut if all values are integer if all([isinstance(v, int) for v in index_list+index_offsets]): # add offsets in all dimensions return list(map(add, index_list, index_offsets)) out_nodes = [] index_vars = [] for i in range(len(index_list)): # new_index = old_index + offset old_index_var = index_list[i] if isinstance(old_index_var, int): old_index_var = ir.Var(scope, mk_unique_var("old_index_var"), loc) self.typemap[old_index_var.name] = types.intp const_assign = ir.Assign(ir.Const(index_list[i], loc), old_index_var, loc) out_nodes.append(const_assign) offset_var = index_offsets[i] if isinstance(offset_var, int): offset_var = ir.Var(scope, mk_unique_var("offset_var"), loc) self.typemap[offset_var.name] = types.intp const_assign = ir.Assign(ir.Const(index_offsets[i], loc), offset_var, loc) out_nodes.append(const_assign) if (isinstance(old_index_var, slice) or isinstance(self.typemap[old_index_var.name], types.misc.SliceType)): # only one arg can be slice assert self.typemap[offset_var.name] == types.intp index_var = self._add_offset_to_slice(old_index_var, offset_var, out_nodes, scope, loc) index_vars.append(index_var) continue if (isinstance(offset_var, slice) or isinstance(self.typemap[offset_var.name], types.misc.SliceType)): # only one arg can be slice assert self.typemap[old_index_var.name] == types.intp index_var = self._add_offset_to_slice(offset_var, old_index_var, out_nodes, scope, loc) index_vars.append(index_var) continue index_var = ir.Var(scope, mk_unique_var("offset_stencil_index"), loc) self.typemap[index_var.name] = types.intp index_call = ir.Expr.binop(operator.add, old_index_var, offset_var, loc) self.calltypes[index_call] = self.typingctx.resolve_function_type( operator.add, (types.intp, types.intp), {}) index_assign = ir.Assign(index_call, index_var, loc) out_nodes.append(index_assign) index_vars.append(index_var) new_body.extend(out_nodes) return index_vars
def dead_branch_prune(func_ir, called_args): """ Removes dead branches based on constant inference from function args. This directly mutates the IR. func_ir is the IR called_args are the actual arguments with which the function is called """ from numba.core.ir_utils import (get_definition, guard, find_const, GuardException) DEBUG = 0 def find_branches(func_ir): # find *all* branches branches = [] for blk in func_ir.blocks.values(): branch_or_jump = blk.body[-1] if isinstance(branch_or_jump, ir.Branch): branch = branch_or_jump pred = guard(get_definition, func_ir, branch.cond.name) if pred is not None and pred.op == "call": function = guard(get_definition, func_ir, pred.func) if (function is not None and isinstance(function, ir.Global) and function.value is bool): condition = guard(get_definition, func_ir, pred.args[0]) if condition is not None: branches.append((branch, condition, blk)) return branches def do_prune(take_truebr, blk): keep = branch.truebr if take_truebr else branch.falsebr # replace the branch with a direct jump jmp = ir.Jump(keep, loc=branch.loc) blk.body[-1] = jmp return 1 if keep == branch.truebr else 0 def prune_by_type(branch, condition, blk, *conds): # this prunes a given branch and fixes up the IR # at least one needs to be a NoneType lhs_cond, rhs_cond = conds lhs_none = isinstance(lhs_cond, types.NoneType) rhs_none = isinstance(rhs_cond, types.NoneType) if lhs_none or rhs_none: try: take_truebr = condition.fn(lhs_cond, rhs_cond) except Exception: return False, None if DEBUG > 0: kill = branch.falsebr if take_truebr else branch.truebr print("Pruning %s" % kill, branch, lhs_cond, rhs_cond, condition.fn) taken = do_prune(take_truebr, blk) return True, taken return False, None def prune_by_value(branch, condition, blk, *conds): lhs_cond, rhs_cond = conds try: take_truebr = condition.fn(lhs_cond, rhs_cond) except Exception: return False, None if DEBUG > 0: kill = branch.falsebr if take_truebr else branch.truebr print("Pruning %s" % kill, branch, lhs_cond, rhs_cond, condition.fn) taken = do_prune(take_truebr, blk) return True, taken def prune_by_predicate(branch, pred, blk): try: # Just to prevent accidents, whilst already guarded, ensure this # is an ir.Const if not isinstance(pred, (ir.Const, ir.FreeVar, ir.Global)): raise TypeError('Expected constant Numba IR node') take_truebr = bool(pred.value) except TypeError: return False, None if DEBUG > 0: kill = branch.falsebr if take_truebr else branch.truebr print("Pruning %s" % kill, branch, pred) taken = do_prune(take_truebr, blk) return True, taken class Unknown(object): pass def resolve_input_arg_const(input_arg_idx): """ Resolves an input arg to a constant (if possible) """ input_arg_ty = called_args[input_arg_idx] # comparing to None? if isinstance(input_arg_ty, types.NoneType): return input_arg_ty # is it a kwarg default if isinstance(input_arg_ty, types.Omitted): val = input_arg_ty.value if isinstance(val, types.NoneType): return val elif val is None: return types.NoneType('none') # literal type, return the type itself so comparisons like `x == None` # still work as e.g. x = types.int64 will never be None/NoneType so # the branch can still be pruned return getattr(input_arg_ty, 'literal_type', Unknown()) if DEBUG > 1: print("before".center(80, '-')) print(func_ir.dump()) # This looks for branches where: # at least one arg of the condition is in input args and const # at least one an arg of the condition is a const # if the condition is met it will replace the branch with a jump branch_info = find_branches(func_ir) # stores conditions that have no impact post prune nullified_conditions = [] for branch, condition, blk in branch_info: const_conds = [] if isinstance(condition, ir.Expr) and condition.op == 'binop': prune = prune_by_value for arg in [condition.lhs, condition.rhs]: resolved_const = Unknown() arg_def = guard(get_definition, func_ir, arg) if isinstance(arg_def, ir.Arg): # it's an e.g. literal argument to the function resolved_const = resolve_input_arg_const(arg_def.index) prune = prune_by_type else: # it's some const argument to the function, cannot use guard # here as the const itself may be None try: resolved_const = find_const(func_ir, arg) if resolved_const is None: resolved_const = types.NoneType('none') except GuardException: pass if not isinstance(resolved_const, Unknown): const_conds.append(resolved_const) # lhs/rhs are consts if len(const_conds) == 2: # prune the branch, switch the branch for an unconditional jump prune_stat, taken = prune(branch, condition, blk, *const_conds) if (prune_stat): # add the condition to the list of nullified conditions nullified_conditions.append( nullified(condition, taken, True)) else: # see if this is a branch on a constant value predicate resolved_const = Unknown() try: pred_call = get_definition(func_ir, branch.cond) resolved_const = find_const(func_ir, pred_call.args[0]) if resolved_const is None: resolved_const = types.NoneType('none') except GuardException: pass if not isinstance(resolved_const, Unknown): prune_stat, taken = prune_by_predicate(branch, condition, blk) if (prune_stat): # add the condition to the list of nullified conditions nullified_conditions.append( nullified(condition, taken, False)) # 'ERE BE DRAGONS... # It is the evaluation of the condition expression that often trips up type # inference, so ideally it would be removed as it is effectively rendered # dead by the unconditional jump if a branch was pruned. However, there may # be references to the condition that exist in multiple places (e.g. dels) # and we cannot run DCE here as typing has not taken place to give enough # information to run DCE safely. Upshot of all this is the condition gets # rewritten below into a benign const that typing will be happy with and DCE # can remove it and its reference post typing when it is safe to do so # (if desired). It is required that the const is assigned a value that # indicates the branch taken as its mutated value would be read in the case # of object mode fall back in place of the condition itself. For # completeness the func_ir._definitions and ._consts are also updated to # make the IR state self consistent. deadcond = [x.condition for x in nullified_conditions] for _, cond, blk in branch_info: if cond in deadcond: for x in blk.body: if isinstance(x, ir.Assign) and x.value is cond: # rewrite the condition as a true/false bit nullified_info = nullified_conditions[deadcond.index(cond)] # only do a rewrite of conditions, predicates need to retain # their value as they may be used later. if nullified_info.rewrite_stmt: branch_bit = nullified_info.taken_br x.value = ir.Const(branch_bit, loc=x.loc) # update the specific definition to the new const defns = func_ir._definitions[x.target.name] repl_idx = defns.index(cond) defns[repl_idx] = x.value # Remove dead blocks, this is safe as it relies on the CFG only. cfg = compute_cfg_from_blocks(func_ir.blocks) for dead in cfg.dead_nodes(): del func_ir.blocks[dead] # if conditions were nullified then consts were rewritten, update if nullified_conditions: func_ir._consts = consts.ConstantInference(func_ir) if DEBUG > 1: print("after".center(80, '-')) print(func_ir.dump())
def run(self): """Finds all calls to StencilFuncs in the IR and converts them to parfor.""" from numba.stencils.stencil import StencilFunc # Get all the calls in the function IR. call_table, _ = get_call_table(self.func_ir.blocks) stencil_calls = [] stencil_dict = {} for call_varname, call_list in call_table.items(): for one_call in call_list: if isinstance(one_call, StencilFunc): # Remember all calls to StencilFuncs. stencil_calls.append(call_varname) stencil_dict[call_varname] = one_call if not stencil_calls: return # return early if no stencil calls found # find and transform stencil calls for label, block in self.func_ir.blocks.items(): for i, stmt in reversed(list(enumerate(block.body))): # Found a call to a StencilFunc. if (isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr) and stmt.value.op == "call" and stmt.value.func.name in stencil_calls): kws = dict(stmt.value.kws) # Create dictionary of input argument number to # the argument itself. input_dict = { i: stmt.value.args[i] for i in range(len(stmt.value.args)) } in_args = stmt.value.args arg_typemap = tuple(self.typemap[i.name] for i in in_args) for arg_type in arg_typemap: if isinstance(arg_type, types.BaseTuple): raise ValueError( "Tuple parameters not supported " "for stencil kernels in parallel=True mode.") out_arr = kws.get("out") # Get the StencilFunc object corresponding to this call. sf = stencil_dict[stmt.value.func.name] stencil_ir, rt, arg_to_arr_dict = get_stencil_ir( sf, self.typingctx, arg_typemap, block.scope, block.loc, input_dict, self.typemap, self.calltypes, ) index_offsets = sf.options.get("index_offsets", None) gen_nodes = self._mk_stencil_parfor( label, in_args, out_arr, stencil_ir, index_offsets, stmt.target, rt, sf, arg_to_arr_dict, ) block.body = block.body[:i] + gen_nodes + block.body[i + 1:] # Found a call to a stencil via numba.stencil(). elif (isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr) and stmt.value.op == "call" and guard(find_callname, self.func_ir, stmt.value) == ("stencil", "numba")): # remove dummy stencil() call stmt.value = ir.Const(0, stmt.loc)
def add_indices_to_kernel(self, kernel, index_names, ndim, neighborhood, standard_indexed, typemap, calltypes): """ Transforms the stencil kernel as specified by the user into one that includes each dimension's index variable as part of the getitem calls. So, in effect array[-1] becomes array[index0-1]. """ const_dict = {} kernel_consts = [] if config.DEBUG_ARRAY_OPT >= 1: print("add_indices_to_kernel", ndim, neighborhood) ir_utils.dump_blocks(kernel.blocks) if neighborhood is None: need_to_calc_kernel = True else: need_to_calc_kernel = False if len(neighborhood) != ndim: raise ValueError("%d dimensional neighborhood specified for %d " \ "dimensional input array" % (len(neighborhood), ndim)) tuple_table = ir_utils.get_tuple_table(kernel.blocks) relatively_indexed = set() for block in kernel.blocks.values(): scope = block.scope loc = block.loc new_body = [] for stmt in block.body: if (isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Const)): if config.DEBUG_ARRAY_OPT >= 1: print("remembering in const_dict", stmt.target.name, stmt.value.value) # Remember consts for use later. const_dict[stmt.target.name] = stmt.value.value if ((isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr) and stmt.value.op in ['setitem', 'static_setitem'] and stmt.value.value.name in kernel.arg_names) or (isinstance(stmt, ir.SetItem) and stmt.target.name in kernel.arg_names)): raise ValueError("Assignments to arrays passed to stencil " \ "kernels is not allowed.") if (isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr) and stmt.value.op in ['getitem', 'static_getitem'] and stmt.value.value.name in kernel.arg_names and stmt.value.value.name not in standard_indexed): # We found a getitem from the input array. if stmt.value.op == 'getitem': stmt_index_var = stmt.value.index else: stmt_index_var = stmt.value.index_var # allow static_getitem since rewrite passes are applied #raise ValueError("Unexpected static_getitem in add_indices_to_kernel.") relatively_indexed.add(stmt.value.value.name) # Store the index used after looking up the variable in # the const dictionary. if need_to_calc_kernel: assert hasattr(stmt_index_var, 'name') if stmt_index_var.name in tuple_table: kernel_consts += [tuple_table[stmt_index_var.name]] elif stmt_index_var.name in const_dict: kernel_consts += [const_dict[stmt_index_var.name]] else: raise ValueError( "stencil kernel index is not " "constant, 'neighborhood' option required") if ndim == 1: # Single dimension always has index variable 'index0'. # tmpvar will hold the real index and is computed by # adding the relative offset in stmt.value.index to # the current absolute location in index0. index_var = ir.Var(scope, index_names[0], loc) tmpname = ir_utils.mk_unique_var("stencil_index") tmpvar = ir.Var(scope, tmpname, loc) stmt_index_var_typ = typemap[stmt_index_var.name] # If the array is indexed with a slice then we # have to add the index value with a call to # slice_addition. if isinstance(stmt_index_var_typ, types.misc.SliceType): sa_var = ir.Var( scope, ir_utils.mk_unique_var("slice_addition"), loc) sa_func = numba.njit(slice_addition) sa_func_typ = types.functions.Dispatcher(sa_func) typemap[sa_var.name] = sa_func_typ g_sa = ir.Global("slice_addition", sa_func, loc) new_body.append(ir.Assign(g_sa, sa_var, loc)) slice_addition_call = ir.Expr.call( sa_var, [stmt_index_var, index_var], (), loc) calltypes[ slice_addition_call] = sa_func_typ.get_call_type( self._typingctx, [stmt_index_var_typ, types.intp], {}) new_body.append( ir.Assign(slice_addition_call, tmpvar, loc)) new_body.append( ir.Assign( ir.Expr.getitem(stmt.value.value, tmpvar, loc), stmt.target, loc)) else: acc_call = ir.Expr.binop(operator.add, stmt_index_var, index_var, loc) new_body.append(ir.Assign(acc_call, tmpvar, loc)) new_body.append( ir.Assign( ir.Expr.getitem(stmt.value.value, tmpvar, loc), stmt.target, loc)) else: index_vars = [] sum_results = [] s_index_name = ir_utils.mk_unique_var("stencil_index") s_index_var = ir.Var(scope, s_index_name, loc) const_index_vars = [] ind_stencils = [] stmt_index_var_typ = typemap[stmt_index_var.name] # Same idea as above but you have to extract # individual elements out of the tuple indexing # expression and add the corresponding index variable # to them and then reconstitute as a tuple that can # index the array. for dim in range(ndim): tmpname = ir_utils.mk_unique_var("const_index") tmpvar = ir.Var(scope, tmpname, loc) new_body.append( ir.Assign(ir.Const(dim, loc), tmpvar, loc)) const_index_vars += [tmpvar] index_var = ir.Var(scope, index_names[dim], loc) index_vars += [index_var] tmpname = ir_utils.mk_unique_var( "ind_stencil_index") tmpvar = ir.Var(scope, tmpname, loc) ind_stencils += [tmpvar] getitemname = ir_utils.mk_unique_var("getitem") getitemvar = ir.Var(scope, getitemname, loc) getitemcall = ir.Expr.getitem( stmt_index_var, const_index_vars[dim], loc) new_body.append( ir.Assign(getitemcall, getitemvar, loc)) # Get the type of this particular part of the index tuple. if isinstance(stmt_index_var_typ, types.ConstSized): one_index_typ = stmt_index_var_typ[dim] else: one_index_typ = stmt_index_var_typ[:] # If the array is indexed with a slice then we # have to add the index value with a call to # slice_addition. if isinstance(one_index_typ, types.misc.SliceType): sa_var = ir.Var( scope, ir_utils.mk_unique_var("slice_addition"), loc) sa_func = numba.njit(slice_addition) sa_func_typ = types.functions.Dispatcher( sa_func) typemap[sa_var.name] = sa_func_typ g_sa = ir.Global("slice_addition", sa_func, loc) new_body.append(ir.Assign(g_sa, sa_var, loc)) slice_addition_call = ir.Expr.call( sa_var, [getitemvar, index_vars[dim]], (), loc) calltypes[ slice_addition_call] = sa_func_typ.get_call_type( self._typingctx, [one_index_typ, types.intp], {}) new_body.append( ir.Assign(slice_addition_call, tmpvar, loc)) else: acc_call = ir.Expr.binop( operator.add, getitemvar, index_vars[dim], loc) new_body.append( ir.Assign(acc_call, tmpvar, loc)) tuple_call = ir.Expr.build_tuple(ind_stencils, loc) new_body.append(ir.Assign(tuple_call, s_index_var, loc)) new_body.append( ir.Assign( ir.Expr.getitem(stmt.value.value, s_index_var, loc), stmt.target, loc)) else: new_body.append(stmt) block.body = new_body if need_to_calc_kernel: # Find the size of the kernel by finding the maximum absolute value # index used in the kernel specification. neighborhood = [[0, 0] for _ in range(ndim)] if len(kernel_consts) == 0: raise ValueError("Stencil kernel with no accesses to " "relatively indexed arrays.") for index in kernel_consts: if isinstance(index, tuple) or isinstance(index, list): for i in range(len(index)): te = index[i] if isinstance(te, ir.Var) and te.name in const_dict: te = const_dict[te.name] if isinstance(te, int): neighborhood[i][0] = min(neighborhood[i][0], te) neighborhood[i][1] = max(neighborhood[i][1], te) else: raise ValueError( "stencil kernel index is not constant," "'neighborhood' option required") index_len = len(index) elif isinstance(index, int): neighborhood[0][0] = min(neighborhood[0][0], index) neighborhood[0][1] = max(neighborhood[0][1], index) index_len = 1 else: raise ValueError( "Non-tuple or non-integer used as stencil index.") if index_len != ndim: raise ValueError( "Stencil index does not match array dimensionality.") return (neighborhood, relatively_indexed)
def test_const(self): a = ir.Const(1, self.loc1) b = ir.Const(1, self.loc1) c = ir.Const(1, self.loc2) d = ir.Const(2, self.loc1) self.check(a, same=[b, c], different=[d])
def _fix_multi_exit_blocks(func_ir, exit_nodes, *, split_condition=None): """Modify the FunctionIR to create a single common exit node given the original exit nodes. Parameters ---------- func_ir : The FunctionIR. Mutated inplace. exit_nodes : The original exit nodes. A sequence of block keys. split_condition : callable or None If not None, it is a callable with the signature `split_condition(statement)` that determines if the `statement` is the splitting point (e.g. `POP_BLOCK`) in an exit node. If it's None, the exit node is not split. """ # Convert the following: # # | | # +-------+ +-------+ # | exit0 | | exit1 | # +-------+ +-------+ # | | # +-------+ +-------+ # | after0| | after1| # +-------+ +-------+ # | | # # To roughly: # # | | # +-------+ +-------+ # | exit0 | | exit1 | # +-------+ +-------+ # | | # +-----+-----+ # | # +---------+ # | common | # +---------+ # | # +-------+ # | post | # +-------+ # | # +-----+-----+ # | | # +-------+ +-------+ # | after0| | after1| # +-------+ +-------+ blocks = func_ir.blocks # Getting the scope any_blk = min(func_ir.blocks.values()) scope = any_blk.scope # Getting the maximum block label max_label = max(func_ir.blocks) + 1 # Define the new common block for the new exit. common_block = ir.Block(any_blk.scope, loc=ir.unknown_loc) common_label = max_label max_label += 1 blocks[common_label] = common_block # Define the new block after the exit. post_block = ir.Block(any_blk.scope, loc=ir.unknown_loc) post_label = max_label max_label += 1 blocks[post_label] = post_block # Adjust each exit node remainings = [] for i, k in enumerate(exit_nodes): blk = blocks[k] # split the block if needed if split_condition is not None: for pt, stmt in enumerate(blk.body): if split_condition(stmt): break else: # no splitting pt = -1 before = blk.body[:pt] after = blk.body[pt:] remainings.append(after) # Add control-point variable to mark which exit block this is. blk.body = before loc = blk.loc blk.body.append( ir.Assign(value=ir.Const(i, loc=loc), target=scope.get_or_define("$cp", loc=loc), loc=loc)) # Replace terminator with a jump to the common block assert not blk.is_terminated blk.body.append(ir.Jump(common_label, loc=ir.unknown_loc)) if split_condition is not None: # Move the splitting statement to the common block common_block.body.append(remainings[0][0]) assert not common_block.is_terminated # Append jump from common block to post block common_block.body.append(ir.Jump(post_label, loc=loc)) # Make if-else tree to jump to target remain_blocks = [] for remain in remainings: remain_blocks.append(max_label) max_label += 1 switch_block = post_block loc = ir.unknown_loc for i, remain in enumerate(remainings): match_expr = scope.redefine("$cp_check", loc=loc) match_rhs = scope.redefine("$cp_rhs", loc=loc) # Do comparison to match control-point variable to the exit block switch_block.body.append( ir.Assign(value=ir.Const(i, loc=loc), target=match_rhs, loc=loc), ) # Add assignment for the comparison switch_block.body.append( ir.Assign(value=ir.Expr.binop( fn=operator.eq, lhs=scope.get("$cp"), rhs=match_rhs, loc=loc, ), target=match_expr, loc=loc), ) # Insert jump to the next case [jump_target] = remain[-1].get_targets() switch_block.body.append( ir.Branch(match_expr, jump_target, remain_blocks[i], loc=loc), ) switch_block = ir.Block(scope=scope, loc=loc) blocks[remain_blocks[i]] = switch_block # Add the final jump switch_block.body.append(ir.Jump(jump_target, loc=loc)) return func_ir, common_label
def op_BUILD_CONST_KEY_MAP(self, inst, keys, keytmps, values, res): # Unpack the constant key-tuple and reused build_map which takes # a sequence of (key, value) pair. keyvar = self.get(keys) # TODO: refactor this pattern. occurred several times. for inst in self.current_block.body: if isinstance(inst, ir.Assign) and inst.target is keyvar: self.current_block.remove(inst) # scan up the block looking for the values, remove them # and find their name strings named_items = [] for x in inst.value.items: for y in self.current_block.body[::-1]: if x == y.target: self.current_block.remove(y) named_items.append(y.value.value) break keytup = named_items break assert len(keytup) == len(values) keyconsts = [ir.Const(value=x, loc=self.loc) for x in keytup] for kval, tmp in zip(keyconsts, keytmps): self.store(kval, tmp) items = list(zip(map(self.get, keytmps), map(self.get, values))) # sort out literal values literal_items = [] for v in values: defns = self.definitions[v] if len(defns) != 1: break defn = defns[0] if not isinstance(defn, ir.Const): break literal_items.append(defn.value) def resolve_const(v): defns = self.definitions[v] if len(defns) != 1: return _UNKNOWN_VALUE(self.get(v).name) defn = defns[0] if not isinstance(defn, ir.Const): return _UNKNOWN_VALUE(self.get(v).name) return defn.value if len(literal_items) != len(values): literal_dict = {x: resolve_const(y) for x, y in zip(keytup, values)} else: literal_dict = {x:y for x, y in zip(keytup, literal_items)} # to deal with things like {'a': 1, 'a': 'cat', 'b': 2, 'a': 2j} # store the index of the actual used value for a given key, this is # used when lowering to pull the right value out into the tuple repr # of a mixed value type dictionary. value_indexes = {} for i, k in enumerate(keytup): value_indexes[k] = i expr = ir.Expr.build_map(items=items, size=2, literal_value=literal_dict, value_indexes=value_indexes, loc=self.loc) self.store(expr, res)