コード例 #1
0
    def run_pass(self, state):
        """
        Create type annotation after type inference
        """
        # add back in dels.
        post_proc = postproc.PostProcessor(state.func_ir)
        post_proc.run(emit_dels=True)

        state.type_annotation = type_annotations.TypeAnnotation(
            func_ir=state.func_ir.copy(),
            typemap=state.typemap,
            calltypes=state.calltypes,
            lifted=state.lifted,
            lifted_from=state.lifted_from,
            args=state.args,
            return_type=state.return_type,
            html_output=config.HTML)

        if config.ANNOTATE:
            print("ANNOTATION".center(80, '-'))
            print(state.type_annotation)
            print('=' * 80)
        if config.HTML:
            with open(config.HTML, 'w') as fout:
                state.type_annotation.html_annotate(fout)

        # now remove dels
        post_proc.remove_dels()
        return False
コード例 #2
0
    def run_pass(self, state):
        """
        Perform any intermediate representation rewrites after type
        inference.
        """
        # a bunch of these passes are either making assumptions or rely on some
        # very picky and slightly bizarre state particularly in relation to
        # ir.Del presence. To accommodate, ir.Dels are added ahead of running
        # this pass and stripped at the end.

        # Ensure we have an IR and type information.
        assert state.func_ir
        assert isinstance(getattr(state, "typemap", None), dict)
        assert isinstance(getattr(state, "calltypes", None), dict)
        msg = (
            "Internal error in post-inference rewriting "
            "pass encountered during compilation of "
            'function "%s"'
            % (state.func_id.func_name,)
        )

        pp = postproc.PostProcessor(state.func_ir)
        pp.run(True)
        with fallback_context(state, msg):
            rewrites.rewrite_registry.apply("after-inference", state)
        pp.remove_dels()
        return True
コード例 #3
0
def with_lifting(func_ir, typingctx, targetctx, flags, locals):
    """With-lifting transformation

    Rewrite the IR to extract all withs.
    Only the top-level withs are extracted.
    Returns the (the_new_ir, the_lifted_with_ir)
    """
    from numba.core import postproc

    def dispatcher_factory(func_ir, objectmode=False, **kwargs):
        from numba.core.dispatcher import LiftedWith, ObjModeLiftedWith

        myflags = flags.copy()
        if objectmode:
            # Lifted with-block cannot looplift
            myflags.enable_looplift = False
            # Lifted with-block uses object mode
            myflags.enable_pyobject = True
            myflags.force_pyobject = True
            myflags.no_cpython_wrapper = False
            cls = ObjModeLiftedWith
        else:
            cls = LiftedWith
        return cls(func_ir, typingctx, targetctx, myflags, locals, **kwargs)

    # find where with-contexts regions are
    withs = find_setupwiths(func_ir.blocks)
    if not withs:
        return func_ir, []

    postproc.PostProcessor(func_ir).run()  # ensure we have variable lifetime
    assert func_ir.variable_lifetime
    vlt = func_ir.variable_lifetime
    blocks = func_ir.blocks.copy()
    cfg = vlt.cfg
    _legalize_withs_cfg(withs, cfg, blocks)
    # For each with-regions, mutate them according to
    # the kind of contextmanager
    sub_irs = []
    for (blk_start, blk_end) in withs:
        body_blocks = []
        for node in _cfg_nodes_in_region(cfg, blk_start, blk_end):
            body_blocks.append(node)

        _legalize_with_head(blocks[blk_start])
        # Find the contextmanager
        cmkind, extra = _get_with_contextmanager(func_ir, blocks, blk_start)
        # Mutate the body and get new IR
        sub = cmkind.mutate_with_body(func_ir, blocks, blk_start, blk_end,
                                      body_blocks, dispatcher_factory,
                                      extra)
        sub_irs.append(sub)
    if not sub_irs:
        # Unchanged
        new_ir = func_ir
    else:
        new_ir = func_ir.derive(blocks)
    return new_ir, sub_irs
コード例 #4
0
 def run_pass(self, state):
     state.array_analysis = ArrayAnalysis(state.typingctx, state.func_ir,
                                          state.typemap, state.calltypes)
     state.array_analysis.run(state.func_ir.blocks)
     post_proc = postproc.PostProcessor(state.func_ir)
     post_proc.run()
     state.func_ir_copies.append(state.func_ir.copy())
     if state.test_idempotence and len(state.func_ir_copies) > 1:
         state.test_idempotence(state.func_ir_copies)
     return False
コード例 #5
0
 def run_pass(self, state):
     # assuming the function has one block with one call inside
     assert len(state.func_ir.blocks) == 1
     block = list(state.func_ir.blocks.values())[0]
     for i, stmt in enumerate(block.body):
         if guard(find_callname,state.func_ir, stmt.value) is not None:
             inline_closure_call(state.func_ir, {}, block, i, lambda: None,
                 state.typingctx, (), state.type_annotation.typemap,
                 state.type_annotation.calltypes)
             break
     # also fix up the IR
     post_proc = postproc.PostProcessor(state.func_ir)
     post_proc.run()
     post_proc.remove_dels()
     return True
コード例 #6
0
ファイル: compiler.py プロジェクト: zsoltc89/numba
def run_frontend(func, inline_closures=False):
    """
    Run the compiler frontend over the given Python function, and return
    the function's canonical Numba IR.

    If inline_closures is Truthy then closure inlining will be run
    """
    # XXX make this a dedicated Pipeline?
    func_id = bytecode.FunctionIdentity.from_function(func)
    interp = interpreter.Interpreter(func_id)
    bc = bytecode.ByteCode(func_id=func_id)
    func_ir = interp.interpret(bc)
    if inline_closures:
        inline_pass = InlineClosureCallPass(func_ir,
                                            cpu.ParallelOptions(False), {},
                                            False)
        inline_pass.run()
    post_proc = postproc.PostProcessor(func_ir)
    post_proc.run()
    return func_ir
コード例 #7
0
ファイル: registry.py プロジェクト: ArtShp/DataScience
    def apply(self, kind, state):
        '''Given a pipeline and a dictionary of basic blocks, exhaustively
        attempt to apply all registered rewrites to all basic blocks.
        '''
        assert kind in self._kinds
        blocks = state.func_ir.blocks
        old_blocks = blocks.copy()
        for rewrite_cls in self.rewrites[kind]:
            # Exhaustively apply a rewrite until it stops matching.
            rewrite = rewrite_cls(state)
            work_list = list(blocks.items())
            while work_list:
                key, block = work_list.pop()
                matches = rewrite.match(state.func_ir, block, state.typemap,
                                        state.calltypes)
                if matches:
                    if config.DEBUG or config.DUMP_IR:
                        print("_" * 70)
                        print("REWRITING (%s):" % rewrite_cls.__name__)
                        block.dump()
                        print("_" * 60)
                    new_block = rewrite.apply()
                    blocks[key] = new_block
                    work_list.append((key, new_block))
                    if config.DEBUG or config.DUMP_IR:
                        new_block.dump()
                        print("_" * 70)
        # If any blocks were changed, perform a sanity check.
        for key, block in blocks.items():
            if block != old_blocks[key]:
                block.verify()

        # Some passes, e.g. _inline_const_arraycall are known to occasionally
        # do invalid things WRT ir.Del, others, e.g. RewriteArrayExprs do valid
        # things with ir.Del, but the placement is not optimal. The lines below
        # fix-up the IR so that ref counts are valid and optimally placed,
        # see #4093 for context. This has to be run here opposed to in
        # apply() as the CFG needs computing so full IR is needed.
        from numba.core import postproc
        post_proc = postproc.PostProcessor(state.func_ir)
        post_proc.run()
コード例 #8
0
    def run_pass(self, state):
        state.func_ir = self._strip_phi_nodes(state.func_ir)
        state.func_ir._definitions = build_definitions(state.func_ir.blocks)
        # Rerun postprocessor to update metadata
        post_proc = postproc.PostProcessor(state.func_ir)
        post_proc.run(emit_dels=False)

        # Ensure we are not in objectmode generator
        if state.func_ir.generator_info is not None and state.typemap is not None:
            # Rebuild generator type
            # TODO: move this into PostProcessor
            gentype = state.return_type
            state_vars = state.func_ir.generator_info.state_vars
            state_types = [state.typemap[k] for k in state_vars]
            state.return_type = types.Generator(
                gen_func=gentype.gen_func,
                yield_type=gentype.yield_type,
                arg_types=gentype.arg_types,
                state_types=state_types,
                has_finalizer=gentype.has_finalizer,
            )
        return True
コード例 #9
0
ファイル: test_analysis.py プロジェクト: zhaijf1992/numba
    def assert_prune(self, func, args_tys, prune, *args, **kwargs):
        # This checks that the expected pruned branches have indeed been pruned.
        # func is a python function to assess
        # args_tys is the numba types arguments tuple
        # prune arg is a list, one entry per branch. The value in the entry is
        # encoded as follows:
        # True: using constant inference only, the True branch will be pruned
        # False: using constant inference only, the False branch will be pruned
        # None: under no circumstances should this branch be pruned
        # *args: the argument instances to pass to the function to check
        #        execution is still valid post transform
        # **kwargs:
        #        - flags: compiler.Flags instance to pass to `compile_isolated`,
        #          permits use of e.g. object mode

        func_ir = compile_to_ir(func)
        before = func_ir.copy()
        if self._DEBUG:
            print("=" * 80)
            print("before inline")
            func_ir.dump()

        # run closure inlining to ensure that nonlocals in closures are visible
        inline_pass = InlineClosureCallPass(
            func_ir,
            cpu.ParallelOptions(False),
        )
        inline_pass.run()

        # Remove all Dels, and re-run postproc
        post_proc = postproc.PostProcessor(func_ir)
        post_proc.run()

        rewrite_semantic_constants(func_ir, args_tys)
        if self._DEBUG:
            print("=" * 80)
            print("before prune")
            func_ir.dump()

        dead_branch_prune(func_ir, args_tys)

        after = func_ir
        if self._DEBUG:
            print("after prune")
            func_ir.dump()

        before_branches = self.find_branches(before)
        self.assertEqual(len(before_branches), len(prune))

        # what is expected to be pruned
        expect_removed = []
        for idx, prune in enumerate(prune):
            branch = before_branches[idx]
            if prune is True:
                expect_removed.append(branch.truebr)
            elif prune is False:
                expect_removed.append(branch.falsebr)
            elif prune is None:
                pass  # nothing should be removed!
            elif prune == 'both':
                expect_removed.append(branch.falsebr)
                expect_removed.append(branch.truebr)
            else:
                assert 0, "unreachable"

        # compare labels
        original_labels = set([_ for _ in before.blocks.keys()])
        new_labels = set([_ for _ in after.blocks.keys()])
        # assert that the new labels are precisely the original less the
        # expected pruned labels
        try:
            self.assertEqual(new_labels, original_labels - set(expect_removed))
        except AssertionError as e:
            print("new_labels", sorted(new_labels))
            print("original_labels", sorted(original_labels))
            print("expect_removed", sorted(expect_removed))
            raise e

        supplied_flags = kwargs.pop('flags', False)
        compiler_kws = {'flags': supplied_flags} if supplied_flags else {}
        cres = compile_isolated(func, args_tys, **compiler_kws)
        if args is None:
            res = cres.entry_point()
            expected = func()
        else:
            res = cres.entry_point(*args)
            expected = func(*args)
        self.assertEqual(res, expected)
コード例 #10
0
 def run_pass(self, state):
     pp = postproc.PostProcessor(state.func_ir)
     pp.run(emit_dels=True)
     return True