def test_replace_exitswitch_by_constant_bug(): class X: pass def constant9(): x = X() x.n = 3 x.n = 9 return x.n def fn(): n = constant9() if n == 1: return 5 elif n == 2: return 6 elif n == 3: return 8 elif n == 4: return -123 elif n == 5: return 12973 else: return n t = TranslationContext() a = t.buildannotator() a.build_types(fn, []) rtyper = t.buildrtyper() rtyper.specialize() graph = t.graphs[0] remove_same_as(graph) merge_if_blocks_once(graph) from rpython.translator.backendopt import malloc, inline inline.auto_inlining(t, 20) malloc.remove_mallocs(t, t.graphs) from rpython.translator import simplify simplify.join_blocks(graph)
def check_auto_inlining( self, func, sig, multiplier=None, call_count_check=False, remove_same_as=False, heuristic=None, const_fold_first=False, ): t = self.translate(func, sig) if const_fold_first: from rpython.translator.backendopt.constfold import constant_fold_graph from rpython.translator.simplify import eliminate_empty_blocks for graph in t.graphs: constant_fold_graph(graph) eliminate_empty_blocks(graph) if option.view: t.view() # inline! sanity_check(t) # also check before inlining (so we don't blame it) threshold = INLINE_THRESHOLD_FOR_TEST if multiplier is not None: threshold *= multiplier call_count_pred = None if call_count_check: call_count_pred = lambda lbl: True instrument_inline_candidates(t.graphs, threshold) if remove_same_as: for graph in t.graphs: removenoops.remove_same_as(graph) if heuristic is not None: kwargs = {"heuristic": heuristic} else: kwargs = {} auto_inlining(t, threshold, call_count_pred=call_count_pred, **kwargs) sanity_check(t) if option.view: t.view() interp = LLInterpreter(t.rtyper) def eval_func(args): return interp.eval_graph(graphof(t, func), args) return eval_func, t
def check_auto_inlining(self, func, sig, multiplier=None, call_count_check=False, remove_same_as=False, heuristic=None, const_fold_first=False): t = self.translate(func, sig) if const_fold_first: from rpython.translator.backendopt.constfold import constant_fold_graph from rpython.translator.simplify import eliminate_empty_blocks for graph in t.graphs: constant_fold_graph(graph) eliminate_empty_blocks(graph) if option.view: t.view() # inline! sanity_check(t) # also check before inlining (so we don't blame it) threshold = INLINE_THRESHOLD_FOR_TEST if multiplier is not None: threshold *= multiplier call_count_pred = None if call_count_check: call_count_pred = lambda lbl: True instrument_inline_candidates(t.graphs, threshold) if remove_same_as: for graph in t.graphs: removenoops.remove_same_as(graph) if heuristic is not None: kwargs = {"heuristic": heuristic} else: kwargs = {} auto_inlining(t, threshold, call_count_pred=call_count_pred, **kwargs) sanity_check(t) if option.view: t.view() interp = LLInterpreter(t.rtyper) def eval_func(args): return interp.eval_graph(graphof(t, func), args) return eval_func, t
def inline_and_remove(t, graphs, threshold=BIG_THRESHOLD, heuristic=inline.inlining_heuristic): callgraph, caller_candidates = find_malloc_removal_candidates(t, graphs) log.inlineandremove("found %s malloc removal candidates" % len(caller_candidates)) if callgraph: count = inline.auto_inlining(t, callgraph=callgraph, threshold=threshold, heuristic=heuristic) if not count: return False log.inlineandremove("inlined %d callsites." % (count,)) count = remove_mallocs(t, caller_candidates.keys()) return count else: return False
def inline_inlineable_portals(self): """ Find all the graphs which have been decorated with @jitdriver.inline and inline them in the callers, making them JIT portals. Then, create a fresh copy of the jitdriver for each of those new portals, because they cannot share the same one. See test_ajit::test_inline_jit_merge_point """ from rpython.translator.backendopt.inline import ( inlinable_static_callers, auto_inlining) jmp_calls = {} def get_jmp_call(graph, _inline_jit_merge_point_): # there might be multiple calls to the @inlined function: the # first time we see it, we remove the call to the jit_merge_point # and we remember the corresponding op. Then, we create a new call # to it every time we need a new one (i.e., for each callsite # which becomes a new portal) try: op, jmp_graph = jmp_calls[graph] except KeyError: op, jmp_graph = fish_jmp_call(graph, _inline_jit_merge_point_) jmp_calls[graph] = op, jmp_graph # # clone the op newargs = op.args[:] newresult = Variable() newresult.concretetype = op.result.concretetype op = SpaceOperation(op.opname, newargs, newresult) return op, jmp_graph def fish_jmp_call(graph, _inline_jit_merge_point_): # graph is function which has been decorated with # @jitdriver.inline, so its very first op is a call to the # function which contains the actual jit_merge_point: fish it! jmp_block, op_jmp_call = next(callee.iterblockops()) msg = ("The first operation of an _inline_jit_merge_point_ graph must be " "a direct_call to the function passed to @jitdriver.inline()") assert op_jmp_call.opname == 'direct_call', msg jmp_funcobj = op_jmp_call.args[0].value._obj assert jmp_funcobj._callable is _inline_jit_merge_point_, msg jmp_block.operations.remove(op_jmp_call) return op_jmp_call, jmp_funcobj.graph # find all the graphs which call an @inline_in_portal function callgraph = inlinable_static_callers(self.translator.graphs, store_calls=True) new_callgraph = [] new_portals = set() inlined_jit_merge_points = set() for caller, block, op_call, callee in callgraph: func = getattr(callee, 'func', None) _inline_jit_merge_point_ = getattr(func, '_inline_jit_merge_point_', None) if _inline_jit_merge_point_: _inline_jit_merge_point_._always_inline_ = True inlined_jit_merge_points.add(_inline_jit_merge_point_) op_jmp_call, jmp_graph = get_jmp_call(callee, _inline_jit_merge_point_) # # now we move the op_jmp_call from callee to caller, just # before op_call. We assume that the args passed to # op_jmp_call are the very same which are received by callee # (i.e., the one passed to op_call) assert len(op_call.args) == len(op_jmp_call.args) op_jmp_call.args[1:] = op_call.args[1:] idx = block.operations.index(op_call) block.operations.insert(idx, op_jmp_call) # # finally, we signal that we want to inline op_jmp_call into # caller, so that finally the actuall call to # driver.jit_merge_point will be seen there new_callgraph.append((caller, jmp_graph)) new_portals.add(caller) # inline them! inline_threshold = 0.1 # we rely on the _always_inline_ set above auto_inlining(self.translator, inline_threshold, new_callgraph) # clean up _always_inline_ = True, it can explode later for item in inlined_jit_merge_points: del item._always_inline_ # make a fresh copy of the JitDriver in all newly created # jit_merge_points self.clone_inlined_jit_merge_points(new_portals)
def inline_inlineable_portals(self): """ Find all the graphs which have been decorated with @jitdriver.inline and inline them in the callers, making them JIT portals. Then, create a fresh copy of the jitdriver for each of those new portals, because they cannot share the same one. See test_ajit::test_inline_jit_merge_point """ from rpython.translator.backendopt.inline import ( inlinable_static_callers, auto_inlining) jmp_calls = {} def get_jmp_call(graph, _inline_jit_merge_point_): # there might be multiple calls to the @inlined function: the # first time we see it, we remove the call to the jit_merge_point # and we remember the corresponding op. Then, we create a new call # to it every time we need a new one (i.e., for each callsite # which becomes a new portal) try: op, jmp_graph = jmp_calls[graph] except KeyError: op, jmp_graph = fish_jmp_call(graph, _inline_jit_merge_point_) jmp_calls[graph] = op, jmp_graph # # clone the op newargs = op.args[:] newresult = Variable() newresult.concretetype = op.result.concretetype op = SpaceOperation(op.opname, newargs, newresult) return op, jmp_graph def fish_jmp_call(graph, _inline_jit_merge_point_): # graph is function which has been decorated with # @jitdriver.inline, so its very first op is a call to the # function which contains the actual jit_merge_point: fish it! jmp_block, op_jmp_call = next(callee.iterblockops()) msg = ( "The first operation of an _inline_jit_merge_point_ graph must be " "a direct_call to the function passed to @jitdriver.inline()") assert op_jmp_call.opname == 'direct_call', msg jmp_funcobj = op_jmp_call.args[0].value._obj assert jmp_funcobj._callable is _inline_jit_merge_point_, msg jmp_block.operations.remove(op_jmp_call) return op_jmp_call, jmp_funcobj.graph # find all the graphs which call an @inline_in_portal function callgraph = inlinable_static_callers(self.translator.graphs, store_calls=True) new_callgraph = [] new_portals = set() inlined_jit_merge_points = set() for caller, block, op_call, callee in callgraph: func = getattr(callee, 'func', None) _inline_jit_merge_point_ = getattr(func, '_inline_jit_merge_point_', None) if _inline_jit_merge_point_: _inline_jit_merge_point_._always_inline_ = True inlined_jit_merge_points.add(_inline_jit_merge_point_) op_jmp_call, jmp_graph = get_jmp_call( callee, _inline_jit_merge_point_) # # now we move the op_jmp_call from callee to caller, just # before op_call. We assume that the args passed to # op_jmp_call are the very same which are received by callee # (i.e., the one passed to op_call) assert len(op_call.args) == len(op_jmp_call.args) op_jmp_call.args[1:] = op_call.args[1:] idx = block.operations.index(op_call) block.operations.insert(idx, op_jmp_call) # # finally, we signal that we want to inline op_jmp_call into # caller, so that finally the actuall call to # driver.jit_merge_point will be seen there new_callgraph.append((caller, jmp_graph)) new_portals.add(caller) # inline them! inline_threshold = 0.1 # we rely on the _always_inline_ set above auto_inlining(self.translator, inline_threshold, new_callgraph) # clean up _always_inline_ = True, it can explode later for item in inlined_jit_merge_points: del item._always_inline_ # make a fresh copy of the JitDriver in all newly created # jit_merge_points self.clone_inlined_jit_merge_points(new_portals)