def test_location(self): def get_printable_location(n): return 'GREEN IS %d.' % n myjitdriver = JitDriver(greens=['n'], reds=['m'], get_printable_location=get_printable_location) def f(n, m): while m > 0: myjitdriver.can_enter_jit(n=n, m=m) myjitdriver.jit_merge_point(n=n, m=m) m -= 1 self.meta_interp(f, [123, 10]) assert len(get_stats().locations) >= 4 for loc in get_stats().locations: assert loc == (0, 123)
def test_loop(self): myjitdriver = JitDriver(greens=[], reds=['x', 'y', 'res']) def f(x, y): res = 0 while y > 0: myjitdriver.can_enter_jit(x=x, y=y, res=res) myjitdriver.jit_merge_point(x=x, y=y, res=res) res += x y -= 1 return res res = self.meta_interp(f, [6, 7]) assert res == 42 self.check_loop_count(1) self.check_loops({ 'guard_true': 1, 'int_add': 1, 'int_sub': 1, 'int_gt': 1, 'jump': 1 }) if self.basic: found = 0 for op in get_stats().loops[0]._all_operations(): if op.getopname() == 'guard_true': liveboxes = op.fail_args assert len(liveboxes) == 3 for box in liveboxes: assert isinstance(box, history.BoxInt) found += 1 assert found == 1
def check_loop_count(self, count): """NB. This is a hack; use check_tree_loop_count() or check_enter_count() for the real thing. This counts as 1 every bridge in addition to every loop; and it does not count at all the entry bridges from interpreter, although they are TreeLoops as well.""" assert get_stats().compiled_count == count
def test_inline(self): # this is not an example of reasonable code: loop1() is unrolled # 'n/m' times, where n and m are given as red arguments. myjitdriver1 = JitDriver(greens=[], reds=["n", "m"], get_printable_location=getloc1) myjitdriver2 = JitDriver(greens=["g"], reds=["r"], get_printable_location=getloc2) # def loop1(n, m): while n > 0: if n > 1000: myjitdriver1.can_enter_jit(n=n, m=m) myjitdriver1.jit_merge_point(n=n, m=m) n -= m return n # def loop2(g, r): set_param(None, "function_threshold", 0) while r > 0: myjitdriver2.can_enter_jit(g=g, r=r) myjitdriver2.jit_merge_point(g=g, r=r) r += loop1(r, g) - 1 return r # res = self.meta_interp(loop2, [4, 40], repeat=7, inline=True) assert res == loop2(4, 40) # we expect no loop at all for 'loop1': it should always be inlined # we do however get several version of 'loop2', all of which contains # at least one int_add, while there are no int_add's in 'loop1' self.check_tree_loop_count(5) for loop in get_stats().loops: assert loop.summary()["int_add"] >= 1
def test_list_pass_around(self): py.test.skip("think about a way to fix it") myjitdriver = JitDriver(greens=['foo'], reds=['x', 'total']) class Foo: _immutable_fields_ = ['lst?[*]'] def __init__(self, lst): self.lst = lst def g(lst): # here, 'lst' is statically annotated as a "modified" list, # so the following doesn't generate a getarrayitem_gc_pure... return lst[1] def f(a, x): lst1 = [0, 0] g(lst1) lst1[1] = a foo = Foo(lst1) total = 0 while x > 0: myjitdriver.jit_merge_point(foo=foo, x=x, total=total) # read a quasi-immutable field out of a Constant total += g(foo.lst) x -= 1 return total # res = self.meta_interp(f, [100, 7]) assert res == 700 self.check_loops(guard_not_invalidated=2, getfield_gc=0, getarrayitem_gc=0, getarrayitem_gc_pure=0, everywhere=True) # from pypy.jit.metainterp.warmspot import get_stats loops = get_stats().loops for loop in loops: assert len(loop.quasi_immutable_deps) == 1 assert isinstance(loop.quasi_immutable_deps.keys()[0], QuasiImmut)
def test_list_simple_1(self): myjitdriver = JitDriver(greens=["foo"], reds=["x", "total"]) class Foo: _immutable_fields_ = ["lst?[*]"] def __init__(self, lst): self.lst = lst def f(a, x): lst1 = [0, 0] lst1[1] = a foo = Foo(lst1) total = 0 while x > 0: myjitdriver.jit_merge_point(foo=foo, x=x, total=total) # read a quasi-immutable field out of a Constant total += foo.lst[1] x -= 1 return total # res = self.meta_interp(f, [100, 7]) assert res == 700 self.check_resops(getarrayitem_gc_pure=0, guard_not_invalidated=2, getarrayitem_gc=0, getfield_gc=0) # from pypy.jit.metainterp.warmspot import get_stats loops = get_stats().loops for loop in loops: assert len(loop.quasi_immutable_deps) == 1 assert isinstance(loop.quasi_immutable_deps.keys()[0], QuasiImmut)
def test_nonopt_1(self): myjitdriver = JitDriver(greens=[], reds=['x', 'total', 'lst']) class Foo: _immutable_fields_ = ['a?'] def __init__(self, a): self.a = a def setup(x): return [Foo(100 + i) for i in range(x)] def f(a, x): lst = setup(x) total = 0 while x > 0: myjitdriver.jit_merge_point(lst=lst, x=x, total=total) # read a quasi-immutable field out of a variable x -= 1 total += lst[x].a return total # assert f(100, 7) == 721 res = self.meta_interp(f, [100, 7]) assert res == 721 self.check_loops(guard_not_invalidated=0, getfield_gc=1) # from pypy.jit.metainterp.warmspot import get_stats loops = get_stats().loops for loop in loops: assert loop.quasi_immutable_deps is None
def test_simple_1(self): myjitdriver = JitDriver(greens=['foo'], reds=['x', 'total']) class Foo: _immutable_fields_ = ['a?'] def __init__(self, a): self.a = a def f(a, x): foo = Foo(a) total = 0 while x > 0: myjitdriver.jit_merge_point(foo=foo, x=x, total=total) # read a quasi-immutable field out of a Constant total += foo.a x -= 1 return total # res = self.meta_interp(f, [100, 7]) assert res == 700 self.check_loops(guard_not_invalidated=2, getfield_gc=0, everywhere=True) # from pypy.jit.metainterp.warmspot import get_stats loops = get_stats().loops for loop in loops: assert len(loop.quasi_immutable_deps) == 1 assert isinstance(loop.quasi_immutable_deps.keys()[0], QuasiImmut)
def test_loop(self): myjitdriver = JitDriver(greens = [], reds = ['x', 'y', 'res']) def f(x, y): res = 0 while y > 0: myjitdriver.can_enter_jit(x=x, y=y, res=res) myjitdriver.jit_merge_point(x=x, y=y, res=res) res += x y -= 1 return res res = self.meta_interp(f, [6, 7]) assert res == 42 self.check_loop_count(1) self.check_loops({'guard_true': 1, 'int_add': 1, 'int_sub': 1, 'int_gt': 1, 'jump': 1}) if self.basic: found = 0 for op in get_stats().loops[0]._all_operations(): if op.getopname() == 'guard_true': liveboxes = op.fail_args assert len(liveboxes) == 3 for box in liveboxes: assert isinstance(box, history.BoxInt) found += 1 assert found == 1
def test_simple_1(self): myjitdriver = JitDriver(greens=['foo'], reds=['x', 'total']) class Foo: _immutable_fields_ = ['a?'] def __init__(self, a): self.a = a def f(a, x): foo = Foo(a) total = 0 while x > 0: myjitdriver.jit_merge_point(foo=foo, x=x, total=total) # read a quasi-immutable field out of a Constant total += foo.a x -= 1 return total # res = self.meta_interp(f, [100, 7]) assert res == 700 self.check_resops(guard_not_invalidated=2, getfield_gc=0) # from pypy.jit.metainterp.warmspot import get_stats loops = get_stats().loops for loop in loops: assert len(loop.quasi_immutable_deps) == 1 assert isinstance(loop.quasi_immutable_deps.keys()[0], QuasiImmut)
def test_nonopt_1(self): myjitdriver = JitDriver(greens=[], reds=['x', 'total', 'lst']) class Foo: _immutable_fields_ = ['a?'] def __init__(self, a): self.a = a def setup(x): return [Foo(100 + i) for i in range(x)] def f(a, x): lst = setup(x) total = 0 while x > 0: myjitdriver.jit_merge_point(lst=lst, x=x, total=total) # read a quasi-immutable field out of a variable x -= 1 total += lst[x].a return total # assert f(100, 7) == 721 res = self.meta_interp(f, [100, 7]) assert res == 721 self.check_resops(guard_not_invalidated=0, getfield_gc=3) # from pypy.jit.metainterp.warmspot import get_stats loops = get_stats().loops for loop in loops: assert loop.quasi_immutable_deps is None
def check_max_trace_length(self, length): for loop in get_stats().loops: assert len( loop.operations ) <= length + 5 # because we only check once per metainterp bytecode for op in loop.operations: if op.is_guard() and hasattr(op.descr, '_debug_suboperations'): assert len(op.descr._debug_suboperations) <= length + 5
def test_location(self): def get_printable_location(n): return 'GREEN IS %d.' % n myjitdriver = JitDriver(greens=['n'], reds=['m'], get_printable_location=get_printable_location) def f(n, m): while m > 0: myjitdriver.can_enter_jit(n=n, m=m) myjitdriver.jit_merge_point(n=n, m=m) m -= 1 self.meta_interp(f, [123, 10]) assert len(get_stats().locations) >= 4 for loc in get_stats().locations: assert loc == (0, 0, 123)
def test_list_length_1(self): myjitdriver = JitDriver(greens=['foo'], reds=['x', 'total']) class Foo: _immutable_fields_ = ['lst?[*]'] def __init__(self, lst): self.lst = lst class A: pass def f(a, x): lst1 = [0, 0] lst1[1] = a foo = Foo(lst1) total = 0 while x > 0: myjitdriver.jit_merge_point(foo=foo, x=x, total=total) # make it a Constant after optimization only a = A() a.foo = foo foo = a.foo # read a quasi-immutable field out of it total += foo.lst[1] # also read the length total += len(foo.lst) x -= 1 return total # res = self.meta_interp(f, [100, 7]) assert res == 714 self.check_loops( guard_not_invalidated=2, getfield_gc=0, getarrayitem_gc=0, getarrayitem_gc_pure=0, arraylen_gc=0, everywhere=True) # from pypy.jit.metainterp.warmspot import get_stats loops = get_stats().loops for loop in loops: assert len(loop.quasi_immutable_deps) == 1 assert isinstance(loop.quasi_immutable_deps.keys()[0], QuasiImmut)
def test_interp_single_loop(self): myjitdriver = JitDriver(greens=['i'], reds=['x', 'y']) bytecode = "abcd" def f(x, y): i = 0 while i < len(bytecode): myjitdriver.jit_merge_point(i=i, x=x, y=y) op = bytecode[i] if op == 'a': x += y elif op == 'b': y -= 1 elif op == 'c': if y: i = 0 myjitdriver.can_enter_jit(i=i, x=x, y=y) continue else: x += 1 i += 1 return x res = self.meta_interp(f, [5, 8]) assert res == 42 self.check_trace_count(1) # the 'int_eq' and following 'guard' should be constant-folded if 'unroll' in self.enable_opts: self.check_resops(int_eq=0, guard_true=2, guard_false=0) else: self.check_resops(int_eq=0, guard_true=1, guard_false=0) if self.basic: found = 0 for op in get_stats().loops[0]._all_operations(): if op.getopname() == 'guard_true': liveboxes = op.getfailargs() assert len(liveboxes) == 2 # x, y (in some order) assert isinstance(liveboxes[0], history.BoxInt) assert isinstance(liveboxes[1], history.BoxInt) found += 1 if 'unroll' in self.enable_opts: assert found == 2 else: assert found == 1
def test_list_length_1(self): myjitdriver = JitDriver(greens=['foo'], reds=['x', 'total']) class Foo: _immutable_fields_ = ['lst?[*]'] def __init__(self, lst): self.lst = lst class A: pass def f(a, x): lst1 = [0, 0] lst1[1] = a foo = Foo(lst1) total = 0 while x > 0: myjitdriver.jit_merge_point(foo=foo, x=x, total=total) # make it a Constant after optimization only a = A() a.foo = foo foo = a.foo # read a quasi-immutable field out of it total += foo.lst[1] # also read the length total += len(foo.lst) x -= 1 return total # res = self.meta_interp(f, [100, 7]) assert res == 714 self.check_resops(getarrayitem_gc_pure=0, guard_not_invalidated=2, arraylen_gc=0, getarrayitem_gc=0, getfield_gc=0) # from pypy.jit.metainterp.warmspot import get_stats loops = get_stats().loops for loop in loops: assert len(loop.quasi_immutable_deps) == 1 assert isinstance(loop.quasi_immutable_deps.keys()[0], QuasiImmut)
def test_interp_single_loop(self): myjitdriver = JitDriver(greens = ['i'], reds = ['x', 'y']) bytecode = "abcd" def f(x, y): i = 0 while i < len(bytecode): myjitdriver.jit_merge_point(i=i, x=x, y=y) op = bytecode[i] if op == 'a': x += y elif op == 'b': y -= 1 elif op == 'c': if y: i = 0 myjitdriver.can_enter_jit(i=i, x=x, y=y) continue else: x += 1 i += 1 return x res = self.meta_interp(f, [5, 8]) assert res == 42 self.check_trace_count(1) # the 'int_eq' and following 'guard' should be constant-folded if 'unroll' in self.enable_opts: self.check_resops(int_eq=0, guard_true=2, guard_false=0) else: self.check_resops(int_eq=0, guard_true=1, guard_false=0) if self.basic: found = 0 for op in get_stats().loops[0]._all_operations(): if op.getopname() == 'guard_true': liveboxes = op.getfailargs() assert len(liveboxes) == 2 # x, y (in some order) assert isinstance(liveboxes[0], history.BoxInt) assert isinstance(liveboxes[1], history.BoxInt) found += 1 if 'unroll' in self.enable_opts: assert found == 2 else: assert found == 1
def test_list_pass_around(self): py.test.skip("think about a way to fix it") myjitdriver = JitDriver(greens=['foo'], reds=['x', 'total']) class Foo: _immutable_fields_ = ['lst?[*]'] def __init__(self, lst): self.lst = lst def g(lst): # here, 'lst' is statically annotated as a "modified" list, # so the following doesn't generate a getarrayitem_gc_pure... return lst[1] def f(a, x): lst1 = [0, 0] g(lst1) lst1[1] = a foo = Foo(lst1) total = 0 while x > 0: myjitdriver.jit_merge_point(foo=foo, x=x, total=total) # read a quasi-immutable field out of a Constant total += g(foo.lst) x -= 1 return total # res = self.meta_interp(f, [100, 7]) assert res == 700 self.check_resops(guard_not_invalidated=2, getfield_gc=0, getarrayitem_gc=0, getarrayitem_gc_pure=0) # from pypy.jit.metainterp.warmspot import get_stats loops = get_stats().loops for loop in loops: assert len(loop.quasi_immutable_deps) == 1 assert isinstance(loop.quasi_immutable_deps.keys()[0], QuasiImmut)
def test_interp_single_loop(self): myjitdriver = JitDriver(greens=["i"], reds=["x", "y"]) bytecode = "abcd" def f(x, y): i = 0 while i < len(bytecode): myjitdriver.jit_merge_point(i=i, x=x, y=y) op = bytecode[i] if op == "a": x += y elif op == "b": y -= 1 elif op == "c": if y: i = 0 myjitdriver.can_enter_jit(i=i, x=x, y=y) continue else: x += 1 i += 1 return x res = self.meta_interp(f, [5, 8]) assert res == 42 self.check_loop_count(1) # the 'int_eq' and following 'guard' should be constant-folded self.check_loops(int_eq=0, guard_true=1, guard_false=0) if self.basic: found = 0 for op in get_stats().loops[0]._all_operations(): if op.getopname() == "guard_true": liveboxes = op.fail_args assert len(liveboxes) == 2 # x, y (in some order) assert isinstance(liveboxes[0], history.BoxInt) assert isinstance(liveboxes[1], history.BoxInt) found += 1 assert found == 1
def test_call_assembler_keep_alive(self): myjitdriver1 = JitDriver(greens=['m'], reds=['n']) myjitdriver2 = JitDriver(greens=['m'], reds=['n', 'rec']) def h(m, n): while True: myjitdriver1.can_enter_jit(n=n, m=m) myjitdriver1.jit_merge_point(n=n, m=m) n = n >> 1 if n == 0: return 21 def g(m, rec): n = 5 while n > 0: myjitdriver2.can_enter_jit(n=n, m=m, rec=rec) myjitdriver2.jit_merge_point(n=n, m=m, rec=rec) if rec: h(m, rec) n = n - 1 return 21 def f(u): for i in range(8): h(u, 32) # make a loop and an exit bridge for h(u) g(u, 8) # make a loop for g(u) with a call_assembler g(u, 0); g(u+1, 0) # \ g(u, 0); g(u+2, 0) # \ make more loops for g(u+1) to g(u+4), g(u, 0); g(u+3, 0) # / but keeps g(u) alive g(u, 0); g(u+4, 0) # / g(u, 8) # call g(u) again, with its call_assembler to h(u) return 42 res = self.meta_interp(f, [1], loop_longevity=4, inline=True) assert res == 42 self.check_jitcell_token_count(6) tokens = [t() for t in get_stats().jitcell_token_wrefs] # Some loops have been freed assert None in tokens # Loop with number 0, h(), has not been freed assert 0 in [t.number for t in tokens if t]
def test_inline(self): # this is not an example of reasonable code: loop1() is unrolled # 'n/m' times, where n and m are given as red arguments. myjitdriver1 = JitDriver(greens=[], reds=['n', 'm'], get_printable_location=getloc1) myjitdriver2 = JitDriver(greens=['g'], reds=['r'], get_printable_location=getloc2) # def loop1(n, m): while n > 0: if n > 1000: myjitdriver1.can_enter_jit(n=n, m=m) myjitdriver1.jit_merge_point(n=n, m=m) n -= m return n # def loop2(g, r): set_param(None, 'function_threshold', 0) while r > 0: myjitdriver2.can_enter_jit(g=g, r=r) myjitdriver2.jit_merge_point(g=g, r=r) r += loop1(r, g) - 1 return r # res = self.meta_interp(loop2, [4, 40], repeat=7, inline=True) assert res == loop2(4, 40) # we expect no loop at all for 'loop1': it should always be inlined # we do however get several version of 'loop2', all of which contains # at least one int_add, while there are no int_add's in 'loop1' self.check_jitcell_token_count(1) for loop in get_stats().loops: assert loop.summary()['int_add'] >= 1
def test_multiple_jits_trace_too_long(self): myjitdriver1 = JitDriver(greens=["n"], reds=["i", "box"]) myjitdriver2 = JitDriver(greens=["n"], reds=["i"]) class IntBox(object): def __init__(self, val): self.val = val def loop1(n): i = 0 box = IntBox(10) while i < n: myjitdriver1.can_enter_jit(n=n, i=i, box=box) myjitdriver1.jit_merge_point(n=n, i=i, box=box) i += 1 loop2(box) return i def loop2(n): i = 0 f(10) while i < n.val: myjitdriver2.can_enter_jit(n=n, i=i) myjitdriver2.jit_merge_point(n=n, i=i) i += 1 @unroll_safe def f(n): i = 0 while i < n: i += 1 res = self.meta_interp(loop1, [10], inline=True, trace_limit=6) assert res == 10 stats = get_stats() assert stats.aborted_keys == [None, None]
def check_trace_count(self, count): # was check_loop_count # The number of traces compiled assert get_stats().compiled_count == count
def do_debug_merge_point(cpu, box1): from pypy.jit.metainterp.warmspot import get_stats loc = box1._get_str() get_stats().add_merge_point_location(loc)
def check_history(self, expected=None, **isns): # this can be used after calling meta_interp get_stats().check_history(expected, **isns)
def check_aborted_count(self, count): assert get_stats().aborted_count == count
def check_enter_count_at_most(self, count): assert get_stats().enter_count <= count
def check_target_token_count(self, count): tokens = get_stats().get_all_jitcell_tokens() n = sum([len(t.target_tokens) for t in tokens]) assert n == count
def check_enter_count(self, count): assert get_stats().enter_count == count
def check_loop_count_at_most(self, count): assert get_stats().compiled_count <= count
def check_tree_loop_count(self, count): assert len(get_stats().loops) == count
def check_jumps(self, maxcount): return # FIXME assert get_stats().exec_jumps <= maxcount
def check_loops(self, expected=None, everywhere=False, **check): get_stats().check_loops(expected=expected, everywhere=everywhere, **check)
def check_trace_count_at_most(self, count): assert get_stats().compiled_count <= count
def check_jitcell_token_count(self, count): # was check_tree_loop_count assert len(get_stats().jitcell_token_wrefs) == count
def check_simple_loop(self, expected=None, **check): get_stats().check_simple_loop(expected=expected, **check)
def check_aborted_count_at_least(self, count): assert get_stats().aborted_count >= count
def check_target_token_count(self, count): tokens = get_stats().get_all_jitcell_tokens() n = sum ([len(t.target_tokens) for t in tokens]) assert n == count
def check_max_trace_length(self, length): for loop in get_stats().loops: assert len(loop.operations) <= length + 5 # because we only check once per metainterp bytecode for op in loop.operations: if op.is_guard() and hasattr(op.getdescr(), '_debug_suboperations'): assert len(op.getdescr()._debug_suboperations) <= length + 5
def check_loops(self, expected=None, **check): get_stats().check_loops(expected=expected, **check)
def check_resops(self, expected=None, **check): get_stats().check_resops(expected=expected, **check)
def check_jumps(self, maxcount): assert get_stats().exec_jumps <= maxcount