def get_class(name): """Resolve a package+class name into the corresponding type""" if 'cuda' in name: if not ek.has_backend(ek.JitBackend.CUDA): pytest.skip('CUDA mode is unsupported') elif 'llvm' in name: if not ek.has_backend(ek.JitBackend.LLVM): pytest.skip('LLVM mode is unsupported') name = name.split('.') value = __import__(".".join(name[:-1])) for item in name[1:]: value = getattr(value, item) ek.set_flag(ek.JitFlag.LoopRecord, True) return value
def test05_side_effect_noloop(pkg): p = get_class(pkg) i = ek.zero(p.Int, 10) j = ek.zero(p.Int, 10) buf = ek.zero(p.Float, 10) ek.set_flag(ek.JitFlag.LoopRecord, False) loop = p.Loop("MyLoop", lambda: (i, j)) while loop(i < 10): j += i i += 1 ek.scatter_reduce(op=ek.ReduceOp.Add, target=buf, value=p.Float(i), index=0) assert i == p.Int([10] * 10) assert buf == p.Float(550, *([0] * 9)) assert j == p.Int([45] * 10)
def test01_record_loop(pkg): p = get_class(pkg) for i in range(3): ek.set_flag(ek.JitFlag.LoopRecord, not i == 0) ek.set_flag(ek.JitFlag.LoopOptimize, i == 2) for j in range(2): x = ek.arange(p.Int, 0, 10) y = ek.zero(p.Float, 1) z = p.Float(1) loop = p.Loop("MyLoop", lambda: (x, y, z)) while loop(x < 5): y += p.Float(x) x += 1 z = z + 1 if j == 0: ek.schedule(x, y, z) assert z == p.Int(6, 5, 4, 3, 2, 1, 1, 1, 1, 1) assert y == p.Int(10, 10, 9, 7, 4, 0, 0, 0, 0, 0) assert x == p.Int(5, 5, 5, 5, 5, 5, 6, 7, 8, 9)
def no_record(): value_before = ek.flag(ek.JitFlag.LoopRecord) ek.set_flag(ek.JitFlag.LoopRecord, False) yield None ek.set_flag(ek.JitFlag.LoopRecord, value_before)
def __exit__(self, type, value, traceback): ek.set_flag(ek.JitFlag.ADEagerForward, False)
def __enter__(self): ek.set_flag(ek.JitFlag.ADEagerForward, True)
def teardown_function(function): ek.set_flag(ek.JitFlag.LoopRecord, False)
def setup_function(function): ek.set_flag(ek.JitFlag.LoopRecord, True)