def test16_custom(cname): t = get_class(cname) v1 = ek.zero(t, 100) v2 = ek.empty(t, 100) assert len(v1.state) == 100 assert len(v2.inc) == 100 v2.state = v1.state v1.state = ek.arange(type(v1.state), 100) v3 = ek.select(v1.state < 10, v1, v2) assert v3.state[3] == 3 assert v3.state[11] == 0 assert ek.width(v3) == 100 v4 = ek.zero(t, 1) ek.schedule(v4) ek.resize(v4, 200) assert ek.width(v4) == 200 assert ek.width(v3) == 100 v4 = ek.zero(t, 1) ek.resize(v4, 200) assert ek.width(v4) == 200 index = ek.arange(type(v1.state), 100) ek.scatter(v4, v1, index) v5 = ek.gather(t, v4, index) ek.eval(v5) assert v5.state == v1.state and v5.inc == v1.inc
def eval(self, pos, vel): pos, vel = m.Array2f(pos), m.Array2f(vel) # Run for 100 iterations it, max_it = m.UInt32(0), 100 # Allocate scratch space n = max(ek.width(pos), ek.width(vel)) self.temp_pos = ek.empty(m.Array2f, n * max_it) self.temp_vel = ek.empty(m.Array2f, n * max_it) loop = m.Loop(pos, vel, it) while loop.cond(it < max_it): # Store current loop variables index = it * n + ek.arange(m.UInt32, n) ek.scatter(self.temp_pos, pos, index) ek.scatter(self.temp_vel, vel, index) # Update loop variables pos_out, vel_out = self.timestep(pos, vel) pos.assign(pos_out) vel.assign(vel_out) it += 1 # Ensure output and temp. arrays are evaluated at this point ek.eval(pos, vel) return pos, vel
def test22_scatter_fwd(m): x = m.Float(4.0) ek.enable_grad(x) values = x * x * ek.linspace(m.Float, 1, 4, 4) idx = 2 * ek.arange(m.UInt32, 4) buf = ek.zero(m.Float, 10) ek.scatter(buf, values, idx) assert ek.grad_enabled(buf) ref = [16.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0] assert ek.allclose(buf, ref) ek.forward(x, retain_graph=True) grad = ek.grad(buf) ref_grad = [8.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0] assert ek.allclose(grad, ref_grad) # Overwrite first value with non-diff value, resulting gradient entry should be 0 y = m.Float(3) idx = m.UInt32(0) ek.scatter(buf, y, idx) ref = [3.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0] assert ek.allclose(buf, ref) ek.forward(x) grad = ek.grad(buf) ref_grad = [0.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0] assert ek.allclose(grad, ref_grad)
def test07_loop_nest(pkg, variant): p = get_class(pkg) def collatz(value: p.Int): counter = p.Int(0) loop = p.Loop(value, counter) while (loop.cond(ek.neq(value, 1))): is_even = ek.eq(value & 1, 0) value.assign(ek.select(is_even, value // 2, 3 * value + 1)) counter += 1 return counter i = p.Int(1) buf = ek.full(p.Int, 1000, 16) ek.eval(buf) if variant == 0: loop_1 = p.Loop(i) while loop_1.cond(i <= 10): ek.scatter(buf, collatz(p.Int(i)), i - 1) i += 1 else: for i in range(1, 11): ek.scatter(buf, collatz(p.Int(i)), i - 1) i += 1 assert buf == p.Int(0, 1, 7, 2, 5, 8, 16, 3, 19, 6, 1000, 1000, 1000, 1000, 1000, 1000)
def test22_scatter_rev(m): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) if i // 2 == 0: ek.enable_grad(x, y) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter(buf2, x, idx1) ek.eval(buf2) ek.scatter(buf2, y, idx2) ref_buf = m.Float(0.0000, 0.2500, 0.5000, 1.0000, 1.3333, 1.6667, 2.0000, 0.0000, 0.0000, 0.0000) assert ek.allclose(ref_buf, buf2, atol=1e-4) assert ek.allclose(ref_buf, buf, atol=1e-4) s = ek.dot_async(buf2, buf2) ek.backward(s) ref_x = m.Float(0.0000, 0.5000, 1.0000, 0.0000, 0.0000) ref_y = m.Float(2.0000, 2.6667, 3.3333, 4.0000) if i // 2 == 0: assert ek.allclose(ek.grad(y), ek.detach(ref_y), atol=1e-4) assert ek.allclose(ek.grad(x), ek.detach(ref_x), atol=1e-4) else: assert ek.grad(x) == 0 assert ek.grad(y) == 0 if i % 2 == 0: assert ek.allclose(ek.grad(buf), 0, atol=1e-4) else: assert ek.grad(buf) == 0
def test22_scatter_fwd_permute(m): x = m.Float(4.0) ek.enable_grad(x) values_0 = x * ek.linspace(m.Float, 1, 9, 5) values_1 = x * ek.linspace(m.Float, 11, 19, 5) buf = ek.zero(m.Float, 10) idx_0 = ek.arange(m.UInt32, 5) idx_1 = ek.arange(m.UInt32, 5) + 5 ek.scatter(buf, values_0, idx_0, permute=False) ek.scatter(buf, values_1, idx_1, permute=False) ref = [4.0, 12.0, 20.0, 28.0, 36.0, 44.0, 52.0, 60.0, 68.0, 76.0] assert ek.allclose(buf, ref) ek.forward(x) grad = ek.grad(buf) ref_grad = [1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0] assert ek.allclose(grad, ref_grad)
def unravel(source, target, dim=3): idx = UInt32.arange(ek.slices(source)) for i in range(dim): ek.scatter(target, source[i], dim * idx + i)
def scatter_(self, target, index, mask, permute): assert target.Depth == 1 sr = max(len(self), len(index), len(mask)) for i in range(sr): _ek.scatter(target, self[i], index[i], mask[i], permute)
def test54_scatter_implicit_detach(m): x = ek.detach(m.Float(0)) y = ek.detach(m.Float(1)) i = m.UInt32(0) m = m.Bool(True) ek.scatter(x, y, i, m)