def run_pass(self, state): """ Create type annotation after type inference """ # add back in dels. post_proc = postproc.PostProcessor(state.func_ir) post_proc.run(emit_dels=True) state.type_annotation = type_annotations.TypeAnnotation( func_ir=state.func_ir.copy(), typemap=state.typemap, calltypes=state.calltypes, lifted=state.lifted, lifted_from=state.lifted_from, args=state.args, return_type=state.return_type, html_output=config.HTML) if config.ANNOTATE: print("ANNOTATION".center(80, '-')) print(state.type_annotation) print('=' * 80) if config.HTML: with open(config.HTML, 'w') as fout: state.type_annotation.html_annotate(fout) # now remove dels post_proc.remove_dels() return False
def run_pass(self, state): """ Perform any intermediate representation rewrites after type inference. """ # a bunch of these passes are either making assumptions or rely on some # very picky and slightly bizarre state particularly in relation to # ir.Del presence. To accommodate, ir.Dels are added ahead of running # this pass and stripped at the end. # Ensure we have an IR and type information. assert state.func_ir assert isinstance(getattr(state, "typemap", None), dict) assert isinstance(getattr(state, "calltypes", None), dict) msg = ( "Internal error in post-inference rewriting " "pass encountered during compilation of " 'function "%s"' % (state.func_id.func_name,) ) pp = postproc.PostProcessor(state.func_ir) pp.run(True) with fallback_context(state, msg): rewrites.rewrite_registry.apply("after-inference", state) pp.remove_dels() return True
def with_lifting(func_ir, typingctx, targetctx, flags, locals): """With-lifting transformation Rewrite the IR to extract all withs. Only the top-level withs are extracted. Returns the (the_new_ir, the_lifted_with_ir) """ from numba.core import postproc def dispatcher_factory(func_ir, objectmode=False, **kwargs): from numba.core.dispatcher import LiftedWith, ObjModeLiftedWith myflags = flags.copy() if objectmode: # Lifted with-block cannot looplift myflags.enable_looplift = False # Lifted with-block uses object mode myflags.enable_pyobject = True myflags.force_pyobject = True myflags.no_cpython_wrapper = False cls = ObjModeLiftedWith else: cls = LiftedWith return cls(func_ir, typingctx, targetctx, myflags, locals, **kwargs) # find where with-contexts regions are withs = find_setupwiths(func_ir.blocks) if not withs: return func_ir, [] postproc.PostProcessor(func_ir).run() # ensure we have variable lifetime assert func_ir.variable_lifetime vlt = func_ir.variable_lifetime blocks = func_ir.blocks.copy() cfg = vlt.cfg _legalize_withs_cfg(withs, cfg, blocks) # For each with-regions, mutate them according to # the kind of contextmanager sub_irs = [] for (blk_start, blk_end) in withs: body_blocks = [] for node in _cfg_nodes_in_region(cfg, blk_start, blk_end): body_blocks.append(node) _legalize_with_head(blocks[blk_start]) # Find the contextmanager cmkind, extra = _get_with_contextmanager(func_ir, blocks, blk_start) # Mutate the body and get new IR sub = cmkind.mutate_with_body(func_ir, blocks, blk_start, blk_end, body_blocks, dispatcher_factory, extra) sub_irs.append(sub) if not sub_irs: # Unchanged new_ir = func_ir else: new_ir = func_ir.derive(blocks) return new_ir, sub_irs
def run_pass(self, state): state.array_analysis = ArrayAnalysis(state.typingctx, state.func_ir, state.typemap, state.calltypes) state.array_analysis.run(state.func_ir.blocks) post_proc = postproc.PostProcessor(state.func_ir) post_proc.run() state.func_ir_copies.append(state.func_ir.copy()) if state.test_idempotence and len(state.func_ir_copies) > 1: state.test_idempotence(state.func_ir_copies) return False
def run_pass(self, state): # assuming the function has one block with one call inside assert len(state.func_ir.blocks) == 1 block = list(state.func_ir.blocks.values())[0] for i, stmt in enumerate(block.body): if guard(find_callname,state.func_ir, stmt.value) is not None: inline_closure_call(state.func_ir, {}, block, i, lambda: None, state.typingctx, (), state.type_annotation.typemap, state.type_annotation.calltypes) break # also fix up the IR post_proc = postproc.PostProcessor(state.func_ir) post_proc.run() post_proc.remove_dels() return True
def run_frontend(func, inline_closures=False): """ Run the compiler frontend over the given Python function, and return the function's canonical Numba IR. If inline_closures is Truthy then closure inlining will be run """ # XXX make this a dedicated Pipeline? func_id = bytecode.FunctionIdentity.from_function(func) interp = interpreter.Interpreter(func_id) bc = bytecode.ByteCode(func_id=func_id) func_ir = interp.interpret(bc) if inline_closures: inline_pass = InlineClosureCallPass(func_ir, cpu.ParallelOptions(False), {}, False) inline_pass.run() post_proc = postproc.PostProcessor(func_ir) post_proc.run() return func_ir
def apply(self, kind, state): '''Given a pipeline and a dictionary of basic blocks, exhaustively attempt to apply all registered rewrites to all basic blocks. ''' assert kind in self._kinds blocks = state.func_ir.blocks old_blocks = blocks.copy() for rewrite_cls in self.rewrites[kind]: # Exhaustively apply a rewrite until it stops matching. rewrite = rewrite_cls(state) work_list = list(blocks.items()) while work_list: key, block = work_list.pop() matches = rewrite.match(state.func_ir, block, state.typemap, state.calltypes) if matches: if config.DEBUG or config.DUMP_IR: print("_" * 70) print("REWRITING (%s):" % rewrite_cls.__name__) block.dump() print("_" * 60) new_block = rewrite.apply() blocks[key] = new_block work_list.append((key, new_block)) if config.DEBUG or config.DUMP_IR: new_block.dump() print("_" * 70) # If any blocks were changed, perform a sanity check. for key, block in blocks.items(): if block != old_blocks[key]: block.verify() # Some passes, e.g. _inline_const_arraycall are known to occasionally # do invalid things WRT ir.Del, others, e.g. RewriteArrayExprs do valid # things with ir.Del, but the placement is not optimal. The lines below # fix-up the IR so that ref counts are valid and optimally placed, # see #4093 for context. This has to be run here opposed to in # apply() as the CFG needs computing so full IR is needed. from numba.core import postproc post_proc = postproc.PostProcessor(state.func_ir) post_proc.run()
def run_pass(self, state): state.func_ir = self._strip_phi_nodes(state.func_ir) state.func_ir._definitions = build_definitions(state.func_ir.blocks) # Rerun postprocessor to update metadata post_proc = postproc.PostProcessor(state.func_ir) post_proc.run(emit_dels=False) # Ensure we are not in objectmode generator if state.func_ir.generator_info is not None and state.typemap is not None: # Rebuild generator type # TODO: move this into PostProcessor gentype = state.return_type state_vars = state.func_ir.generator_info.state_vars state_types = [state.typemap[k] for k in state_vars] state.return_type = types.Generator( gen_func=gentype.gen_func, yield_type=gentype.yield_type, arg_types=gentype.arg_types, state_types=state_types, has_finalizer=gentype.has_finalizer, ) return True
def assert_prune(self, func, args_tys, prune, *args, **kwargs): # This checks that the expected pruned branches have indeed been pruned. # func is a python function to assess # args_tys is the numba types arguments tuple # prune arg is a list, one entry per branch. The value in the entry is # encoded as follows: # True: using constant inference only, the True branch will be pruned # False: using constant inference only, the False branch will be pruned # None: under no circumstances should this branch be pruned # *args: the argument instances to pass to the function to check # execution is still valid post transform # **kwargs: # - flags: compiler.Flags instance to pass to `compile_isolated`, # permits use of e.g. object mode func_ir = compile_to_ir(func) before = func_ir.copy() if self._DEBUG: print("=" * 80) print("before inline") func_ir.dump() # run closure inlining to ensure that nonlocals in closures are visible inline_pass = InlineClosureCallPass( func_ir, cpu.ParallelOptions(False), ) inline_pass.run() # Remove all Dels, and re-run postproc post_proc = postproc.PostProcessor(func_ir) post_proc.run() rewrite_semantic_constants(func_ir, args_tys) if self._DEBUG: print("=" * 80) print("before prune") func_ir.dump() dead_branch_prune(func_ir, args_tys) after = func_ir if self._DEBUG: print("after prune") func_ir.dump() before_branches = self.find_branches(before) self.assertEqual(len(before_branches), len(prune)) # what is expected to be pruned expect_removed = [] for idx, prune in enumerate(prune): branch = before_branches[idx] if prune is True: expect_removed.append(branch.truebr) elif prune is False: expect_removed.append(branch.falsebr) elif prune is None: pass # nothing should be removed! elif prune == 'both': expect_removed.append(branch.falsebr) expect_removed.append(branch.truebr) else: assert 0, "unreachable" # compare labels original_labels = set([_ for _ in before.blocks.keys()]) new_labels = set([_ for _ in after.blocks.keys()]) # assert that the new labels are precisely the original less the # expected pruned labels try: self.assertEqual(new_labels, original_labels - set(expect_removed)) except AssertionError as e: print("new_labels", sorted(new_labels)) print("original_labels", sorted(original_labels)) print("expect_removed", sorted(expect_removed)) raise e supplied_flags = kwargs.pop('flags', False) compiler_kws = {'flags': supplied_flags} if supplied_flags else {} cres = compile_isolated(func, args_tys, **compiler_kws) if args is None: res = cres.entry_point() expected = func() else: res = cres.entry_point(*args) expected = func(*args) self.assertEqual(res, expected)
def run_pass(self, state): pp = postproc.PostProcessor(state.func_ir) pp.run(emit_dels=True) return True