def test_inactive_jitdriver(self): myjitdriver1 = JitDriver(greens=[], reds=['n', 'm'], get_printable_location = getloc1) myjitdriver2 = JitDriver(greens=['g'], reds=['r'], get_printable_location = getloc2) # myjitdriver1.active = False # <=== # def loop1(n, m): while n > 0: myjitdriver1.can_enter_jit(n=n, m=m) myjitdriver1.jit_merge_point(n=n, m=m) n -= m return n # def loop2(g, r): while r > 0: myjitdriver2.can_enter_jit(g=g, r=r) myjitdriver2.jit_merge_point(g=g, r=r) r += loop1(r, g) + (-1) return r # res = self.meta_interp(loop2, [4, 40], repeat=7, inline=True) assert res == loop2(4, 40) # we expect no int_sub, but a residual call self.check_resops(call_i=2, int_sub=0)
def test_inline_jit_merge_point(self): py.test.skip("fix the test if you want to re-enable this") # test that the machinery to inline jit_merge_points in callers # works. The final user does not need to mess manually with the # _inline_jit_merge_point_ attribute and similar, it is all nicely # handled by @JitDriver.inline() (see next tests) myjitdriver = JitDriver(greens = ['a'], reds = 'auto') def jit_merge_point(a, b): myjitdriver.jit_merge_point(a=a) def add(a, b): jit_merge_point(a, b) return a+b add._inline_jit_merge_point_ = jit_merge_point myjitdriver.inline_jit_merge_point = True def calc(n): res = 0 while res < 1000: res = add(n, res) return res def f(): return calc(1) + calc(3) res = self.meta_interp(f, []) assert res == 1000 + 1002 self.check_resops(int_add=4)
def test_argument_order_more_precision_later_2(self): myjitdriver = JitDriver(greens=['r1', 'i1', 'r2', 'f1'], reds=[]) class A(object): pass myjitdriver.jit_merge_point(i1=42, r1=None, r2=A(), f1=3.5) e = py.test.raises(AssertionError, myjitdriver.jit_merge_point, i1=42, r1=A(), r2=None, f1=3.5) assert "got ['2:REF', '1:INT', '2:REF', '3:FLOAT']" in repr(e.value)
def __init__(self, name, debugprint, **kwds): JitDriver.__init__(self, name='rsre_' + name, **kwds) # def get_printable_location(*args): # we print based on indices in 'args'. We first print # 'ctx.pattern' from the arg number debugprint[0]. pattern = args[debugprint[0]] s = str(pattern) if len(s) > 120: s = s[:110] + '...' if len(debugprint) > 1: # then we print numbers from the args number # debugprint[1] and possibly debugprint[2] info = ' at %d' % (args[debugprint[1]],) if len(debugprint) > 2: info = '%s/%d' % (info, args[debugprint[2]]) else: info = '' return 're %s%s %s' % (name, info, s) # self.get_printable_location = get_printable_location
def test_indirect_call_unknown_object_2(self): myjitdriver = JitDriver(greens=[], reds=['x', 'y', 'state']) def getvalue2(): return 2 def getvalue25(): return 25 def getvalue1001(): return -1001 class State: count = 0 def externfn(self, n): assert n == 198 - self.count self.count += 1 if n % 5: return getvalue2 elif n % 7: return getvalue25 else: return getvalue1001 def f(y): state = State() x = 0 while y > 0: myjitdriver.can_enter_jit(x=x, y=y, state=state) myjitdriver.jit_merge_point(x=x, y=y, state=state) x += state.externfn(y)() y -= 1 return x res = self.meta_interp(f, [198], policy=StopAtXPolicy(State.externfn.im_func)) assert res == f(198) # we get two TargetTokens, one for the loop and one for the preamble self.check_jitcell_token_count(1) self.check_target_token_count(2)
def test_alloc_virtualref_and_then_alloc_structure(self): myjitdriver = JitDriver(greens=[], reds=['n']) # class XY: pass class ExCtx: pass exctx = ExCtx() @dont_look_inside def escapexy(xy): print 'escapexy:', xy.n if xy.n % 5 == 0: vr = exctx.vr print 'accessing via vr:', vr() assert vr() is xy # def f(n): while n > 0: myjitdriver.jit_merge_point(n=n) xy = XY() xy.n = n vr = virtual_ref(xy) # force the virtualref to be allocated exctx.vr = vr # force xy to be allocated escapexy(xy) # clean up exctx.vr = vref_None virtual_ref_finish(vr, xy) n -= 1 return 1 # res = self.meta_interp(f, [15]) assert res == 1 self.check_resops(new_with_vtable=4) # vref, xy
def test_automatic_promotion(self): myjitdriver = JitDriver(greens = ['i'], reds = ['res', 'a']) CO_INCREASE = 0 CO_JUMP_BACK_3 = 1 code = [CO_INCREASE, CO_INCREASE, CO_INCREASE, CO_JUMP_BACK_3, CO_INCREASE] def add(res, a): return res + a def sub(res, a): return res - a def main_interpreter_loop(a): i = 0 res = 0 c = len(code) while True: myjitdriver.jit_merge_point(res=res, i=i, a=a) if i >= c: break elem = code[i] if elem == CO_INCREASE: i += a res += a else: if res > 100: i += 1 else: i = i - 3 myjitdriver.can_enter_jit(res=res, i=i, a=a) return res res = self.meta_interp(main_interpreter_loop, [1]) assert res == main_interpreter_loop(1) self.check_trace_count(1) # These loops do different numbers of ops based on which optimizer we # are testing with. self.check_resops(self.automatic_promotion_result)
def test_unroll_issue_1(self): class A(object): _attrs_ = [] def checkcls(self): raise NotImplementedError class B(A): def __init__(self, b_value): self.b_value = b_value def get_value(self): return self.b_value def checkcls(self): return self.b_value @dont_look_inside def check(a): return isinstance(a, B) jitdriver = JitDriver(greens=[], reds='auto') def f(a, xx): i = 0 total = 0 while i < 10: jitdriver.jit_merge_point() if check(a): if xx & 1: total *= a.checkcls() total += a.get_value() i += 1 return total def run(n): bt = f(B(n), 1) bt = f(B(n), 2) at = f(A(), 3) return at * 100000 + bt assert run(42) == 420 res = self.meta_interp(run, [42], backendopt=True) assert res == 420
def test_ordered_dict_two_lookups(self): driver = JitDriver(greens=[], reds='auto') d = OrderedDict() d['a'] = 3 d['b'] = 4 indexes = ['a', 'b'] def f(n): s = 0 while n > 0: driver.jit_merge_point() s += d[indexes[n & 1]] s += d[indexes[n & 1]] n -= 1 return s self.meta_interp(f, [10]) # XXX should be one getinteriorfield_gc. At least it's one call. self.check_simple_loop(call_i=1, getinteriorfield_gc_i=2, guard_no_exception=1)
def test_loop_automatic_reds(self): myjitdriver = JitDriver(greens = ['m'], reds = 'auto') def f(n, m): res = 0 # try to have lots of red vars, so that if there is an error in # the ordering of reds, there are low chances that the test passes # by chance a = b = c = d = n while n > 0: myjitdriver.jit_merge_point(m=m) n -= 1 a += 1 # dummy unused red b += 2 # dummy unused red c += 3 # dummy unused red d += 4 # dummy unused red res += m*2 return res expected = f(21, 5) res = self.meta_interp(f, [21, 5]) assert res == expected self.check_resops(int_sub=2, int_mul=0, int_add=10)
def test_virtualized1(self): myjitdriver = JitDriver(greens=[], reds=['n', 'node']) def f(n): node = self._new() node.value = 0 node.extra = 0 while n > 0: myjitdriver.can_enter_jit(n=n, node=node) myjitdriver.jit_merge_point(n=n, node=node) next = self._new() next.value = node.value + n next.extra = node.extra + 1 node = next n -= 1 return node.value * node.extra assert f(10) == 55 * 10 res = self.meta_interp(f, [10]) assert res == 55 * 10 self.check_trace_count(1) self.check_resops(new_with_vtable=0, setfield_gc=0, getfield_gc=2, new=0)
def test_virtualized2(self): myjitdriver = JitDriver(greens=[], reds=['n', 'node1', 'node2']) def f(n): node1 = self._new() node1.value = 0 node2 = self._new() node2.value = 0 while n > 0: myjitdriver.can_enter_jit(n=n, node1=node1, node2=node2) myjitdriver.jit_merge_point(n=n, node1=node1, node2=node2) next1 = self._new() next1.value = node1.value + n + node2.value next2 = self._new() next2.value = next1.value node1 = next1 node2 = next2 n -= 1 return node1.value * node2.value assert f(10) == self.meta_interp(f, [10]) self.check_resops(new_with_vtable=0, setfield_gc=0, getfield_gc=2, new=0)
def test_simple_loop_with_call(self): @dont_look_inside def g(n): pass myjitdriver = JitDriver(greens=[], reds=['x', 'y', 'res']) def f(x, y): res = 0 while y > 0: myjitdriver.can_enter_jit(x=x, y=y, res=res) myjitdriver.jit_merge_point(x=x, y=y, res=res) res += x g(x) y -= 1 return res * 2 res = self.meta_interp(f, [6, 7]) assert res == 84 profiler = pyjitpl._warmrunnerdesc.metainterp_sd.profiler assert profiler.calls == 1
def test_stringbuilder_append_len2_2(self): jitdriver = JitDriver(reds=['n', 'str1'], greens=[]) def f(n): str1 = str(n) while n > 0: jitdriver.jit_merge_point(n=n, str1=str1) sb = StringBuilder(4) sb.append("a") sb.append(str1) s = sb.build() if len(s) != 3: raise ValueError if s[0] != "a": raise ValueError if s[1] != "1": raise ValueError if s[2] != "0": raise ValueError n -= 1 return n res = self.meta_interp(f, [10], backendopt=True) assert res == 0 self.check_resops(call_n=2, call_r=2, # (ll_append_res0, ll_build) * 2 unroll cond_call=0)
def test_two_behaviors(self): myjitdriver = JitDriver(greens=[], reds=['y', 'x']) class Int: def __init__(self, value): self.value = value cases = [True] * 100 + [False, True] * 10 + [False] * 20 def f(y): x = Int(0) while y > 0: myjitdriver.can_enter_jit(x=x, y=y) myjitdriver.jit_merge_point(x=x, y=y) y -= 1 if cases[y]: x = Int(x.value + 1) return x.value res = self.meta_interp(f, [len(cases)]) assert res == 110
def test_constant_virtual2(self): myjitdriver = JitDriver(greens=[], reds=['n', 'sa', 'node']) def f(n): node = self._new() node.value = 1 sa = 0 while n > 0: myjitdriver.can_enter_jit(n=n, sa=sa, node=node) myjitdriver.jit_merge_point(n=n, sa=sa, node=node) sa += node.value if n & 15 > 7: next = self._new() next.value = 2 node = next else: next = self._new() next.value = 3 node = next n -= 1 return sa assert self.meta_interp(f, [31]) == f(31)
def test_stringbuilder_append_empty(self): jitdriver = JitDriver(reds=['n'], greens=[]) def f(n): while n > 0: jitdriver.jit_merge_point(n=n) sb = UnicodeBuilder() sb.append(u"") s = sb.build() if len(s) != 0: raise ValueError n -= 1 return n res = self.meta_interp(f, [10], backendopt=True) assert res == 0 self.check_resops({ 'int_sub': 2, 'int_gt': 2, 'guard_true': 2, 'jump': 1 })
def test_strconcat_guard_fail(self): _str = self._str jitdriver = JitDriver(greens=[], reds=['m', 'n']) @dont_look_inside def escape(x): pass mylist = [_str("abc") + _str(i) for i in range(12)] def f(n, m): while m >= 0: jitdriver.can_enter_jit(m=m, n=n) jitdriver.jit_merge_point(m=m, n=n) s = mylist[n] + mylist[m] if m & 1: escape(s) m -= 1 return 42 self.meta_interp(f, [6, 10])
def test_eq_folded(self): _str = self._str jitdriver = JitDriver(greens=['s'], reds=['n', 'i']) global_s = _str("hello") def f(n, b, s): if b: s += _str("ello") else: s += _str("allo") i = 0 while n > 0: jitdriver.can_enter_jit(s=s, n=n, i=i) jitdriver.jit_merge_point(s=s, n=n, i=i) n -= 1 + (s == global_s) i += 1 return i res = self.meta_interp(f, [10, True, _str('h')], listops=True) assert res == 5 self.check_resops(**{self.CALL: 0, self.CALL_PURE: 0})
def compile_boehm_test(): myjitdriver = JitDriver(greens=[], reds=['n', 'x']) @dont_look_inside def see(lst, n): assert len(lst) == 3 assert lst[0] == n + 10 assert lst[1] == n + 20 assert lst[2] == n + 30 def main(n, x): while n > 0: myjitdriver.can_enter_jit(n=n, x=x) myjitdriver.jit_merge_point(n=n, x=x) y = X() y.foo = x.foo n -= y.foo see([n + 10, n + 20, n + 30], n) res = compile_and_run(get_entry(get_g(main)), "boehm", jit=True) assert int(res) >= 16
def test_dict_array_write_invalidates_caches(self): driver = JitDriver(greens=[], reds='auto') indexes = ['aa', 'b', 'cc'] def f(n): d = {'aa': 3, 'b': 4, 'cc': 5} s = 0 while n > 0: driver.jit_merge_point() index = indexes[n & 1] s += d[index] del d['cc'] s += d[index] d['cc'] = 3 n -= 1 return s exp = f(10) res = self.meta_interp(f, [10]) assert res == exp self.check_simple_loop(call_i=4, cond_call_value_i=1, call_n=2)
def test_exception_from_outside(self): myjitdriver = JitDriver(greens=[], reds=['n']) def check(n, mode): if mode == 0 and n > -100: raise MyError(n) return n - 5 def f(n): while n > 0: myjitdriver.can_enter_jit(n=n) myjitdriver.jit_merge_point(n=n) try: check(n, 0) except MyError as e: n = check(e.n, 1) return n assert f(53) == -2 res = self.meta_interp(f, [53], policy=StopAtXPolicy(check)) assert res == -2
def test_strslice(self): _str = self._str longstring = _str("foobarbazetc") jitdriver = JitDriver(greens=[], reds=['m', 'n']) @dont_look_inside def escape(x): pass def f(n, m): assert n >= 0 while m >= 0: jitdriver.can_enter_jit(m=m, n=n) jitdriver.jit_merge_point(m=m, n=n) s = longstring[m:n] if m <= 5: escape(s) m -= 1 return 42 self.meta_interp(f, [10, 10])
def test_single_virtual_forced_in_bridge(self): myjitdriver = JitDriver(greens=[], reds=['n', 's', 'node']) def externfn(node): node.value *= 2 def f(n, s): node = self._new() node.value = 1 while n > 0: myjitdriver.can_enter_jit(n=n, s=s, node=node) myjitdriver.jit_merge_point(n=n, s=s, node=node) next = self._new() next.value = node.value + 1 node = next if (n >> s) & 1: externfn(node) n -= 1 return node.value res = self.meta_interp(f, [48, 3], policy=StopAtXPolicy(externfn)) assert res == f(48, 3) res = self.meta_interp(f, [40, 3], policy=StopAtXPolicy(externfn)) assert res == f(40, 3)
def test_compare_single_char_for_ordering(self): jitdriver = JitDriver(reds=['result', 'n'], greens=[]) _str = self._str constant1 = _str("abcdefghij") def cmpstr(x, y): return x > _str(y) def f(n): cmpstr(_str("abc"), "def") # force x and y to be annot as strings result = 0 while n >= 0: jitdriver.jit_merge_point(n=n, result=result) c = constant1[n] result += cmpstr(c, "c") n -= 1 return result res = self.meta_interp(f, [9]) assert res == f(9) self.check_resops(newstr=0, newunicode=0, call=0)
def test_blackhole_pure(self): @elidable def g(n): return n + 1 myjitdriver = JitDriver(greens=['z'], reds=['y', 'x', 'res']) def f(x, y, z): res = 0 while y > 0: myjitdriver.can_enter_jit(x=x, y=y, res=res, z=z) myjitdriver.jit_merge_point(x=x, y=y, res=res, z=z) res += x res += g(z) y -= 1 return res * 2 res = self.meta_interp(f, [6, 7, 2]) assert res == f(6, 7, 2) profiler = pyjitpl._warmrunnerdesc.metainterp_sd.profiler assert profiler.calls == 1
def test_char2string2char(self): _str, _chr = self._str, self._chr jitdriver = JitDriver(greens = [], reds = ['m', 'total']) def f(m): total = 0 while m > 0: jitdriver.can_enter_jit(m=m, total=total) jitdriver.jit_merge_point(m=m, total=total) string = _chr(m) if m > 100: string += string # forces to be a string # read back the character c = string[0] total += ord(c) m -= 1 return total res = self.meta_interp(f, [6]) assert res == 21 self.check_resops(newstr=0, strgetitem=0, strsetitem=0, strlen=0, newunicode=0, unicodegetitem=0, unicodesetitem=0, unicodelen=0)
def test_strconcat_escape_char_char(self): _str, _chr = self._str, self._chr jitdriver = JitDriver(greens = [], reds = ['m', 'n']) @dont_look_inside def escape(x): pass def f(n, m): while m >= 0: jitdriver.can_enter_jit(m=m, n=n) jitdriver.jit_merge_point(m=m, n=n) s = _chr(n) + _chr(m) escape(s) m -= 1 return 42 self.meta_interp(f, [6, 7]) if _str is str: self.check_resops(call_pure=0, copystrcontent=0, strsetitem=4, call=2, newstr=2) else: self.check_resops(call_pure=0, unicodesetitem=4, call=2, copyunicodecontent=0, newunicode=2)
def test_simple_recursion(self): myjitdriver = JitDriver(greens=[], reds=['n', 'm']) def f(n): m = n - 2 while True: myjitdriver.jit_merge_point(n=n, m=m) n -= 1 if m == n: return main(n) * 2 myjitdriver.can_enter_jit(n=n, m=m) def main(n): if n > 0: return f(n + 1) else: return 1 res = self.meta_interp(main, [20], enable_opts='') assert res == main(20) self.check_history(call=0)
def test_inline_trace_limit(self): myjitdriver = JitDriver(greens=[], reds=['n']) def recursive(n): if n > 0: return recursive(n - 1) + 1 return 0 def loop(n): set_param(myjitdriver, "threshold", 10) pc = 0 while n: myjitdriver.can_enter_jit(n=n) myjitdriver.jit_merge_point(n=n) n = recursive(n) n -= 1 return n TRACE_LIMIT = 66 res = self.meta_interp(loop, [100], enable_opts='', inline=True, trace_limit=TRACE_LIMIT) assert res == 0 self.check_max_trace_length(TRACE_LIMIT) self.check_enter_count_at_most(10) # maybe self.check_aborted_count(6)
def test_path_with_operations_not_from_start(self): jitdriver = JitDriver(greens=['k'], reds=['n', 'z']) def f(n): k = 0 z = 0 while n > 0: jitdriver.can_enter_jit(n=n, k=k, z=z) jitdriver.jit_merge_point(n=n, k=k, z=z) k += 1 if k == 30: if z == 0 or z == 1: k = 4 z += 1 else: k = 15 z = 0 n -= 1 return 42 res = self.meta_interp(f, [200])
def test_int_lshift_ovf(self): myjitdriver = JitDriver(greens=[], reds=['n', 'x', 'y', 'm']) def f(x, y, n): m = 0 while n < 100: myjitdriver.can_enter_jit(n=n, x=x, y=y, m=m) myjitdriver.jit_merge_point(n=n, x=x, y=y, m=m) y += 1 y &= (LONG_BIT - 1) try: ovfcheck(x << y) except OverflowError: m += 1 n += 1 return m res = self.meta_interp(f, [1, 1, 0], enable_opts='') assert res == f(1, 1, 0) res = self.meta_interp(f, [809644098, 16, 0], enable_opts='') assert res == f(809644098, 16, 0)
def test_virtual_array_with_nulls(self): class Foo: pass myjitdriver = JitDriver(greens=[], reds=['n', 'node']) def f(n): node = [None, Foo()] while n > 0: myjitdriver.can_enter_jit(n=n, node=node) myjitdriver.jit_merge_point(n=n, node=node) newnode = [None] * 2 if (n >> 3) & 1: newnode[1] = node[1] else: newnode[1] = node[1] node = newnode n -= 1 return 42 assert self.meta_interp(f, [40]) == 42
def test_raw_malloc_only_chars(self): mydriver = JitDriver(greens=[], reds='auto') def f(n): i = 0 res = 0 while i < n: mydriver.jit_merge_point() # this is not virtualized because it's not a buffer of chars buffer = lltype.malloc(rffi.LONGP.TO, 1, flavor='raw') buffer[0] = i + 1 res += buffer[0] i = buffer[0] lltype.free(buffer, flavor='raw') return res assert f(10) == 55 res = self.meta_interp(f, [10]) assert res == 55 self.check_trace_count(1) self.check_resops(setarrayitem_raw=2, getarrayitem_raw_i=4)
def test_jitdriver_clone(): py.test.skip("@inline off: see skipped failures in test_warmspot.") def bar(): pass def foo(): pass driver = JitDriver(greens=[], reds=[]) py.test.raises(AssertionError, "driver.inline(bar)(foo)") # driver = JitDriver(greens=[], reds='auto') py.test.raises(AssertionError, "driver.clone()") foo = driver.inline(bar)(foo) assert foo._inline_jit_merge_point_ == bar # driver.foo = 'bar' driver2 = driver.clone() assert driver is not driver2 assert driver2.foo == 'bar' driver.foo = 'xxx' assert driver2.foo == 'bar'
def test_argument_order_ok(self): myjitdriver = JitDriver(greens=['i1', 'r1', 'f1'], reds=[]) class A(object): pass myjitdriver.jit_merge_point(i1=42, r1=A(), f1=3.5)
def test_argument_order_accept_r_uint(self): # this used to fail on 64-bit, because r_uint == r_ulonglong myjitdriver = JitDriver(greens=['i1'], reds=[]) myjitdriver.jit_merge_point(i1=r_uint(42))