Ejemplo n.º 1
0
    def test_inline_call_branch_pruning(self):
        # branch pruning pass should run properly in inlining to enable
        # functions with type checks
        @njit
        def foo(A=None):
            if A is None:
                return 2
            else:
                return A

        def test_impl(A=None):
            return foo(A)

        @register_pass(analysis_only=False, mutates_CFG=True)
        class PruningInlineTestPass(FunctionPass):
            _name = "pruning_inline_test_pass"

            def __init__(self):
                FunctionPass.__init__(self)

            def run_pass(self, state):
                # assuming the function has one block with one call inside
                assert len(state.func_ir.blocks) == 1
                block = list(state.func_ir.blocks.values())[0]
                for i, stmt in enumerate(block.body):
                    if guard(find_callname, state.func_ir,
                             stmt.value) is not None:
                        inline_closure_call(
                            state.func_ir,
                            {},
                            block,
                            i,
                            foo.py_func,
                            state.typingctx,
                            (state.type_annotation.typemap[
                                stmt.value.args[0].name], ),
                            state.type_annotation.typemap,
                            state.calltypes,
                        )
                        break
                return True

        class InlineTestPipelinePrune(compiler.CompilerBase):
            def define_pipelines(self):
                pm = gen_pipeline(self.state, PruningInlineTestPass)
                pm.finalize()
                return [pm]

        # make sure inline_closure_call runs in full pipeline
        j_func = njit(pipeline_class=InlineTestPipelinePrune)(test_impl)
        A = 3
        self.assertEqual(test_impl(A), j_func(A))
        self.assertEqual(test_impl(), j_func())

        # make sure IR doesn't have branches
        fir = j_func.overloads[(
            types.Omitted(None), )].metadata["preserved_ir"]
        fir.blocks = simplify_CFG(fir.blocks)
        self.assertEqual(len(fir.blocks), 1)
Ejemplo n.º 2
0
def dufunc_inliner(func_ir, calltypes, typemap, typingctx, targetctx):
    _DEBUG = False
    modified = False

    if _DEBUG:
        print("GUFunc before inlining DUFunc".center(80, "-"))
        print(func_ir.dump())

    work_list = list(func_ir.blocks.items())
    # use a work list, look for call sites via `ir.Expr.op == call` and
    # then pass these to `self._do_work` to make decisions about inlining.
    while work_list:
        label, block = work_list.pop()
        for i, instr in enumerate(block.body):

            if isinstance(instr, ir.Assign):
                expr = instr.value
                if isinstance(expr, ir.Expr):
                    call_node = _is_dufunc_callsite(expr, block)
                    if call_node:
                        py_func = call_node.value._dispatcher.py_func
                        workfn = _inline(
                            func_ir,
                            work_list,
                            block,
                            i,
                            expr,
                            py_func,
                            typemap,
                            calltypes,
                            typingctx,
                            targetctx,
                        )
                        if workfn:
                            modified = True
                            break  # because block structure changed
                    else:
                        continue
    if _DEBUG:
        print("GUFunc after inlining DUFunc".center(80, "-"))
        print(func_ir.dump())
        print("".center(80, "-"))

    if modified:
        # clean up leftover load instructions. This step is needed or else
        # SpirvLowerer would complain
        dead_code_elimination(func_ir, typemap=typemap)
        # clean up unconditional branches that appear due to inlined
        # functions introducing blocks
        func_ir.blocks = simplify_CFG(func_ir.blocks)

    if _DEBUG:
        print("GUFunc after inlining DUFunc, DCE, SimplyCFG".center(80, "-"))
        print(func_ir.dump())
        print("".center(80, "-"))

    return True
    def run_pass(self, state):
        rewrite_ndarray_function_name_pass = RewriteNdarrayFunctions(
            state, rewrite_function_name_map)

        mutated = rewrite_ndarray_function_name_pass.run()

        if mutated:
            remove_dead(state.func_ir.blocks, state.func_ir.arg_names,
                        state.func_ir)
        state.func_ir.blocks = simplify_CFG(state.func_ir.blocks)

        return mutated
Ejemplo n.º 4
0
    def run_pass(self, state):
        """Run inlining of overloads
        """
        if self._DEBUG:
            print('before overload inline'.center(80, '-'))
            print(state.func_ir.dump())
            print(''.center(80, '-'))
        modified = False
        work_list = list(state.func_ir.blocks.items())
        # use a work list, look for call sites via `ir.Expr.op == call` and
        # then pass these to `self._do_work` to make decisions about inlining.
        while work_list:
            label, block = work_list.pop()
            for i, instr in enumerate(block.body):
                if isinstance(instr, ir.Assign):
                    expr = instr.value
                    if isinstance(expr, ir.Expr):
                        if expr.op == 'call':
                            workfn = self._do_work_call
                        elif expr.op == 'getattr':
                            workfn = self._do_work_getattr
                        else:
                            continue

                        if guard(workfn, state, work_list, block, i, expr):
                            modified = True
                            break  # because block structure changed

        if self._DEBUG:
            print('after overload inline'.center(80, '-'))
            print(state.func_ir.dump())
            print(''.center(80, '-'))

        if modified:
            # clean up blocks
            dead_code_elimination(state.func_ir,
                                  typemap=state.type_annotation.typemap)
            # clean up unconditional branches that appear due to inlined
            # functions introducing blocks
            state.func_ir.blocks = simplify_CFG(state.func_ir.blocks)

        if self._DEBUG:
            print('after overload inline DCE'.center(80, '-'))
            print(state.func_ir.dump())
            print(''.center(80, '-'))

        return True
Ejemplo n.º 5
0
def inline_new_blocks(func_ir, block, i, callee_blocks, work_list=None):
    # adopted from inline_closure_call
    scope = block.scope
    instr = block.body[i]

    # 1. relabel callee_ir by adding an offset
    callee_blocks = add_offset_to_labels(callee_blocks,
                                         ir_utils._max_label + 1)
    callee_blocks = ir_utils.simplify_CFG(callee_blocks)
    max_label = max(callee_blocks.keys())
    #    reset globals in ir_utils before we use it
    ir_utils._max_label = max_label
    topo_order = find_topo_order(callee_blocks)

    # 5. split caller blocks into two
    new_blocks = []
    new_block = ir.Block(scope, block.loc)
    new_block.body = block.body[i + 1:]
    new_label = ir_utils.next_label()
    func_ir.blocks[new_label] = new_block
    new_blocks.append((new_label, new_block))
    block.body = block.body[:i]
    min_label = topo_order[0]
    block.body.append(ir.Jump(min_label, instr.loc))

    # 6. replace Return with assignment to LHS
    numba.core.inline_closurecall._replace_returns(callee_blocks, instr.target,
                                                   new_label)
    #    remove the old definition of instr.target too
    if (instr.target.name in func_ir._definitions):
        func_ir._definitions[instr.target.name] = []

    # 7. insert all new blocks, and add back definitions
    for label in topo_order:
        # block scope must point to parent's
        block = callee_blocks[label]
        block.scope = scope
        numba.core.inline_closurecall._add_definitions(func_ir, block)
        func_ir.blocks[label] = block
        new_blocks.append((label, block))

    if work_list is not None:
        for block in new_blocks:
            work_list.append(block)
    return callee_blocks
Ejemplo n.º 6
0
    def check(self, test_impl, *args, **kwargs):
        inline_expect = kwargs.pop('inline_expect', None)
        assert inline_expect
        block_count = kwargs.pop('block_count', 1)
        assert not kwargs
        for k, v in inline_expect.items():
            assert isinstance(k, str)
            assert isinstance(v, bool)

        j_func = njit(pipeline_class=IRPreservingTestPipeline)(test_impl)

        # check they produce the same answer first!
        self.assertEqual(test_impl(*args), j_func(*args))

        # make sure IR doesn't have branches
        fir = j_func.overloads[j_func.signatures[0]].metadata['preserved_ir']
        fir.blocks = ir_utils.simplify_CFG(fir.blocks)
        if self._DEBUG:
            print("FIR".center(80, "-"))
            fir.dump()
        if block_count != 'SKIP':
            self.assertEqual(len(fir.blocks), block_count)
        block = next(iter(fir.blocks.values()))

        # if we don't expect the function to be inlined then make sure there is
        # 'call' present still
        exprs = [x for x in block.find_exprs()]
        assert exprs
        for k, v in inline_expect.items():
            found = False
            for expr in exprs:
                if getattr(expr, 'op', False) == 'call':
                    func_defn = fir.get_definition(expr.func)
                    found |= func_defn.name == k
                elif ir_utils.is_operator_or_getitem(expr):
                    found |= expr.fn.__name__ == k
            self.assertFalse(found == v)

        return fir  # for use in further analysis
Ejemplo n.º 7
0
    def run_pass(self, state):
        """Run inlining of overloads
        """
        if self._DEBUG:
            print('before overload inline'.center(80, '-'))
            print(state.func_id.unique_name)
            print(state.func_ir.dump())
            print(''.center(80, '-'))
        from numba.core.inline_closurecall import (InlineWorker,
                                                   callee_ir_validator)
        inline_worker = InlineWorker(
            state.typingctx,
            state.targetctx,
            state.locals,
            state.pipeline,
            state.flags,
            callee_ir_validator,
            state.typemap,
            state.calltypes,
        )
        modified = False
        work_list = list(state.func_ir.blocks.items())
        # use a work list, look for call sites via `ir.Expr.op == call` and
        # then pass these to `self._do_work` to make decisions about inlining.
        while work_list:
            label, block = work_list.pop()
            for i, instr in enumerate(block.body):
                # TO-DO: other statements (setitem)
                if isinstance(instr, ir.Assign):
                    expr = instr.value
                    if isinstance(expr, ir.Expr):
                        workfn = self._do_work_expr

                        if guard(workfn, state, work_list, block, i, expr,
                                 inline_worker):
                            modified = True
                            break  # because block structure changed

        if self._DEBUG:
            print('after overload inline'.center(80, '-'))
            print(state.func_id.unique_name)
            print(state.func_ir.dump())
            print(''.center(80, '-'))

        if modified:
            # Remove dead blocks, this is safe as it relies on the CFG only.
            cfg = compute_cfg_from_blocks(state.func_ir.blocks)
            for dead in cfg.dead_nodes():
                del state.func_ir.blocks[dead]
            # clean up blocks
            dead_code_elimination(state.func_ir,
                                  typemap=state.type_annotation.typemap)
            # clean up unconditional branches that appear due to inlined
            # functions introducing blocks
            state.func_ir.blocks = simplify_CFG(state.func_ir.blocks)

        if self._DEBUG:
            print('after overload inline DCE'.center(80, '-'))
            print(state.func_id.unique_name)
            print(state.func_ir.dump())
            print(''.center(80, '-'))
        return True
Ejemplo n.º 8
0
    def _mk_stencil_parfor(self, label, in_args, out_arr, stencil_ir,
                           index_offsets, target, return_type, stencil_func,
                           arg_to_arr_dict):
        """ Converts a set of stencil kernel blocks to a parfor.
        """
        gen_nodes = []
        stencil_blocks = stencil_ir.blocks

        if config.DEBUG_ARRAY_OPT >= 1:
            print("_mk_stencil_parfor", label, in_args, out_arr, index_offsets,
                  return_type, stencil_func, stencil_blocks)
            ir_utils.dump_blocks(stencil_blocks)

        in_arr = in_args[0]
        # run copy propagate to replace in_args copies (e.g. a = A)
        in_arr_typ = self.typemap[in_arr.name]
        in_cps, out_cps = ir_utils.copy_propagate(stencil_blocks, self.typemap)
        name_var_table = ir_utils.get_name_var_table(stencil_blocks)

        ir_utils.apply_copy_propagate(stencil_blocks, in_cps, name_var_table,
                                      self.typemap, self.calltypes)
        if config.DEBUG_ARRAY_OPT >= 1:
            print("stencil_blocks after copy_propagate")
            ir_utils.dump_blocks(stencil_blocks)
        ir_utils.remove_dead(stencil_blocks, self.func_ir.arg_names,
                             stencil_ir, self.typemap)
        if config.DEBUG_ARRAY_OPT >= 1:
            print("stencil_blocks after removing dead code")
            ir_utils.dump_blocks(stencil_blocks)

        # create parfor vars
        ndims = self.typemap[in_arr.name].ndim
        scope = in_arr.scope
        loc = in_arr.loc
        parfor_vars = []
        for i in range(ndims):
            parfor_var = ir.Var(scope, mk_unique_var("$parfor_index_var"), loc)
            self.typemap[parfor_var.name] = types.intp
            parfor_vars.append(parfor_var)

        start_lengths, end_lengths = self._replace_stencil_accesses(
            stencil_ir, parfor_vars, in_args, index_offsets, stencil_func,
            arg_to_arr_dict)

        if config.DEBUG_ARRAY_OPT >= 1:
            print("stencil_blocks after replace stencil accesses")
            ir_utils.dump_blocks(stencil_blocks)

        # create parfor loop nests
        loopnests = []
        equiv_set = self.array_analysis.get_equiv_set(label)
        in_arr_dim_sizes = equiv_set.get_shape(in_arr)

        assert ndims == len(in_arr_dim_sizes)
        for i in range(ndims):
            last_ind = self._get_stencil_last_ind(in_arr_dim_sizes[i],
                                                  end_lengths[i], gen_nodes,
                                                  scope, loc)
            start_ind = self._get_stencil_start_ind(start_lengths[i],
                                                    gen_nodes, scope, loc)
            # start from stencil size to avoid invalid array access
            loopnests.append(
                numba.parfors.parfor.LoopNest(parfor_vars[i], start_ind,
                                              last_ind, 1))

        # We have to guarantee that the exit block has maximum label and that
        # there's only one exit block for the parfor body.
        # So, all return statements will change to jump to the parfor exit block.
        parfor_body_exit_label = max(stencil_blocks.keys()) + 1
        stencil_blocks[parfor_body_exit_label] = ir.Block(scope, loc)
        exit_value_var = ir.Var(scope, mk_unique_var("$parfor_exit_value"),
                                loc)
        self.typemap[exit_value_var.name] = return_type.dtype

        # create parfor index var
        for_replacing_ret = []
        if ndims == 1:
            parfor_ind_var = parfor_vars[0]
        else:
            parfor_ind_var = ir.Var(scope,
                                    mk_unique_var("$parfor_index_tuple_var"),
                                    loc)
            self.typemap[parfor_ind_var.name] = types.containers.UniTuple(
                types.intp, ndims)
            tuple_call = ir.Expr.build_tuple(parfor_vars, loc)
            tuple_assign = ir.Assign(tuple_call, parfor_ind_var, loc)
            for_replacing_ret.append(tuple_assign)

        if config.DEBUG_ARRAY_OPT >= 1:
            print("stencil_blocks after creating parfor index var")
            ir_utils.dump_blocks(stencil_blocks)

        # empty init block
        init_block = ir.Block(scope, loc)
        if out_arr is None:
            in_arr_typ = self.typemap[in_arr.name]

            shape_name = ir_utils.mk_unique_var("in_arr_shape")
            shape_var = ir.Var(scope, shape_name, loc)
            shape_getattr = ir.Expr.getattr(in_arr, "shape", loc)
            self.typemap[shape_name] = types.containers.UniTuple(
                types.intp, in_arr_typ.ndim)
            init_block.body.extend([ir.Assign(shape_getattr, shape_var, loc)])

            zero_name = ir_utils.mk_unique_var("zero_val")
            zero_var = ir.Var(scope, zero_name, loc)
            if "cval" in stencil_func.options:
                cval = stencil_func.options["cval"]
                # TODO: Loosen this restriction to adhere to casting rules.
                if return_type.dtype != typing.typeof.typeof(cval):
                    raise ValueError(
                        "cval type does not match stencil return type.")

                temp2 = return_type.dtype(cval)
            else:
                temp2 = return_type.dtype(0)
            full_const = ir.Const(temp2, loc)
            self.typemap[zero_name] = return_type.dtype
            init_block.body.extend([ir.Assign(full_const, zero_var, loc)])

            so_name = ir_utils.mk_unique_var("stencil_output")
            out_arr = ir.Var(scope, so_name, loc)
            self.typemap[out_arr.name] = numba.core.types.npytypes.Array(
                return_type.dtype, in_arr_typ.ndim, in_arr_typ.layout)
            dtype_g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc)
            self.typemap[dtype_g_np_var.name] = types.misc.Module(np)
            dtype_g_np = ir.Global('np', np, loc)
            dtype_g_np_assign = ir.Assign(dtype_g_np, dtype_g_np_var, loc)
            init_block.body.append(dtype_g_np_assign)

            dtype_np_attr_call = ir.Expr.getattr(dtype_g_np_var,
                                                 return_type.dtype.name, loc)
            dtype_attr_var = ir.Var(scope, mk_unique_var("$np_attr_attr"), loc)
            self.typemap[dtype_attr_var.name] = types.functions.NumberClass(
                return_type.dtype)
            dtype_attr_assign = ir.Assign(dtype_np_attr_call, dtype_attr_var,
                                          loc)
            init_block.body.append(dtype_attr_assign)

            stmts = ir_utils.gen_np_call("full", np.full, out_arr,
                                         [shape_var, zero_var, dtype_attr_var],
                                         self.typingctx, self.typemap,
                                         self.calltypes)
            equiv_set.insert_equiv(out_arr, in_arr_dim_sizes)
            init_block.body.extend(stmts)
        else:  # out is present
            if "cval" in stencil_func.options:  # do out[:] = cval
                cval = stencil_func.options["cval"]
                # TODO: Loosen this restriction to adhere to casting rules.
                cval_ty = typing.typeof.typeof(cval)
                if not self.typingctx.can_convert(cval_ty, return_type.dtype):
                    msg = "cval type does not match stencil return type."
                    raise ValueError(msg)

                # get slice ref
                slice_var = ir.Var(scope, mk_unique_var("$py_g_var"), loc)
                slice_fn_ty = self.typingctx.resolve_value_type(slice)
                self.typemap[slice_var.name] = slice_fn_ty
                slice_g = ir.Global('slice', slice, loc)
                slice_assigned = ir.Assign(slice_g, slice_var, loc)
                init_block.body.append(slice_assigned)

                sig = self.typingctx.resolve_function_type(
                    slice_fn_ty, (types.none, ) * 2, {})

                callexpr = ir.Expr.call(func=slice_var,
                                        args=(),
                                        kws=(),
                                        loc=loc)

                self.calltypes[callexpr] = sig
                slice_inst_var = ir.Var(scope, mk_unique_var("$slice_inst"),
                                        loc)
                self.typemap[slice_inst_var.name] = types.slice2_type
                slice_assign = ir.Assign(callexpr, slice_inst_var, loc)
                init_block.body.append(slice_assign)

                # get const val for cval
                cval_const_val = ir.Const(return_type.dtype(cval), loc)
                cval_const_var = ir.Var(scope, mk_unique_var("$cval_const"),
                                        loc)
                self.typemap[cval_const_var.name] = return_type.dtype
                cval_const_assign = ir.Assign(cval_const_val, cval_const_var,
                                              loc)
                init_block.body.append(cval_const_assign)

                # do setitem on `out` array
                setitemexpr = ir.StaticSetItem(out_arr, slice(None, None),
                                               slice_inst_var, cval_const_var,
                                               loc)
                init_block.body.append(setitemexpr)
                sig = signature(types.none, self.typemap[out_arr.name],
                                self.typemap[slice_inst_var.name],
                                self.typemap[out_arr.name].dtype)
                self.calltypes[setitemexpr] = sig

        self.replace_return_with_setitem(stencil_blocks, exit_value_var,
                                         parfor_body_exit_label)

        if config.DEBUG_ARRAY_OPT >= 1:
            print("stencil_blocks after replacing return")
            ir_utils.dump_blocks(stencil_blocks)

        setitem_call = ir.SetItem(out_arr, parfor_ind_var, exit_value_var, loc)
        self.calltypes[setitem_call] = signature(
            types.none, self.typemap[out_arr.name],
            self.typemap[parfor_ind_var.name],
            self.typemap[out_arr.name].dtype)
        stencil_blocks[parfor_body_exit_label].body.extend(for_replacing_ret)
        stencil_blocks[parfor_body_exit_label].body.append(setitem_call)

        # simplify CFG of parfor body (exit block could be simplified often)
        # add dummy return to enable CFG
        dummy_loc = ir.Loc("stencilparfor_dummy", -1)
        ret_const_var = ir.Var(scope, mk_unique_var("$cval_const"), dummy_loc)
        cval_const_assign = ir.Assign(ir.Const(0, loc=dummy_loc),
                                      ret_const_var, dummy_loc)
        stencil_blocks[parfor_body_exit_label].body.append(cval_const_assign)

        stencil_blocks[parfor_body_exit_label].body.append(
            ir.Return(ret_const_var, dummy_loc), )
        stencil_blocks = ir_utils.simplify_CFG(stencil_blocks)
        stencil_blocks[max(stencil_blocks.keys())].body.pop()

        if config.DEBUG_ARRAY_OPT >= 1:
            print("stencil_blocks after adding SetItem")
            ir_utils.dump_blocks(stencil_blocks)

        pattern = ('stencil', [start_lengths, end_lengths])
        parfor = numba.parfors.parfor.Parfor(loopnests, init_block,
                                             stencil_blocks, loc,
                                             parfor_ind_var, equiv_set,
                                             pattern, self.flags)
        gen_nodes.append(parfor)
        gen_nodes.append(ir.Assign(out_arr, target, loc))
        return gen_nodes