Exemplo n.º 1
0
def get_class(name):
    """Resolve a package+class name into the corresponding type"""
    if 'cuda' in name:
        if not ek.has_backend(ek.JitBackend.CUDA):
            pytest.skip('CUDA mode is unsupported')
    elif 'llvm' in name:
        if not ek.has_backend(ek.JitBackend.LLVM):
            pytest.skip('LLVM mode is unsupported')

    name = name.split('.')
    value = __import__(".".join(name[:-1]))
    for item in name[1:]:
        value = getattr(value, item)
    ek.set_flag(ek.JitFlag.LoopRecord, True)

    return value
Exemplo n.º 2
0
def test05_side_effect_noloop(pkg):
    p = get_class(pkg)

    i = ek.zero(p.Int, 10)
    j = ek.zero(p.Int, 10)
    buf = ek.zero(p.Float, 10)
    ek.set_flag(ek.JitFlag.LoopRecord, False)

    loop = p.Loop("MyLoop", lambda: (i, j))
    while loop(i < 10):
        j += i
        i += 1
        ek.scatter_reduce(op=ek.ReduceOp.Add,
                          target=buf,
                          value=p.Float(i),
                          index=0)

    assert i == p.Int([10] * 10)
    assert buf == p.Float(550, *([0] * 9))
    assert j == p.Int([45] * 10)
Exemplo n.º 3
0
def test01_record_loop(pkg):
    p = get_class(pkg)

    for i in range(3):
        ek.set_flag(ek.JitFlag.LoopRecord, not i == 0)
        ek.set_flag(ek.JitFlag.LoopOptimize, i == 2)

        for j in range(2):
            x = ek.arange(p.Int, 0, 10)
            y = ek.zero(p.Float, 1)
            z = p.Float(1)

            loop = p.Loop("MyLoop", lambda: (x, y, z))
            while loop(x < 5):
                y += p.Float(x)
                x += 1
                z = z + 1

            if j == 0:
                ek.schedule(x, y, z)

            assert z == p.Int(6, 5, 4, 3, 2, 1, 1, 1, 1, 1)
            assert y == p.Int(10, 10, 9, 7, 4, 0, 0, 0, 0, 0)
            assert x == p.Int(5, 5, 5, 5, 5, 5, 6, 7, 8, 9)
Exemplo n.º 4
0
def no_record():
    value_before = ek.flag(ek.JitFlag.LoopRecord)
    ek.set_flag(ek.JitFlag.LoopRecord, False)
    yield None
    ek.set_flag(ek.JitFlag.LoopRecord, value_before)
Exemplo n.º 5
0
 def __exit__(self, type, value, traceback):
     ek.set_flag(ek.JitFlag.ADEagerForward, False)
Exemplo n.º 6
0
 def __enter__(self):
     ek.set_flag(ek.JitFlag.ADEagerForward, True)
Exemplo n.º 7
0
def teardown_function(function):
    ek.set_flag(ek.JitFlag.LoopRecord, False)
Exemplo n.º 8
0
def setup_function(function):
    ek.set_flag(ek.JitFlag.LoopRecord, True)