def test_inactive_jitdriver(self): myjitdriver1 = JitDriver(greens=[], reds=["n", "m"], get_printable_location=getloc1) myjitdriver2 = JitDriver(greens=["g"], reds=["r"], get_printable_location=getloc2) # myjitdriver1.active = False # <=== # def loop1(n, m): while n > 0: myjitdriver1.can_enter_jit(n=n, m=m) myjitdriver1.jit_merge_point(n=n, m=m) n -= m return n # def loop2(g, r): while r > 0: myjitdriver2.can_enter_jit(g=g, r=r) myjitdriver2.jit_merge_point(g=g, r=r) r += loop1(r, g) + (-1) return r # res = self.meta_interp(loop2, [4, 40], repeat=7, inline=True) assert res == loop2(4, 40) # we expect no int_sub, but a residual call self.check_loops(int_sub=0, call=1)
def test_argument_order_more_precision_later_2(self): myjitdriver = JitDriver(greens=['r1', 'i1', 'r2', 'f1'], reds=[]) class A(object): pass myjitdriver.jit_merge_point(i1=42, r1=None, r2=A(), f1=3.5) e = raises(AssertionError, myjitdriver.jit_merge_point, i1=42, r1=A(), r2=None, f1=3.5) assert "got ['2:REF', '1:INT', '2:REF', '3:FLOAT']" in repr(e.value)
def __init__(self, name, debugprint, **kwds): JitDriver.__init__(self, **kwds) # def get_printable_location(*args): # we print based on indices in 'args'. We first print # 'ctx.pattern' from the arg number debugprint[0]. pattern = args[debugprint[0]] s = str(pattern) if len(s) > 120: s = s[:110] + '...' if len(debugprint) > 1: # then we print numbers from the args number # debugprint[1] and possibly debugprint[2] info = ' at %d' % (args[debugprint[1]],) if len(debugprint) > 2: info = '%s/%d' % (info, args[debugprint[2]]) else: info = '' return '%s%s %s' % (name, info, s) # self.get_printable_location = get_printable_location
def loop(self, program, pc, bracket_map): jitdriver2 = JitDriver(greens=['pc', 'program', 'bracket_map'], reds = ['tape']) while pc < len(program): jitdriver2.jit_merge_point(pc=pc, tape=self, program=program, bracket_map=bracket_map) code = program[pc] if code == ">": self.advance() elif code == "<": self.devance() elif code == "+": self.inc() elif code == "-": self.dec() elif code == ".": # print os.write(1, chr(self.get())) elif code == ",": # read from stdin self.set(ord(os.read(0, 1)[0])) elif code == "[" and self.get() == 0: # Skip forward to the matching ] pc = bracket_map[pc] elif code == "]" and self.get() != 0: # Skip back to the matching [ pc = bracket_map[pc] pc += 1 return pc
def test_outer_and_inner_loop(self): jitdriver = JitDriver(greens=['p', 'code'], reds=['i', 'j', 'total']) class Code: def __init__(self, lst): self.lst = lst codes = [Code([]), Code([0, 0, 1, 1])] def interpret(num): code = codes[num] p = 0 i = 0 j = 0 total = 0 while p < len(code.lst): jitdriver.jit_merge_point(code=code, p=p, i=i, j=j, total=total) total += i e = code.lst[p] if e == 0: p += 1 elif e == 1: if i < p * 20: p = 3 - p i += 1 jitdriver.can_enter_jit(code=code, p=p, j=j, i=i, total=total) else: j += 1 i = j p += 1 return total res = self.meta_interp(interpret, [1]) assert res == interpret(1) # XXX it's unsure how many loops should be there self.check_loop_count(3)
def test_list_length_1(self): myjitdriver = JitDriver(greens=['foo'], reds=['x', 'total']) class Foo: _immutable_fields_ = ['lst?[*]'] def __init__(self, lst): self.lst = lst class A: pass def f(a, x): lst1 = [0, 0] lst1[1] = a foo = Foo(lst1) total = 0 while x > 0: myjitdriver.jit_merge_point(foo=foo, x=x, total=total) # make it a Constant after optimization only a = A() a.foo = foo foo = a.foo # read a quasi-immutable field out of it total += foo.lst[1] # also read the length total += len(foo.lst) x -= 1 return total # res = self.meta_interp(f, [100, 7]) assert res == 714 self.check_resops(getarrayitem_gc_pure=0, guard_not_invalidated=2, arraylen_gc=0, getarrayitem_gc=0, getfield_gc=0) # from pypy.jit.metainterp.warmspot import get_stats loops = get_stats().loops for loop in loops: assert len(loop.quasi_immutable_deps) == 1 assert isinstance(loop.quasi_immutable_deps.keys()[0], QuasiImmut)
def test_really_run(): """ This test checks whether output of jitprof did not change. It'll explode when someone touches jitprof.py """ mydriver = JitDriver(reds=['i', 'n'], greens=[]) def f(n): i = 0 while i < n: mydriver.can_enter_jit(i=i, n=n) mydriver.jit_merge_point(i=i, n=n) i += 1 cap = py.io.StdCaptureFD() try: ll_meta_interp(f, [10], CPUClass=runner.LLtypeCPU, type_system='lltype', ProfilerClass=Profiler, debug_level=DEBUG_PROFILE) finally: out, err = cap.reset() err = "\n".join(err.splitlines()[-JITPROF_LINES:]) print err assert err.count("\n") == JITPROF_LINES - 1 info = parse_prof(err) # assert did not crash # asserts below are a bit delicate, possibly they might be deleted assert info.tracing_no == 1 assert info.asm_no == 1 assert info.blackhole_no == 1 assert info.backend_no == 1 assert info.ops.total == 2 assert info.ops.calls == 0 assert info.ops.pure_calls == 0 assert info.recorded_ops.total == 2 assert info.recorded_ops.calls == 0 assert info.recorded_ops.pure_calls == 0 assert info.guards == 1 assert info.blackholed_ops.total == 0 assert info.blackholed_ops.pure_calls == 0 assert info.opt_ops == 6 assert info.opt_guards == 1 assert info.forcings == 0
def test_indirect_call_unknown_object_3(self): myjitdriver = JitDriver(greens=[], reds=['x', 'y', 'z', 'state']) def getvalue2(): return 2 def getvalue25(): return 25 def getvalue1001(): return -1001 class State: count = 0 def externfn(self, n): assert n == 198 - self.count self.count += 1 if n % 5: return getvalue2 elif n % 7: return getvalue25 else: return getvalue1001 def f(y): state = State() x = z = 0 while y > 0: myjitdriver.can_enter_jit(x=x, y=y, z=z, state=state) myjitdriver.jit_merge_point(x=x, y=y, z=z, state=state) x += z z = state.externfn(y)() y -= 1 return x res = self.meta_interp(f, [198], policy=StopAtXPolicy(State.externfn.im_func)) assert res == f(198) # we get four TargetTokens: one for each of the 3 getvalue functions, # and one entering from the interpreter (the preamble) self.check_jitcell_token_count(1) self.check_target_token_count(4)
def test_recursion_cant_call_assembler_directly(self): driver = JitDriver(greens = ['codeno'], reds = ['i', 'j'], get_printable_location = lambda codeno : str(codeno)) def portal(codeno, j): i = 1 while 1: driver.jit_merge_point(codeno=codeno, i=i, j=j) if (i >> 1) == 1: if j == 0: return portal(2, j - 1) elif i == 5: return i += 1 driver.can_enter_jit(codeno=codeno, i=i, j=j) portal(2, 5) from pypy.jit.metainterp import compile, pyjitpl pyjitpl._warmrunnerdesc = None trace = [] def my_ctc(*args): looptoken = original_ctc(*args) trace.append(looptoken) return looptoken original_ctc = compile.compile_tmp_callback try: compile.compile_tmp_callback = my_ctc self.meta_interp(portal, [2, 5], inline=True) self.check_resops(call_may_force=0, call_assembler=2) finally: compile.compile_tmp_callback = original_ctc # check that we made a temporary callback assert len(trace) == 1 # and that we later redirected it to something else try: redirected = pyjitpl._warmrunnerdesc.cpu._redirected_call_assembler except AttributeError: pass # not the llgraph backend else: print redirected assert redirected.keys() == trace
def test_bug(self): jitdriver = JitDriver(greens=[], reds=['n']) class X(object): pass def f(n): while n > -100: jitdriver.can_enter_jit(n=n) jitdriver.jit_merge_point(n=n) x = X() x.arg = 5 if n <= 0: break n -= x.arg x.arg = 6 # prevents 'x.arg' from being annotated as constant return n res = self.meta_interp(f, [31], enable_opts='') assert res == -4
def test_oosend_base(self): myjitdriver = JitDriver(greens=[], reds=['x', 'y', 'w']) class Base: pass class W1(Base): def __init__(self, x): self.x = x def incr(self): return W1(self.x + 1) def getvalue(self): return self.x class W2(Base): def __init__(self, y): self.y = y def incr(self): return W2(self.y + 100) def getvalue(self): return self.y def f(x, y): if x & 1: w = W1(x) else: w = W2(x) while y > 0: myjitdriver.can_enter_jit(x=x, y=y, w=w) myjitdriver.jit_merge_point(x=x, y=y, w=w) w = w.incr() y -= 1 return w.getvalue() res = self.meta_interp(f, [3, 14]) assert res == 17 res = self.meta_interp(f, [4, 14]) assert res == 1404 self.check_loops(guard_class=0, new_with_vtable=0, new=0)
def test_cannot_be_virtual(self): jitdriver = JitDriver(greens=[], reds=['n', 'l']) def f(n): l = [3] * 100 while n > 0: jitdriver.can_enter_jit(n=n, l=l) jitdriver.jit_merge_point(n=n, l=l) x = l[n] l = [3] * 100 l[3] = x l[3] = x + 1 n -= 1 return l[0] res = self.meta_interp(f, [10], listops=True) assert res == f(10) # one setitem should be gone by now self.check_loops(call=1, setarrayitem_gc=2, getarrayitem_gc=1)
def test_loop(self): myjitdriver = JitDriver(greens=[], reds=['n']) def check(n): if n < 0: raise IndexError def f(n): try: while True: myjitdriver.can_enter_jit(n=n) myjitdriver.jit_merge_point(n=n) check(n) n = n - 10 except IndexError: return n res = self.meta_interp(f, [54]) assert res == -6
def test_int_mod_ovf_zer(self): myjitdriver = JitDriver(greens=[], reds=['i', 'x', 'y']) def f(x, y): i = 0 while i < 10: myjitdriver.can_enter_jit(x=x, y=y, i=i) myjitdriver.jit_merge_point(x=x, y=y, i=i) try: ovfcheck(i % x) i += 1 except ZeroDivisionError: i += 1 except OverflowError: i += 2 return 0 self.meta_interp(f, [0, 0]) self.meta_interp(f, [1, 0])
def test_raise(self): myjitdriver = JitDriver(greens=[], reds=['n']) def f(n): while True: myjitdriver.can_enter_jit(n=n) myjitdriver.jit_merge_point(n=n) if n < 0: raise ValueError n = n - 1 def main(n): try: f(n) except ValueError: return 132 res = self.meta_interp(main, [13]) assert res == 132
def test_virtual_escaping_via_list(self): py.test.skip("unsupported") jitdriver = JitDriver(greens = [], reds = ['n', 'l']) class Stuff(object): def __init__(self, x): self.x = x def f(n): l = [Stuff(n-i) for i in range(n)] while n > 0: jitdriver.can_enter_jit(n=n, l=l) jitdriver.jit_merge_point(n=n, l=l) s = l.pop() n -= s.x res = self.meta_interp(f, [20]) assert res == f(20) self.check_loops(pop=1, getfield_gc=1)
def test_stuff_escapes_via_setitem(self): py.test.skip("unsupported") jitdriver = JitDriver(greens = [], reds = ['n', 'l']) class Stuff(object): def __init__(self, x): self.x = x def f(n): l = [None] while n > 0: jitdriver.can_enter_jit(n=n, l=l) jitdriver.jit_merge_point(n=n, l=l) s = Stuff(3) l.append(s) n -= l[0].x return n res = self.meta_interp(f, [30]) assert res == 0 self.check_loops(append=1)
def test_nested_loops_bridge(self): class Int(object): def __init__(self, val): self.val = val myjitdriver = JitDriver(greens = ['pc'], reds = ['n', 'sa', 'i', 'j']) bytecode = "iajb+JI" def f(n): pc = sa = 0 i = j = Int(0) while pc < len(bytecode): myjitdriver.jit_merge_point(pc=pc, n=n, sa=sa, i=i, j=j) op = bytecode[pc] if op == 'i': i = Int(0) elif op == 'j': j = Int(0) elif op == '+': if i.val < n-8: sa += 7 if j.val < n-16: sa += 42 sa += i.val * j.val elif op == 'a': i = Int(i.val + 1) elif op == 'b': j = Int(j.val + 1) elif op == 'J': if j.val < n: pc -= 2 myjitdriver.can_enter_jit(pc=pc, n=n, sa=sa, i=i, j=j) continue elif op == 'I': if i.val < n: pc -= 5 myjitdriver.can_enter_jit(pc=pc, n=n, sa=sa, i=i, j=j) continue pc += 1 return sa res = self.meta_interp(f, [32]) assert res == f(32) self.check_aborted_count(0) self.check_target_token_count(3)
def test_cannot_merge(self): py.test.skip("unsupported") jitdriver = JitDriver(greens=[], reds=['n']) def f(n): while n > 0: jitdriver.can_enter_jit(n=n) jitdriver.jit_merge_point(n=n) lst = [] if n < 20: lst.append(n - 3) if n > 5: lst.append(n - 4) n = lst.pop() return n res = self.meta_interp(f, [30]) assert res == -1 self.check_all_virtualized()
def test_double_frame_array(self): myjitdriver = JitDriver(greens=[], reds=['n', 'xy2', 'other'], virtualizables=['xy2']) ARRAY = lltype.GcArray(lltype.Signed) def f(n): xy2 = self.setup2() xy2.inst_x = 10 xy2.inst_l1 = lltype.malloc(ARRAY, 1) xy2.inst_l1[0] = 1982731 xy2.inst_l2 = lltype.malloc(ARRAY, 1) xy2.inst_l2[0] = 10000 other = self.setup2() other.inst_x = 15 other.inst_l1 = lltype.malloc(ARRAY, 2) other.inst_l1[0] = 189182 other.inst_l1[1] = 58421 other.inst_l2 = lltype.malloc(ARRAY, 2) other.inst_l2[0] = 181 other.inst_l2[1] = 189 while n > 0: myjitdriver.can_enter_jit(xy2=xy2, n=n, other=other) myjitdriver.jit_merge_point(xy2=xy2, n=n, other=other) promote_virtualizable(other, 'inst_l2') length = len(other.inst_l2) # getfield_gc/arraylen_gc value = other.inst_l2[0] # getfield_gc/getarrayitem_gc other.inst_l2[ 0] = value + length # getfield_gc/setarrayitem_gc promote_virtualizable(xy2, 'inst_l2') xy2.inst_l2[0] = value + 100 # virtualized away n -= 1 promote_virtualizable(xy2, 'inst_l2') return xy2.inst_l2[0] expected = f(20) res = self.meta_interp(f, [20], optimizer=OPTIMIZER_SIMPLE) assert res == expected self.check_loops(getfield_gc=3, setfield_gc=0, arraylen_gc=1, getarrayitem_gc=1, setarrayitem_gc=1)
def test_vlist_with_default_read(self): jitdriver = JitDriver(greens=[], reds=['n']) def f(n): l = [1] * 20 while n > 0: jitdriver.can_enter_jit(n=n) jitdriver.jit_merge_point(n=n) l = [0] * 20 l[3] = 5 x = l[-17] + l[5] # that should be zero if n < 3: return x n -= 1 return l[0] res = self.meta_interp(f, [10], listops=True) assert res == f(10) self.check_loops(setarrayitem_gc=0, getarrayitem_gc=0, call=0)
def test_loop_kept_alive(self): myjitdriver = JitDriver(greens=[], reds=['n']) def g(): n = 10 while n > 0: myjitdriver.can_enter_jit(n=n) myjitdriver.jit_merge_point(n=n) n = n - 1 return 21 def f(): for i in range(15): g() return 42 res = self.meta_interp(f, [], loop_longevity=2) assert res == 42 # we should see only the loop and the entry bridge self.check_target_token_count(2)
def test_guard_failure_and_then_exception_in_inlined_function(self): from pypy.rpython.annlowlevel import hlstr def p(code, pc): code = hlstr(code) return "%s %d %s" % (code, pc, code[pc]) def c(code, pc): return "l" not in hlstr(code) myjitdriver = JitDriver(greens=['code', 'pc'], reds=['n', 'flag'], get_printable_location=p, can_inline=c) def f(code, n): pc = 0 flag = False while pc < len(code): myjitdriver.jit_merge_point(n=n, code=code, pc=pc, flag=flag) op = code[pc] if op == "-": n -= 1 elif op == "c": try: n = f("---ir---", n) except Exception: return n elif op == "i": if n < 200: flag = True elif op == "r": if flag: raise Exception elif op == "l": if n > 0: myjitdriver.can_enter_jit(n=n, code=code, pc=0, flag=flag) pc = 0 continue else: assert 0 pc += 1 return n def main(n): return f("c-l", n) print main(1000) res = self.meta_interp(main, [1000], optimizer=OPTIMIZER_SIMPLE, inline=True) assert res == main(1000)
def test_three_cases(self): class Node: def __init__(self, x): self.x = x myjitdriver = JitDriver(greens = [], reds = ['node']) def f(n): node = Node(n) while node.x > 0: myjitdriver.can_enter_jit(node=node) myjitdriver.jit_merge_point(node=node) if node.x < 40: if node.x < 20: node = Node(node.x - 1) node = Node(node.x - 1) node = Node(node.x - 1) return node.x res = self.meta_interp(f, [55]) assert res == f(55) self.check_tree_loop_count(2)
def test_indirect_calls_not_followed(self): myjitdriver = JitDriver(greens = [], reds = ['n']) def h(): return 42 def g(): return h() def f(n): while n > 0: myjitdriver.can_enter_jit(n=n) myjitdriver.jit_merge_point(n=n) n -= 1 if n < 0: call = h else: call = g return call() res = self.meta_interp(f, [7]) assert res == 42 assert self.seen_frames == ['f', 'f']
def test_is_not_virtual_none(self): myjitdriver = JitDriver(greens=[], reds=['n', 'res1']) @dont_look_inside def residual(vref): return vref.virtual # def f(n): res1 = -42 while n > 0: myjitdriver.jit_merge_point(n=n, res1=res1) res1 = residual(vref_None) n -= 1 return res1 # res = self.meta_interp(f, [10]) assert res == 0
def test_bug_1(self): myjitdriver = JitDriver(greens=[], reds=['n', 'i', 'stack']) def opaque(n, i): if n == 1 and i == 19: for j in range(20): res = f(0) # recurse repeatedly, 20 times assert res == 0 def f(n): stack = [n] i = 0 while i < 20: myjitdriver.can_enter_jit(n=n, i=i, stack=stack) myjitdriver.jit_merge_point(n=n, i=i, stack=stack) opaque(n, i) i += 1 return stack.pop() res = self.meta_interp(f, [1], enable_opts='', repeat=2, policy=StopAtXPolicy(opaque)) assert res == 1
def test_directly_call_assembler_return(self): driver = JitDriver(greens=['codeno'], reds=['i', 'k'], get_printable_location=lambda codeno: str(codeno), can_inline=lambda codeno: False) def portal(codeno): i = 0 k = codeno while i < 10: driver.can_enter_jit(codeno=codeno, i=i, k=k) driver.jit_merge_point(codeno=codeno, i=i, k=k) if codeno == 2: k = portal(1) i += 1 return k self.meta_interp(portal, [2], inline=True) self.check_history(call_assembler=1)
def test_simple_force_always(self): myjitdriver = JitDriver(greens=[], reds=['n']) # A = lltype.GcArray(lltype.Signed) class XY: pass class ExCtx: pass exctx = ExCtx() # @dont_look_inside def externalfn(n): m = exctx.topframeref().n assert m == n return 1 # def f(n): while n > 0: myjitdriver.can_enter_jit(n=n) myjitdriver.jit_merge_point(n=n) xy = XY() xy.next1 = lltype.malloc(A, 0) xy.next2 = lltype.malloc(A, 0) xy.next3 = lltype.malloc(A, 0) xy.n = n exctx.topframeref = vref = virtual_ref(xy) n -= externalfn(n) xy.next1 = lltype.nullptr(A) xy.next2 = lltype.nullptr(A) xy.next3 = lltype.nullptr(A) virtual_ref_finish(vref, xy) exctx.topframeref = vref_None # self.meta_interp(f, [15]) self.check_resops( new_with_vtable=4, # XY(), the vref new_array=6) # next1/2/3 self.check_aborted_count(0)
def test_generalize_loop(self): myjitdriver = JitDriver(greens=[], reds = ['i', 'obj']) class A: def __init__(self, n): self.n = n def extern(obj): pass def fn(i): obj = A(1) while i > 0: myjitdriver.can_enter_jit(i=i, obj=obj) myjitdriver.jit_merge_point(i=i, obj=obj) obj = A(obj.n + 1) if i < 10: extern(obj) i -= 1 return obj.n res = self.meta_interp(fn, [20], policy=StopAtXPolicy(extern)) assert res == 21
def get_interpreter(self, codes, always_inline=False): ADD = "0" JUMP_BACK = "1" CALL = "2" EXIT = "3" if always_inline: def can_inline(*args): return True else: def can_inline(code, i): code = hlstr(code) return not JUMP_BACK in code jitdriver = JitDriver(greens=['code', 'i'], reds=['n'], can_inline=can_inline) def interpret(codenum, n, i): code = codes[codenum] while i < len(code): jitdriver.jit_merge_point(n=n, i=i, code=code) op = code[i] if op == ADD: n += 1 i += 1 elif op == CALL: n = interpret(1, n, 1) i += 1 elif op == JUMP_BACK: if n > 20: return 42 i -= 2 jitdriver.can_enter_jit(n=n, i=i, code=code) elif op == EXIT: return n else: raise NotImplementedError return n return interpret
def test_append_pop(self): py.test.skip("unsupported") jitdriver = JitDriver(greens=[], reds=['n']) def f(n): while n > 0: jitdriver.can_enter_jit(n=n) jitdriver.jit_merge_point(n=n) lst = [] lst.append(5) lst.append(n) lst[0] -= len(lst) three = lst[0] n = lst.pop() - three return n res = self.meta_interp(f, [31]) assert res == -2 self.check_all_virtualized()
def test_alternating_loops(self): myjitdriver = JitDriver(greens=[], reds=['pattern']) def f(pattern): while pattern > 0: myjitdriver.can_enter_jit(pattern=pattern) myjitdriver.jit_merge_point(pattern=pattern) if pattern & 1: pass else: pass pattern >>= 1 return 42 self.meta_interp(f, [0xF0F0F0]) if self.enable_opts: self.check_trace_count(3) else: self.check_trace_count(2)
def test_nested_loops_1(self): class Int(object): def __init__(self, val): self.val = val bytecode = "iajb+JI" def get_printable_location(i): return "%d: %s" % (i, bytecode[i]) myjitdriver = JitDriver(greens = ['pc'], reds = ['n', 'sa', 'i', 'j'], get_printable_location=get_printable_location) def f(n): pc = sa = 0 i = j = Int(0) while pc < len(bytecode): myjitdriver.jit_merge_point(pc=pc, n=n, sa=sa, i=i, j=j) op = bytecode[pc] if op == 'i': i = Int(0) elif op == 'j': j = Int(0) elif op == '+': sa += (i.val + 2) * (j.val + 2) elif op == 'a': i = Int(i.val + 1) elif op == 'b': j = Int(j.val + 1) elif op == 'J': if j.val < n: pc -= 2 myjitdriver.can_enter_jit(pc=pc, n=n, sa=sa, i=i, j=j) continue elif op == 'I': if i.val < n: pc -= 5 myjitdriver.can_enter_jit(pc=pc, n=n, sa=sa, i=i, j=j) continue pc += 1 return sa res = self.meta_interp(f, [10]) assert res == f(10) self.check_aborted_count(0) self.check_target_token_count(3) self.check_resops(int_mul=2)
def test_argument_order_ok(self): myjitdriver = JitDriver(greens=['i1', 'r1', 'f1'], reds=[]) class A(object): pass myjitdriver.jit_merge_point(i1=42, r1=A(), f1=3.5)
def test_argument_order_accept_r_uint(self): # this used to fail on 64-bit, because r_uint == r_ulonglong myjitdriver = JitDriver(greens=['i1'], reds=[]) myjitdriver.jit_merge_point(i1=r_uint(42))