def backward(self): grad_pos, grad_vel = self.grad_out() # Run for 100 iterations it = m.UInt32(100) loop = m.Loop(it, grad_pos, grad_vel) n = ek.width(grad_pos) while loop.cond(it > 0): # Retrieve loop variables, reverse chronological order it -= 1 index = it * n + ek.arange(m.UInt32, n) pos = ek.gather(m.Array2f, self.temp_pos, index) vel = ek.gather(m.Array2f, self.temp_vel, index) # Differentiate loop body in reverse mode ek.enable_grad(pos, vel) pos_out, vel_out = self.timestep(pos, vel) ek.set_grad(pos_out, grad_pos) ek.set_grad(vel_out, grad_vel) ek.enqueue(pos_out, vel_out) ek.traverse(m.Float, reverse=True) # Update loop variables grad_pos.assign(ek.grad(pos)) grad_vel.assign(ek.grad(vel)) self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel)
def backward(self): grad_pos, grad_vel = self.grad_out() pos, vel = self.pos, self.vel # Run for 100 iterations it = m.UInt32(0) loop = m.Loop(it, pos, vel, grad_pos, grad_vel) while loop.cond(it < 100): # Take reverse step in time pos_rev, vel_rev = self.timestep(pos, vel, dt=-0.02) pos.assign(pos_rev) vel.assign(vel_rev) # Take a forward step in time, keep track of derivatives ek.enable_grad(pos_rev, vel_rev) pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02) ek.set_grad(pos_fwd, grad_pos) ek.set_grad(vel_fwd, grad_vel) ek.enqueue(pos_fwd, vel_fwd) ek.traverse(m.Float, reverse=True) grad_pos.assign(ek.grad(pos_rev)) grad_vel.assign(ek.grad(vel_rev)) it += 1 self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel)
def backward(self): grad_pos, grad_vel = self.grad_out() pos, vel = self.pos, self.vel # Run for 100 iterations it = m.UInt32(0) loop = m.Loop("backward", lambda: (it, pos, vel, grad_pos, grad_vel)) while loop(it < 100): # Take reverse step in time pos, vel = self.timestep(pos, vel, dt=-0.02) # Take a forward step in time, keep track of derivatives pos_rev, vel_rev = m.Array2f(pos), m.Array2f(vel) ek.enable_grad(pos_rev, vel_rev) pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02) ek.set_grad(pos_fwd, grad_pos) ek.set_grad(vel_fwd, grad_vel) ek.enqueue(pos_fwd, vel_fwd) ek.traverse(m.Float, reverse=True) grad_pos = ek.grad(pos_rev) grad_vel = ek.grad(vel_rev) it += 1 self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel)
def test52_scatter_fwd_eager(m): with EagerMode(): x = m.Float(4.0) ek.enable_grad(x) ek.set_grad(x, 1) values = x * x * ek.linspace(m.Float, 1, 4, 4) idx = 2 * ek.arange(m.UInt32, 4) buf = ek.zero(m.Float, 10) ek.scatter(buf, values, idx) assert ek.grad_enabled(buf) ref = [16.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0] assert ek.allclose(buf, ref) grad = ek.grad(buf) ref_grad = [8.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0] assert ek.allclose(grad, ref_grad) # Overwrite first value with non-diff value, resulting gradient entry should be 0 y = m.Float(3) idx = m.UInt32(0) ek.scatter(buf, y, idx) ref = [3.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0] assert ek.allclose(buf, ref) ek.forward(x) grad = ek.grad(buf) ref_grad = [0.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0] assert ek.allclose(grad, ref_grad)
def test49_eager_fwd(m): with EagerMode(): x = m.Float(1) ek.enable_grad(x) ek.set_grad(x, 10) y = x * x * x assert ek.grad(y) == 30
def test43_custom_reverse(m): d = m.Array3f(1, 2, 3) ek.enable_grad(d) d2 = ek.custom(Normalize, d) ek.set_grad(d2, m.Array3f(5, 6, 7)) ek.enqueue(d2) ek.traverse(m.Float, reverse=True) assert ek.allclose(ek.grad(d), m.Array3f(0.610883, 0.152721, -0.305441))
def test50_gather_fwd_eager(m): with EagerMode(): x = ek.linspace(m.Float, -1, 1, 10) ek.enable_grad(x) ek.set_grad(x, 1) y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3)) ref = [-1.55556, -1.55556, -1.11111, -0.666667] assert ek.allclose(ek.grad(y), ref)
def test44_custom_forward(m): d = m.Array3f(1, 2, 3) ek.enable_grad(d) d2 = ek.custom(Normalize, d) ek.set_grad(d, m.Array3f(5, 6, 7)) ek.enqueue(d) ek.traverse(m.Float, reverse=False, retain_graph=True) assert ek.grad(d) == 0 ek.set_grad(d, m.Array3f(5, 6, 7)) assert ek.allclose(ek.grad(d2), m.Array3f(0.610883, 0.152721, -0.305441)) ek.enqueue(d) ek.traverse(m.Float, reverse=False, retain_graph=False) assert ek.allclose(ek.grad(d2), m.Array3f(0.610883, 0.152721, -0.305441) * 2)
def test_ad_operations(package): Float, Array3f = package.Float, package.Array3f prepare(package) class MyStruct: ENOKI_STRUCT = {'a': Array3f, 'b': Float} def __init__(self): self.a = Array3f() self.b = Float() foo = ek.zero(MyStruct, 4) assert not ek.grad_enabled(foo.a) assert not ek.grad_enabled(foo.b) assert not ek.grad_enabled(foo) ek.enable_grad(foo) assert ek.grad_enabled(foo.a) assert ek.grad_enabled(foo.b) assert ek.grad_enabled(foo) foo_detached = ek.detach(foo) assert not ek.grad_enabled(foo_detached.a) assert not ek.grad_enabled(foo_detached.b) assert not ek.grad_enabled(foo_detached) x = Float(4.0) ek.enable_grad(x) foo.a += x foo.b += x * x ek.forward(x) foo_grad = ek.grad(foo) assert foo_grad.a == 1 assert foo_grad.b == 8 ek.set_grad(foo, 5.0) foo_grad = ek.grad(foo) assert foo_grad.a == 5.0 assert foo_grad.b == 5.0 ek.accum_grad(foo, 5.0) foo_grad = ek.grad(foo) assert foo_grad.a == 10.0 assert foo_grad.b == 10.0
def test51_scatter_reduce_fwd_eager(m): with EagerMode(): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) ek.set_grad(buf, 1) if i // 2 == 0: ek.enable_grad(x, y) ek.set_grad(x, 1) ek.set_grad(y, 1) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1) ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2) s = ek.dot_async(buf2, buf2) # Verified against Mathematica assert ek.allclose(ek.detach(s), 15.5972) assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) + (17 if i % 2 == 0 else 0))
def test53_scatter_fwd_permute_eager(m): with EagerMode(): x = m.Float(4.0) ek.enable_grad(x) ek.set_grad(x, 1) values_0 = x * ek.linspace(m.Float, 1, 9, 5) values_1 = x * ek.linspace(m.Float, 11, 19, 5) buf = ek.zero(m.Float, 10) idx_0 = ek.arange(m.UInt32, 5) idx_1 = ek.arange(m.UInt32, 5) + 5 ek.scatter(buf, values_0, idx_0, permute=False) ek.scatter(buf, values_1, idx_1, permute=False) ref = [4.0, 12.0, 20.0, 28.0, 36.0, 44.0, 52.0, 60.0, 68.0, 76.0] assert ek.allclose(buf, ref) grad = ek.grad(buf) ref_grad = [1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0] assert ek.allclose(grad, ref_grad)
def test21_scatter_add_fwd(m): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) ek.set_grad(buf, 1) if i // 2 == 0: ek.enable_grad(x, y) ek.set_grad(x, 1) ek.set_grad(y, 1) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_add(buf2, x, idx1) ek.scatter_add(buf2, y, idx2) s = ek.dot_async(buf2, buf2) if i % 2 == 0: ek.enqueue(buf) if i // 2 == 0: ek.enqueue(x, y) ek.traverse(m.Float, reverse=False) # Verified against Mathematica assert ek.allclose(ek.detach(s), 15.5972) assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) + (17 if i % 2 == 0 else 0))
def test02_add_fwd(m): if True: a, b = m.Float(1), m.Float(2) ek.enable_grad(a, b) c = 2 * a + b ek.forward(a, retain_graph=True) assert ek.grad(c) == 2 ek.set_grad(c, 101) ek.forward(b) assert ek.grad(c) == 102 if True: a, b = m.Float(1), m.Float(2) ek.enable_grad(a, b) c = 2 * a + b ek.set_grad(a, 1.0) ek.enqueue(a) ek.traverse(m.Float, retain_graph=True, reverse=False) assert ek.grad(c) == 2 assert ek.grad(a) == 0 ek.set_grad(a, 1.0) ek.enqueue(a) ek.traverse(m.Float, retain_graph=False, reverse=False) assert ek.grad(c) == 4