def test51_scatter_reduce_fwd_eager(m): with EagerMode(): for i in range(3): idx1 = ek.arange(m.UInt, 5) idx2 = ek.arange(m.UInt, 4) + 3 x = ek.linspace(m.Float, 0, 1, 5) y = ek.linspace(m.Float, 1, 2, 4) buf = ek.zero(m.Float, 10) if i % 2 == 0: ek.enable_grad(buf) ek.set_grad(buf, 1) if i // 2 == 0: ek.enable_grad(x, y) ek.set_grad(x, 1) ek.set_grad(y, 1) x.label = "x" y.label = "y" buf.label = "buf" buf2 = m.Float(buf) ek.scatter_reduce(ek.ReduceOp.Add, buf2, x, idx1) ek.scatter_reduce(ek.ReduceOp.Add, buf2, y, idx2) s = ek.dot_async(buf2, buf2) # Verified against Mathematica assert ek.allclose(ek.detach(s), 15.5972) assert ek.allclose(ek.grad(s), (25.1667 if i // 2 == 0 else 0) + (17 if i % 2 == 0 else 0))
def test19_gather_fwd(m): x = ek.linspace(m.Float, -1, 1, 10) ek.enable_grad(x) y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3)) ek.forward(x) ref = [-1.55556, -1.55556, -1.11111, -0.666667] assert ek.allclose(ek.grad(y), ref)
def backward(self): grad_pos, grad_vel = self.grad_out() pos, vel = self.pos, self.vel # Run for 100 iterations it = m.UInt32(0) loop = m.Loop(it, pos, vel, grad_pos, grad_vel) while loop.cond(it < 100): # Take reverse step in time pos_rev, vel_rev = self.timestep(pos, vel, dt=-0.02) pos.assign(pos_rev) vel.assign(vel_rev) # Take a forward step in time, keep track of derivatives ek.enable_grad(pos_rev, vel_rev) pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02) ek.set_grad(pos_fwd, grad_pos) ek.set_grad(vel_fwd, grad_vel) ek.enqueue(pos_fwd, vel_fwd) ek.traverse(m.Float, reverse=True) grad_pos.assign(ek.grad(pos_rev)) grad_vel.assign(ek.grad(vel_rev)) it += 1 self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel)
def test22_scatter_fwd(m): x = m.Float(4.0) ek.enable_grad(x) values = x * x * ek.linspace(m.Float, 1, 4, 4) idx = 2 * ek.arange(m.UInt32, 4) buf = ek.zero(m.Float, 10) ek.scatter(buf, values, idx) assert ek.grad_enabled(buf) ref = [16.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0] assert ek.allclose(buf, ref) ek.forward(x, retain_graph=True) grad = ek.grad(buf) ref_grad = [8.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0] assert ek.allclose(grad, ref_grad) # Overwrite first value with non-diff value, resulting gradient entry should be 0 y = m.Float(3) idx = m.UInt32(0) ek.scatter(buf, y, idx) ref = [3.0, 0.0, 32.0, 0.0, 48.0, 0.0, 64.0, 0.0, 0.0, 0.0] assert ek.allclose(buf, ref) ek.forward(x) grad = ek.grad(buf) ref_grad = [0.0, 0.0, 16.0, 0.0, 24.0, 0.0, 32.0, 0.0, 0.0, 0.0] assert ek.allclose(grad, ref_grad)
def test49_eager_fwd(m): with EagerMode(): x = m.Float(1) ek.enable_grad(x) ek.set_grad(x, 10) y = x * x * x assert ek.grad(y) == 30
def test14_rsqrt(m): x = m.Float(1, .25, 0.0625) ek.enable_grad(x) y = ek.rsqrt(x) ek.backward(y) assert ek.allclose(ek.detach(y), [1, 2, 4]) assert ek.allclose(ek.grad(x), [-.5, -4, -32])
def test15_abs(m): x = m.Float(-2, 2) ek.enable_grad(x) y = ek.abs(x) ek.backward(y) assert ek.allclose(ek.detach(y), [2, 2]) assert ek.allclose(ek.grad(x), [-1, 1])
def test13_sqrt(m): x = m.Float(1, 4, 16) ek.enable_grad(x) y = ek.sqrt(x) ek.backward(y) assert ek.allclose(ek.detach(y), [1, 2, 4]) assert ek.allclose(ek.grad(x), [.5, .25, .125])
def test12_hmax_fwd(m): x = m.Float(1, 2, 8, 5, 8) ek.enable_grad(x) y = ek.hmax_async(x) ek.forward(x) assert len(y) == 1 and ek.allclose(y[0], 8) assert ek.allclose(ek.grad(y), [2]) # Approximation
def test11_hmax_rev(m): x = m.Float(1, 2, 8, 5, 8) ek.enable_grad(x) y = ek.hmax_async(x) ek.backward(y) assert len(y) == 1 and ek.allclose(y[0], 8) assert ek.allclose(ek.grad(x), [0, 0, 1, 0, 1])
def test17_cos(m): x = ek.linspace(m.Float, 0.01, 10, 10) ek.enable_grad(x) y = ek.cos(x) ek.backward(y) assert ek.allclose(ek.detach(y), ek.cos(ek.detach(x))) assert ek.allclose(ek.grad(x), -ek.sin(ek.detach(x)))
def test01_add_rev(m): a, b = m.Float(1), m.Float(2) ek.enable_grad(a, b) c = 2 * a + b ek.backward(c) assert ek.grad(a) == 2 assert ek.grad(b) == 1
def test04_div(m): a, b = m.Float(2), m.Float(3) ek.enable_grad(a, b) d = a / b ek.backward(d) assert ek.allclose(ek.grad(a), 1.0 / 3.0) assert ek.allclose(ek.grad(b), -2.0 / 9.0)
def test06_hsum_0_fwd(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.hsum_async(x * x) ek.forward(x) assert len(y) == 1 and ek.allclose(ek.detach(y), 95.0 / 27.0) assert len(ek.grad(y)) == 1 and ek.allclose(ek.grad(y), 10)
def test05_hsum_0_rev(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.hsum_async(x * x) ek.backward(y) assert len(y) == 1 and ek.allclose(y, 95.0 / 27.0) assert ek.allclose(ek.grad(x), 2 * ek.detach(x))
def test11_hprod(m): x = m.Float(1, 2, 5, 8) ek.enable_grad(x) y = ek.hprod_async(x) ek.backward(y) assert len(y) == 1 and ek.allclose(y[0], 80) assert ek.allclose(ek.grad(x), [80, 40, 16, 10])
def backward(self): grad_pos, grad_vel = self.grad_out() # Run for 100 iterations it = m.UInt32(100) loop = m.Loop(it, grad_pos, grad_vel) n = ek.width(grad_pos) while loop.cond(it > 0): # Retrieve loop variables, reverse chronological order it -= 1 index = it * n + ek.arange(m.UInt32, n) pos = ek.gather(m.Array2f, self.temp_pos, index) vel = ek.gather(m.Array2f, self.temp_vel, index) # Differentiate loop body in reverse mode ek.enable_grad(pos, vel) pos_out, vel_out = self.timestep(pos, vel) ek.set_grad(pos_out, grad_pos) ek.set_grad(vel_out, grad_vel) ek.enqueue(pos_out, vel_out) ek.traverse(m.Float, reverse=True) # Update loop variables grad_pos.assign(ek.grad(pos)) grad_vel.assign(ek.grad(vel)) self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel)
def backward(self): grad_pos, grad_vel = self.grad_out() pos, vel = self.pos, self.vel # Run for 100 iterations it = m.UInt32(0) loop = m.Loop("backward", lambda: (it, pos, vel, grad_pos, grad_vel)) while loop(it < 100): # Take reverse step in time pos, vel = self.timestep(pos, vel, dt=-0.02) # Take a forward step in time, keep track of derivatives pos_rev, vel_rev = m.Array2f(pos), m.Array2f(vel) ek.enable_grad(pos_rev, vel_rev) pos_fwd, vel_fwd = self.timestep(pos_rev, vel_rev, dt=0.02) ek.set_grad(pos_fwd, grad_pos) ek.set_grad(vel_fwd, grad_vel) ek.enqueue(pos_fwd, vel_fwd) ek.traverse(m.Float, reverse=True) grad_pos = ek.grad(pos_rev) grad_vel = ek.grad(vel_rev) it += 1 self.set_grad_in('pos', grad_pos) self.set_grad_in('vel', grad_vel)
def test23_exp(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.exp(x * x) ek.backward(y) exp_x = ek.exp(ek.sqr(ek.detach(x))) assert ek.allclose(y, exp_x) assert ek.allclose(ek.grad(x), 2 * ek.detach(x) * exp_x)
def test24_log(m): x = ek.linspace(m.Float, 0.01, 1, 10) ek.enable_grad(x) y = ek.log(x * x) ek.backward(y) log_x = ek.log(ek.sqr(ek.detach(x))) assert ek.allclose(y, log_x) assert ek.allclose(ek.grad(x), 2 / ek.detach(x))
def test50_gather_fwd_eager(m): with EagerMode(): x = ek.linspace(m.Float, -1, 1, 10) ek.enable_grad(x) ek.set_grad(x, 1) y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3)) ref = [-1.55556, -1.55556, -1.11111, -0.666667] assert ek.allclose(ek.grad(y), ref)
def test03_sub_mul(m): a, b, c = m.Float(2), m.Float(3), m.Float(4) ek.enable_grad(a, b, c) d = a * b - c ek.backward(d) assert ek.grad(a) == ek.detach(b) assert ek.grad(b) == ek.detach(a) assert ek.grad(c) == -1
def test18_gather(m): x = ek.linspace(m.Float, -1, 1, 10) ek.enable_grad(x) y = ek.gather(m.Float, x * x, m.UInt(1, 1, 2, 3)) z = ek.hsum_async(y) ek.backward(z) ref = [0, -1.55556 * 2, -1.11111, -0.666667, 0, 0, 0, 0, 0, 0] assert ek.allclose(ek.grad(x), ref)
def test09_hsum_2_rev(m): x = ek.linspace(m.Float, 0, 1, 11) ek.enable_grad(x) z = ek.hsum_async(ek.hsum_async(x * x) * x * x) ek.backward(z) assert ek.allclose( ek.grad(x), [0., 1.54, 3.08, 4.62, 6.16, 7.7, 9.24, 10.78, 12.32, 13.86, 15.4])
def test43_custom_reverse(m): d = m.Array3f(1, 2, 3) ek.enable_grad(d) d2 = ek.custom(Normalize, d) ek.set_grad(d2, m.Array3f(5, 6, 7)) ek.enqueue(d2) ek.traverse(m.Float, reverse=True) assert ek.allclose(ek.grad(d), m.Array3f(0.610883, 0.152721, -0.305441))
def test_56_diffloop_simple_rev(m, no_record): fi, fo = m.Float(1, 2, 3), m.Float(0, 0, 0) ek.enable_grad(fi) loop = m.Loop("MyLoop", lambda: fo) while loop(fo < 10): fo += fi ek.backward(fo) assert ek.grad(fi) == m.Float(10, 5, 4)
def test28_tan(m): x = ek.linspace(m.Float, 0, 1, 10) ek.enable_grad(x) y = ek.tan(x * x) ek.backward(y) tan_x = ek.tan(ek.sqr(ek.detach(x))) assert ek.allclose(y, tan_x) assert ek.allclose( ek.grad(x), m.Float(0., 0.222256, 0.44553, 0.674965, 0.924494, 1.22406, 1.63572, 2.29919, 3.58948, 6.85104))
def test29_asin(m): x = ek.linspace(m.Float, -.8, .8, 10) ek.enable_grad(x) y = ek.asin(x * x) ek.backward(y) asin_x = ek.asin(ek.sqr(ek.detach(x))) assert ek.allclose(y, asin_x) assert ek.allclose( ek.grad(x), m.Float(-2.08232, -1.3497, -0.906755, -0.534687, -0.177783, 0.177783, 0.534687, 0.906755, 1.3497, 2.08232))
def test30_acos(m): x = ek.linspace(m.Float, -.8, .8, 10) ek.enable_grad(x) y = ek.acos(x * x) ek.backward(y) acos_x = ek.acos(ek.sqr(ek.detach(x))) assert ek.allclose(y, acos_x) assert ek.allclose( ek.grad(x), m.Float(2.08232, 1.3497, 0.906755, 0.534687, 0.177783, -0.177783, -0.534687, -0.906755, -1.3497, -2.08232))
def test31_atan(m): x = ek.linspace(m.Float, -.8, .8, 10) ek.enable_grad(x) y = ek.atan(x * x) ek.backward(y) atan_x = ek.atan(ek.sqr(ek.detach(x))) assert ek.allclose(y, atan_x) assert ek.allclose( ek.grad(x), m.Float(-1.13507, -1.08223, -0.855508, -0.53065, -0.177767, 0.177767, 0.53065, 0.855508, 1.08223, 1.13507))