def test51_scatter_reduce_fwd_eager(m): with EagerMode(): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) ek.set_grad(buf, 1) if i // 2 == 0: ek.enable_grad(x, y) ek.set_grad(x, 1) ek.set_grad(y, 1) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1) ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2) s = ek.dot_async(buf2, buf2) # Verified against Mathematica assert ek.allclose(ek.detach(s), 15.5972) assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) + (17 if i % 2 == 0 else 0))
def test_58_diffloop_masking_rev(m, no_record): fo = ek.zero(m.Float, 10) fi = m.Float(1, 2) i = m.UInt32(0, 5) ek.enable_grad(fi) loop = m.Loop("MyLoop", lambda: i) while loop(i < 5): ek.scatter_reduce(ek.ReduceOp.Add, fo, fi, i) i += 1 ek.backward(fo) assert ek.grad(fi) == m.Float(5, 0)
def test20_scatter_reduce_rev(m): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) if i // 2 == 0: ek.enable_grad(x, y) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1) ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2) ref_buf = m.Float(0.0000, 0.2500, 0.5000, 1.7500, 2.3333, 1.6667, 2.0000, 0.0000, 0.0000, 0.0000) assert ek.allclose(ref_buf, buf2, atol=1e-4) s = ek.dot_async(buf2, buf2) ek.backward(s) ref_x = m.Float(0.0000, 0.5000, 1.0000, 3.5000, 4.6667) ref_y = m.Float(3.5000, 4.6667, 3.3333, 4.0000) if i // 2 == 0: assert ek.allclose(ek.grad(y), ek.detach(ref_y), atol=1e-4) assert ek.allclose(ek.grad(x), ek.detach(ref_x), atol=1e-4) else: assert ek.grad(x) == 0 assert ek.grad(y) == 0 if i % 2 == 0: assert ek.allclose(ek.grad(buf), ek.detach(ref_buf) * 2, atol=1e-4) else: assert ek.grad(buf) == 0
def test05_side_effect_noloop(pkg): p = get_class(pkg) i = ek.zero(p.Int, 10) j = ek.zero(p.Int, 10) buf = ek.zero(p.Float, 10) ek.set_flag(ek.JitFlag.LoopRecord, False) loop = p.Loop("MyLoop", lambda: (i, j)) while loop(i < 10): j += i i += 1 ek.scatter_reduce(op=ek.ReduceOp.Add, target=buf, value=p.Float(i), index=0) assert i == p.Int([10] * 10) assert buf == p.Float(550, *([0] * 9)) assert j == p.Int([45] * 10)
def scatter_reduce_(self, op, target, index, mask): assert target.Depth == 1 sr = max(len(self), len(index), len(mask)) for i in range(sr): _ek.scatter_reduce(op, target, self[i], index[i], mask[i])