def test_kernel_template_mapper(): x = ti.var(ti.i32) y = ti.var(ti.f32) ti.root.place(x, y) mapper = ti.KernelTemplateMapper( (ti.template(), ti.template(), ti.template()), template_slot_locations=(0, 1, 2)) assert mapper.lookup((0, 0, 0))[0] == 0 assert mapper.lookup((0, 1, 0))[0] == 1 assert mapper.lookup((0, 0, 0))[0] == 0 assert mapper.lookup((0, 0, 1))[0] == 2 assert mapper.lookup((0, 1, 0))[0] == 1 mapper = ti.KernelTemplateMapper((ti.i32, ti.i32, ti.i32), ()) assert mapper.lookup((0, 0, 0))[0] == 0 assert mapper.lookup((0, 1, 0))[0] == 0 assert mapper.lookup((0, 0, 0))[0] == 0 assert mapper.lookup((0, 0, 1))[0] == 0 assert mapper.lookup((0, 1, 0))[0] == 0 mapper = ti.KernelTemplateMapper((ti.i32, ti.template(), ti.i32), (1, )) assert mapper.lookup((0, x, 0))[0] == 0 assert mapper.lookup((0, y, 0))[0] == 1 assert mapper.lookup((0, x, 1))[0] == 0
def test_kernel_template_mapper(): x = ti.var(ti.i32) y = ti.var(ti.f32) @ti.layout def layout(): ti.root.place(x, y) mapper = ti.KernelTemplateMapper(3, (0, 1, 2)) assert mapper.lookup((0, 0, 0)) == 0 assert mapper.lookup((0, 1, 0)) == 1 assert mapper.lookup((0, 0, 0)) == 0 assert mapper.lookup((0, 0, 1)) == 2 assert mapper.lookup((0, 1, 0)) == 1 mapper = ti.KernelTemplateMapper(3, ()) assert mapper.lookup((0, 0, 0)) == 0 assert mapper.lookup((0, 1, 0)) == 0 assert mapper.lookup((0, 0, 0)) == 0 assert mapper.lookup((0, 0, 1)) == 0 assert mapper.lookup((0, 1, 0)) == 0 mapper = ti.KernelTemplateMapper(3, (1,)) assert mapper.lookup((0, x, 0)) == 0 assert mapper.lookup((0, y, 0)) == 1 assert mapper.lookup((0, x, 1)) == 0
def test_kernel_template_mapper_numpy(): x = ti.var(ti.i32) y = ti.var(ti.f32) ti.root.place(x, y) annotations = (ti.template(), ti.template(), ti.ext_arr()) import numpy as np mapper = ti.KernelTemplateMapper(annotations, (0, 1, 2)) assert mapper.lookup((0, 0, np.ones(shape=(1, 2, 3), dtype=np.float32)))[0] == 0 assert mapper.lookup((0, 0, np.ones(shape=(1, 2, 4), dtype=np.float32)))[0] == 0 assert mapper.lookup((0, 0, np.ones(shape=(1, 2, 1), dtype=np.int32)))[0] == 1